diff --git a/go/adk/pkg/agent/agent.go b/go/adk/pkg/agent/agent.go index fa9d633d14..beaa0beb02 100644 --- a/go/adk/pkg/agent/agent.go +++ b/go/adk/pkg/agent/agent.go @@ -13,6 +13,7 @@ import ( "github.com/kagent-dev/kagent/go/adk/pkg/sts" "github.com/kagent-dev/kagent/go/adk/pkg/tools" "github.com/kagent-dev/kagent/go/api/adk" + "github.com/kagent-dev/kagent/go/core/pkg/env" "google.golang.org/adk/agent" "google.golang.org/adk/agent/llmagent" adkmodel "google.golang.org/adk/model" @@ -50,12 +51,16 @@ func CreateGoogleADKAgentWithSubagentSessionIDs(ctx context.Context, agentConfig return nil, nil, fmt.Errorf("agent config is required") } - propagateToken := strings.ToLower(os.Getenv("KAGENT_PROPAGATE_TOKEN")) == "true" + propagateToken := env.KagentPropagateToken.Get() + tokenPrecedence := mcp.StaticTokenWins + if env.KagentPropagateTokenOverridesStatic.Get() { + tokenPrecedence = mcp.ForwardedTokenWins + } var dynamicHeaderProvider mcp.DynamicHeaderProvider if stsPlugin != nil { dynamicHeaderProvider = stsPlugin.HeaderProvider } - toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, dynamicHeaderProvider) + toolsets := mcp.CreateToolsets(ctx, agentConfig.HttpTools, agentConfig.SseTools, propagateToken, tokenPrecedence, dynamicHeaderProvider) subagentSessionIDs := make(map[string]string) var remoteAgentTools []tool.Tool diff --git a/go/adk/pkg/constants/const.go b/go/adk/pkg/constants/const.go index 2926e96e4f..f4dca27214 100644 --- a/go/adk/pkg/constants/const.go +++ b/go/adk/pkg/constants/const.go @@ -4,4 +4,10 @@ const ( // A2A call context's NewRequestMeta normalizes header names to lowercase. // This is why we use "authorization" instead of "Authorization". AuthorizationHeader = "authorization" + + // ActorTokenHeader carries the agent's own workload token alongside a + // forwarded end-user Authorization, so a downstream gateway can run an + // RFC 8693 delegation (subject=user, actor=agent). It is set on the + // outgoing request, so it uses the canonical header form. + ActorTokenHeader = "X-Actor-Token" ) diff --git a/go/adk/pkg/mcp/registry.go b/go/adk/pkg/mcp/registry.go index 97cf20ea41..270293d3e0 100644 --- a/go/adk/pkg/mcp/registry.go +++ b/go/adk/pkg/mcp/registry.go @@ -7,6 +7,7 @@ import ( "fmt" "net/http" "os" + "strings" "time" "github.com/a2aproject/a2a-go/a2asrv" @@ -23,6 +24,22 @@ import ( // This is used for dynamic token injection (e.g., STS tokens) per session. type DynamicHeaderProvider func(ctx context.Context) map[string]string +// TokenPrecedence selects how a static Authorization configured on an MCP +// server relates to a forwarded or STS-exchanged Authorization. +type TokenPrecedence int + +const ( + // StaticTokenWins keeps a static Authorization at the highest precedence: it + // overrides any forwarded or STS-exchanged Authorization. This is the default. + StaticTokenWins TokenPrecedence = iota + + // ForwardedTokenWins lets a forwarded or STS-exchanged Authorization win over + // a static Authorization. The displaced static token is sent as the actor + // token (X-Actor-Token) so a downstream gateway can run an RFC 8693 + // delegation with subject=end user and actor=agent. + ForwardedTokenWins +) + const ( // Default timeout matching Python KAGENT_REMOTE_AGENT_TIMEOUT defaultTimeout = 30 * time.Minute @@ -69,6 +86,7 @@ type mcpServerParams struct { Headers map[string]string AllowedHeaders []string // header names to forward from incoming request PropagateToken bool // when true, Authorization is forwarded independently of AllowedHeaders + TokenPrecedence TokenPrecedence // how a static Authorization relates to a forwarded/STS Authorization HeaderProvider DynamicHeaderProvider // optional per-request headers derived from invocation context (e.g., STS exchanged access tokens) ServerType string // "http" or "sse" Timeout *float64 @@ -86,6 +104,9 @@ type mcpServerParams struct { // independently of AllowedHeaders, mirroring the Python ADKTokenPropagationPlugin // behaviour triggered by KAGENT_PROPAGATE_TOKEN. // +// tokenPrecedence is a runtime-global policy (KAGENT_PROPAGATE_TOKEN_OVERRIDES_STATIC) +// applied uniformly to every server here; see TokenPrecedence and applyStaticHeaders. +// // Optional headerProvider can be used to inject per-request headers // derived from invocation context (e.g., STS exchanged access tokens). func CreateToolsets( @@ -93,6 +114,7 @@ func CreateToolsets( httpTools []adk.HttpMcpServerConfig, sseTools []adk.SseMcpServerConfig, propagateToken bool, + tokenPrecedence TokenPrecedence, headerProvider DynamicHeaderProvider, ) []tool.Toolset { log := logr.FromContextOrDiscard(ctx) @@ -105,6 +127,7 @@ func CreateToolsets( Headers: httpTool.Params.Headers, AllowedHeaders: httpTool.AllowedHeaders, PropagateToken: propagateToken, + TokenPrecedence: tokenPrecedence, HeaderProvider: headerProvider, ServerType: "http", Timeout: httpTool.Params.Timeout, @@ -127,6 +150,7 @@ func CreateToolsets( Headers: sseTool.Params.Headers, AllowedHeaders: sseTool.AllowedHeaders, PropagateToken: propagateToken, + TokenPrecedence: tokenPrecedence, HeaderProvider: headerProvider, ServerType: "sse", Timeout: sseTool.Params.Timeout, @@ -224,14 +248,20 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp baseTransport.TLSClientConfig = tlsConfig } + if params.TokenPrecedence == ForwardedTokenWins && + params.TLSInsecureSkipVerify != nil && *params.TLSInsecureSkipVerify { + log.Info("WARNING: ForwardedTokenWins sends the static M2M credential as X-Actor-Token, but TLS verification is disabled for this MCP server - the actor token can leak to an unverified endpoint", "url", params.URL) + } + var httpTransport http.RoundTripper = baseTransport if len(params.Headers) > 0 || len(params.AllowedHeaders) > 0 || params.PropagateToken || params.HeaderProvider != nil { httpTransport = &headerRoundTripper{ - base: baseTransport, - headers: params.Headers, - allowedHeaders: params.AllowedHeaders, - propagateToken: params.PropagateToken, - headerProvider: params.HeaderProvider, + base: baseTransport, + headers: params.Headers, + allowedHeaders: params.AllowedHeaders, + propagateToken: params.PropagateToken, + tokenPrecedence: params.TokenPrecedence, + headerProvider: params.HeaderProvider, } } @@ -257,20 +287,22 @@ func createTransport(ctx context.Context, params mcpServerParams) (mcpsdk.Transp } // headerRoundTripper wraps an http.RoundTripper to add custom headers to all -// requests. It supports four sources of headers, applied in this order so that -// higher-priority sources win on collision: +// requests. Header sources are applied lowest to highest precedence: // 1. propagateToken: when true, Authorization is read from the incoming A2A // CallContext and forwarded unconditionally (independent of allowedHeaders). // 2. allowedHeaders: explicit per-header forwarding from the A2A CallContext. // 3. headerProvider: runtime headers derived from ADK context, such as STS tokens. -// 4. headers: static key/value pairs configured on the MCP server spec (highest -// priority — always wins). +// 4. headers: static key/value pairs configured on the MCP server spec. +// +// Static headers (4) have the highest precedence; the one exception is the +// Authorization header under ForwardedTokenWins, resolved in applyStaticHeaders. type headerRoundTripper struct { - base http.RoundTripper - headers map[string]string - allowedHeaders []string // header names (case-insensitive) to forward from A2A context - propagateToken bool // when true, Authorization is forwarded independently - headerProvider DynamicHeaderProvider + base http.RoundTripper + headers map[string]string + allowedHeaders []string // header names (case-insensitive) to forward from A2A context + propagateToken bool // when true, Authorization is forwarded independently + tokenPrecedence TokenPrecedence // resolves static vs forwarded Authorization + headerProvider DynamicHeaderProvider } func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { @@ -300,12 +332,49 @@ func (rt *headerRoundTripper) RoundTrip(req *http.Request) (*http.Response, erro } } - // Apply static headers last — they take precedence over all dynamic sources. + rt.applyStaticHeaders(req) + + return rt.base.RoundTrip(req) +} + +// applyStaticHeaders writes the static headers configured on the MCP server spec +// onto req. Non-Authorization headers always overwrite forwarded values. The +// Authorization header honours tokenPrecedence: StaticTokenWins overwrites any +// forwarded token, while ForwardedTokenWins keeps a forwarded/STS token and, when +// it differs from the static one, carries the displaced static token as the actor +// (X-Actor-Token) for a downstream RFC 8693 delegation. With no forwarded token +// the static Authorization is applied and no actor is added; a forwarded token +// equal to the static one is treated as M2M (no actor); an actor token already +// forwarded via allowedHeaders is left untouched. +func (rt *headerRoundTripper) applyStaticHeaders(req *http.Request) { + // headers is assumed to hold at most one Authorization key; with case-variant + // duplicates map iteration order decides which wins. + staticAuthorization := "" for key, value := range rt.headers { + if strings.EqualFold(key, constants.AuthorizationHeader) { + staticAuthorization = value + continue + } req.Header.Set(key, value) } - return rt.base.RoundTrip(req) + if staticAuthorization == "" { + return + } + + if rt.tokenPrecedence == StaticTokenWins { + req.Header.Set(constants.AuthorizationHeader, staticAuthorization) + return + } + + forwardedAuthorization := req.Header.Get(constants.AuthorizationHeader) + if forwardedAuthorization == "" { + req.Header.Set(constants.AuthorizationHeader, staticAuthorization) + return + } + if forwardedAuthorization != staticAuthorization && req.Header.Get(constants.ActorTokenHeader) == "" { + req.Header.Set(constants.ActorTokenHeader, staticAuthorization) + } } // initializeToolSet fetches tools from an MCP server using Google ADK's mcptoolset. diff --git a/go/adk/pkg/mcp/registry_test.go b/go/adk/pkg/mcp/registry_test.go index 2931a94bd3..c3ac9cee14 100644 --- a/go/adk/pkg/mcp/registry_test.go +++ b/go/adk/pkg/mcp/registry_test.go @@ -396,3 +396,295 @@ func TestStaticHeaders_OverrideDynamic(t *testing.T) { t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer static") } } + +// TestOverrideStatic_PropagatedTokenWinsAsSubject verifies that with +// ForwardedTokenWins, a token forwarded via propagateToken +// beats the static Authorization (becoming the OBO subject), the displaced +// static token is carried as the actor (X-Actor-Token), and other static +// headers still apply. +func TestOverrideStatic_PropagatedTokenWinsAsSubject(t *testing.T) { + t.Parallel() + var capturedAuth, capturedActor, capturedStatic string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedActor = r.Header.Get("X-Actor-Token") + capturedStatic = r.Header.Get("X-Static") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer user-dex"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{ + "Authorization": "Bearer m2m-sa", + "X-Static": "keep-me", + }, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer user-dex" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer user-dex") + } + if capturedActor != "Bearer m2m-sa" { + t.Errorf("X-Actor-Token: got %q, want %q", capturedActor, "Bearer m2m-sa") + } + if capturedStatic != "keep-me" { + t.Errorf("X-Static: got %q, want %q", capturedStatic, "keep-me") + } +} + +// TestOverrideStatic_DisplacedStaticBecomesActor verifies the displacement also +// holds when the winning Authorization comes from the STS headerProvider rather +// than a propagated A2A token. +func TestOverrideStatic_DisplacedStaticBecomesActor(t *testing.T) { + t.Parallel() + var capturedAuth, capturedActor string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedActor = r.Header.Get("X-Actor-Token") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{"Authorization": "Bearer m2m-sa"}, + headerProvider: func(context.Context) map[string]string { + return map[string]string{"Authorization": "Bearer sts-exchanged"} + }, + } + + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer sts-exchanged" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer sts-exchanged") + } + if capturedActor != "Bearer m2m-sa" { + t.Errorf("X-Actor-Token: got %q, want %q", capturedActor, "Bearer m2m-sa") + } +} + +// TestOverrideStatic_NoForwardedToken_StaticStaysNoActor verifies that with no +// forwarded or STS token the static Authorization is preserved and no actor +// token is added, so autonomous runs stay pure M2M. +func TestOverrideStatic_NoForwardedToken_StaticStaysNoActor(t *testing.T) { + t.Parallel() + var capturedAuth, capturedActor string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedActor = r.Header.Get("X-Actor-Token") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + // propagateToken is on but the incoming request carries no Authorization. + ctx := a2aCtx(map[string][]string{ + "X-Trace-Id": {"abc"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{"Authorization": "Bearer m2m-sa"}, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer m2m-sa" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer m2m-sa") + } + if capturedActor != "" { + t.Errorf("X-Actor-Token: got %q, want empty", capturedActor) + } +} + +// TestOverrideStatic_NonAuthStaticHeaderWins verifies that in ForwardedTokenWins +// mode the override is scoped to Authorization: a non-Authorization static header +// still overrides a forwarded header of the same name. +func TestOverrideStatic_NonAuthStaticHeaderWins(t *testing.T) { + t.Parallel() + var capturedTenant string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedTenant = r.Header.Get("X-Tenant") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + // X-Tenant is both forwarded via allowedHeaders and configured statically. + ctx := a2aCtx(map[string][]string{ + "X-Tenant": {"forwarded-tenant"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + allowedHeaders: []string{"X-Tenant"}, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{"X-Tenant": "static-tenant"}, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedTenant != "static-tenant" { + t.Errorf("X-Tenant: got %q, want %q", capturedTenant, "static-tenant") + } +} + +// TestOverrideStatic_ForwardedEqualsStatic_NoActor verifies that when the +// forwarded Authorization equals the static one the call is treated as M2M: no +// actor token is added. +func TestOverrideStatic_ForwardedEqualsStatic_NoActor(t *testing.T) { + t.Parallel() + var capturedAuth, capturedActor string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedActor = r.Header.Get("X-Actor-Token") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer m2m-sa"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{"Authorization": "Bearer m2m-sa"}, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer m2m-sa" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer m2m-sa") + } + if capturedActor != "" { + t.Errorf("X-Actor-Token: got %q, want empty", capturedActor) + } +} + +// TestOverrideStatic_PreexistingActorTokenPreserved verifies that an actor token +// already forwarded via allowedHeaders is not overwritten by the displaced +// static token. +func TestOverrideStatic_PreexistingActorTokenPreserved(t *testing.T) { + t.Parallel() + var capturedAuth, capturedActor string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedActor = r.Header.Get("X-Actor-Token") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer user-dex"}, + "X-Actor-Token": {"Bearer forwarded-actor"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + allowedHeaders: []string{"X-Actor-Token"}, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{"Authorization": "Bearer m2m-sa"}, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer user-dex" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer user-dex") + } + if capturedActor != "Bearer forwarded-actor" { + t.Errorf("X-Actor-Token: got %q, want %q", capturedActor, "Bearer forwarded-actor") + } +} + +// TestOverrideStatic_NoStaticAuthorization_ForwardedPassesThrough verifies that +// in ForwardedTokenWins mode with no static Authorization configured (only a +// non-Authorization static header), the forwarded Authorization passes through +// untouched and no actor token is synthesized. +func TestOverrideStatic_NoStaticAuthorization_ForwardedPassesThrough(t *testing.T) { + t.Parallel() + var capturedAuth, capturedActor, capturedStatic string + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + capturedAuth = r.Header.Get("Authorization") + capturedActor = r.Header.Get("X-Actor-Token") + capturedStatic = r.Header.Get("X-Static") + w.WriteHeader(http.StatusOK) + })) + defer srv.Close() + + ctx := a2aCtx(map[string][]string{ + "Authorization": {"Bearer user-dex"}, + }) + + rt := &headerRoundTripper{ + base: http.DefaultTransport, + propagateToken: true, + tokenPrecedence: ForwardedTokenWins, + headers: map[string]string{"X-Static": "keep-me"}, + } + + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, srv.URL, nil) + resp, err := rt.RoundTrip(req) + if err != nil { + t.Fatalf("RoundTrip failed: %v", err) + } + resp.Body.Close() + + if capturedAuth != "Bearer user-dex" { + t.Errorf("Authorization: got %q, want %q", capturedAuth, "Bearer user-dex") + } + if capturedActor != "" { + t.Errorf("X-Actor-Token: got %q, want empty", capturedActor) + } + if capturedStatic != "keep-me" { + t.Errorf("X-Static: got %q, want %q", capturedStatic, "keep-me") + } +} diff --git a/go/adk/pkg/runner/adapter.go b/go/adk/pkg/runner/adapter.go index 0441f778c0..c7d3baac25 100644 --- a/go/adk/pkg/runner/adapter.go +++ b/go/adk/pkg/runner/adapter.go @@ -3,7 +3,6 @@ package runner import ( "context" "fmt" - "os" "strings" "github.com/go-logr/logr" @@ -12,6 +11,7 @@ import ( "github.com/kagent-dev/kagent/go/adk/pkg/session" "github.com/kagent-dev/kagent/go/adk/pkg/sts" "github.com/kagent-dev/kagent/go/api/adk" + "github.com/kagent-dev/kagent/go/core/pkg/env" adkmemory "google.golang.org/adk/memory" adkplugin "google.golang.org/adk/plugin" "google.golang.org/adk/runner" @@ -97,8 +97,8 @@ func CreateRunnerConfig( } func buildTokenPropagationPlugin(ctx context.Context, log logr.Logger) (*sts.TokenPropagationPlugin, error) { - propagateToken := strings.EqualFold(strings.TrimSpace(os.Getenv("KAGENT_PROPAGATE_TOKEN")), "true") - stsWellKnownURI := strings.TrimSpace(os.Getenv("STS_WELL_KNOWN_URI")) + propagateToken := env.KagentPropagateToken.Get() + stsWellKnownURI := strings.TrimSpace(env.StsWellKnownURI.Get()) if !propagateToken && stsWellKnownURI == "" { return nil, nil } diff --git a/go/core/pkg/env/kagent.go b/go/core/pkg/env/kagent.go index 5d158b2060..56fe5cdfc5 100644 --- a/go/core/pkg/env/kagent.go +++ b/go/core/pkg/env/kagent.go @@ -63,10 +63,17 @@ var ( ComponentAgentRuntime, ) - KagentPropagateToken = RegisterStringVar( + KagentPropagateToken = RegisterBoolVar( "KAGENT_PROPAGATE_TOKEN", - "", - "When set, propagates the authentication token to downstream services.", + false, + "When true, the incoming Authorization token is propagated to downstream MCP servers and A2A agents.", + ComponentAgentRuntime, + ) + + KagentPropagateTokenOverridesStatic = RegisterBoolVar( + "KAGENT_PROPAGATE_TOKEN_OVERRIDES_STATIC", + false, + "When true, a forwarded or STS-exchanged Authorization takes precedence over a static Authorization on an MCP server, and the displaced static token is sent as the X-Actor-Token actor token for downstream RFC 8693 delegation.", ComponentAgentRuntime, )