diff --git a/config/headers_test.go b/config/headers_test.go index c807fbc3..daac9808 100644 --- a/config/headers_test.go +++ b/config/headers_test.go @@ -17,8 +17,15 @@ package config import ( + "fmt" + "io" + "net" "net/http" + "net/http/httptest" + "strings" "testing" + + "github.com/stretchr/testify/require" ) func TestReservedHeaders(t *testing.T) { @@ -29,3 +36,111 @@ func TestReservedHeaders(t *testing.T) { } } } + +func TestHeadersRoundTripperSameHost(t *testing.T) { + // All headers, including sensitive ones, must be forwarded on same-host requests. + for _, header := range []string{"Cookie", "X-Custom-Header"} { + t.Run(header, func(t *testing.T) { + received := "" + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + received = r.Header.Get(header) + fmt.Fprint(w, "ok") + })) + t.Cleanup(server.Close) + + headers := &Headers{ + Headers: map[string]Header{ + header: {Values: []string{"testvalue"}}, + }, + } + rt := NewHeadersRoundTripper(headers, http.DefaultTransport) + + req, err := http.NewRequest(http.MethodGet, server.URL, nil) + require.NoError(t, err) + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, "ok", strings.TrimSpace(string(body))) + require.Equalf(t, "testvalue", received, "header %q must be forwarded on same-host request", header) + }) + } +} + +func TestHeadersRoundTripperCrossHostRedirect(t *testing.T) { + // Cookie must be set on the initial request but stripped on cross-host redirects. + cookieOnRedirect := "" + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookieOnRedirect = r.Header.Get("Cookie") + fmt.Fprint(w, "ok") + })) + t.Cleanup(target.Close) + + // Use "localhost" as the redirect target hostname so that it differs from + // "127.0.0.1" used by the origin server, making it a cross-host redirect. + targetPort := target.Listener.Addr().(*net.TCPAddr).Port + targetURL := fmt.Sprintf("http://localhost:%d", targetPort) + + cookieOnOrigin := "" + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cookieOnOrigin = r.Header.Get("Cookie") + http.Redirect(w, r, targetURL, http.StatusFound) + })) + t.Cleanup(origin.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + HTTPHeaders: &Headers{ + Headers: map[string]Header{ + "Cookie": {Values: []string{"session=abc"}}, + }, + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equalf(t, "session=abc", cookieOnOrigin, "Cookie must be set on the initial request.") + require.Emptyf(t, cookieOnRedirect, "Cookie must not be forwarded on a cross-host redirect.") +} + +func TestHeadersRoundTripperSameHostRedirect(t *testing.T) { + // Cookie must be forwarded on same-host redirects. + mux := http.NewServeMux() + cookieOnRedirect := "" + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/end", http.StatusFound) + }) + mux.HandleFunc("/end", func(w http.ResponseWriter, r *http.Request) { + cookieOnRedirect = r.Header.Get("Cookie") + fmt.Fprint(w, "ok") + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + HTTPHeaders: &Headers{ + Headers: map[string]Header{ + "Cookie": {Values: []string{"session=abc"}}, + }, + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(server.URL + "/start") + require.NoError(t, err) + defer resp.Body.Close() + _, err = io.ReadAll(resp.Body) + require.NoError(t, err) + + require.Equalf(t, "session=abc", cookieOnRedirect, "Cookie must be forwarded on a same-host redirect.") +} diff --git a/config/http_config.go b/config/http_config.go index 55cc5b07..c2259c73 100644 --- a/config/http_config.go +++ b/config/http_config.go @@ -721,6 +721,14 @@ func NewRoundTripperFromConfigWithContext(ctx context.Context, cfg HTTPClientCon } if cfg.HTTPHeaders != nil { + // Strip sensitive headers added by headersRoundTripper on cross-host + // redirects before they reach the transport. Only needed when + // redirects are actually followed; when FollowRedirects is false + // CheckRedirect returns ErrUseLastResponse immediately so there are + // no subsequent requests. + if cfg.FollowRedirects { + rt = &sensitiveHeadersStripRT{next: rt} + } rt = NewHeadersRoundTripper(cfg.HTTPHeaders, rt) } @@ -862,7 +870,7 @@ func NewAuthorizationCredentialsRoundTripper(authType string, authCredentials Se } func (rt *authorizationCredentialsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) != 0 { + if len(req.Header.Get("Authorization")) != 0 || isCrossHostRedirect(req) { return rt.rt.RoundTrip(req) } @@ -900,7 +908,7 @@ func NewBasicAuthRoundTripper(username, password SecretReader, rt http.RoundTrip } func (rt *basicAuthRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { - if len(req.Header.Get("Authorization")) != 0 { + if len(req.Header.Get("Authorization")) != 0 || isCrossHostRedirect(req) { return rt.rt.RoundTrip(req) } var username string @@ -1085,6 +1093,9 @@ func (rt *oauth2RoundTripper) RoundTrip(req *http.Request) (*http.Response, erro rt.mtx.RLock() currentRT := rt.lastRT rt.mtx.RUnlock() + if isCrossHostRedirect(req) { + return currentRT.Base.RoundTrip(req) + } return currentRT.RoundTrip(req) } @@ -1106,6 +1117,85 @@ func mapToValues(m map[string]string) url.Values { return v } +// isCrossHostRedirect reports whether req is a redirect to a different host +// than the original request. It detects this by walking the req.Response chain +// (which Go's HTTP client populates on every redirect hop) to find the original +// request's hostname, then comparing it to the current destination. +// This works regardless of whether the caller uses NewClientFromConfig or a +// custom http.Client built from NewRoundTripperFromConfigWithContext directly. +func isCrossHostRedirect(req *http.Request) bool { + if req.Response == nil { + return false + } + originalHost := strings.ToLower(originalRequestHost(req)) + return !isDomainOrSubdomain(strings.ToLower(req.URL.Hostname()), originalHost) +} + +func originalRequestHost(req *http.Request) string { + r := req + for r.Response != nil && r.Response.Request != nil { + r = r.Response.Request + } + return r.URL.Hostname() +} + +// sensitiveHeadersOnRedirect lists the headers that must not be forwarded when +// following a redirect to a different host, mirroring the list in +// makeHeadersCopier in net/http/client.go. +var sensitiveHeadersOnRedirect = map[string]struct{}{ + "Authorization": {}, + // "Www-Authenticate" is the canonical form produced by + // textproto.CanonicalMIMEHeaderKey; it is not a typo of "WWW-Authenticate". + "Www-Authenticate": {}, + "Cookie": {}, + "Cookie2": {}, + "Proxy-Authorization": {}, + "Proxy-Authenticate": {}, +} + +// sensitiveHeadersStripRT strips sensitive headers from requests marked as +// cross-host redirects before passing them to the underlying transport. +type sensitiveHeadersStripRT struct { + next http.RoundTripper +} + +func (rt *sensitiveHeadersStripRT) RoundTrip(req *http.Request) (*http.Response, error) { + if isCrossHostRedirect(req) { + req = cloneRequest(req) + for h := range sensitiveHeadersOnRedirect { + req.Header.Del(h) + } + } + return rt.next.RoundTrip(req) +} + +func (rt *sensitiveHeadersStripRT) CloseIdleConnections() { + if ci, ok := rt.next.(closeIdler); ok { + ci.CloseIdleConnections() + } +} + +// isDomainOrSubdomain reports whether sub is a subdomain (or exact match) of +// parent. It mirrors isDomainOrSubdomain from net/http/client.go. +func isDomainOrSubdomain(sub, parent string) bool { + if parent == "" { + return false + } + if sub == parent { + return true + } + // A colon means sub is an IPv6 address; a percent sign introduces an IPv6 + // zone ID. Neither can be a hostname, and both could otherwise pass the + // suffix check below (e.g. "::1%.www.example.com" ends with "example.com"). + if strings.ContainsAny(sub, ":%") { + return false + } + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} + // cloneRequest returns a clone of the provided *http.Request. // The clone is a shallow copy of the struct and its Header map. func cloneRequest(r *http.Request) *http.Request { diff --git a/config/http_config_test.go b/config/http_config_test.go index 9968d37a..ad0a6055 100644 --- a/config/http_config_test.go +++ b/config/http_config_test.go @@ -1380,6 +1380,324 @@ func TestDefaultFollowRedirect(t *testing.T) { } } +func TestCrossHostRedirectDropsCredentials(t *testing.T) { + for _, tc := range []struct { + name string + config HTTPClientConfig + }{ + { + name: "bearer token", + config: HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + }, + }, + { + name: "basic auth", + config: HTTPClientConfig{ + FollowRedirects: true, + BasicAuth: &BasicAuth{ + Username: "user", + Password: "pass", + }, + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + // target listens on 127.0.0.1 but origin redirects using "localhost" + // as the hostname. "127.0.0.1" and "localhost" are different hostname + // strings, so Go's redirect rules strip credentials on the redirect. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + http.Error(w, "credentials leaked to cross-host redirect target", http.StatusForbidden) + return + } + fmt.Fprint(w, ExpectedMessage) + })) + t.Cleanup(target.Close) + + // Build a redirect URL that uses "localhost" instead of "127.0.0.1". + targetPort := target.Listener.Addr().(*net.TCPAddr).Port + targetLocalhostURL := fmt.Sprintf("http://localhost:%d", targetPort) + + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, targetLocalhostURL+r.URL.Path, http.StatusFound) + })) + t.Cleanup(origin.Close) + + client, err := NewClientFromConfig(tc.config, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + }) + } +} + +func TestIsDomainOrSubdomain(t *testing.T) { + for _, tc := range []struct { + sub, parent string + want bool + }{ + {"example.com", "example.com", true}, + {"sub.example.com", "example.com", true}, + {"deep.sub.example.com", "example.com", true}, + {"notexample.com", "example.com", false}, + {"example.com", "sub.example.com", false}, + {"bar.com", "foo.com", false}, + {"127.0.0.1", "127.0.0.1", true}, + {"localhost", "127.0.0.1", false}, + {"127.0.0.1", "localhost", false}, + {"::1", "::1", true}, + {"::2", "::1", false}, + {"::1", "example.com", false}, + // Zone ID containing a hostname must not match as a subdomain. + {"::1%.www.example.com", "example.com", false}, + {"fe80::1%eth0", "eth0", false}, + // Empty parent must never match. + {"example.com", "", false}, + {"", "", false}, + // Trailing-dot FQDN: "sub.example.com." vs "example.com" are not equal + // because isDomainOrSubdomain operates on raw strings; callers normalise + // via Hostname() which strips trailing dots in practice. + {"sub.example.com.", "example.com.", true}, + {"sub.example.com.", "example.com", false}, + // Case folding is the caller's responsibility (shouldSendCredentialsOnRedirect + // lowercases before calling); isDomainOrSubdomain itself is case-sensitive. + {"Sub.Example.Com", "example.com", false}, + } { + t.Run(tc.sub+"→"+tc.parent, func(t *testing.T) { + require.Equal(t, tc.want, isDomainOrSubdomain(tc.sub, tc.parent)) + }) + } +} + +func TestSameHostRedirectKeepsCredentials(t *testing.T) { + credsSeen := false + mux := http.NewServeMux() + mux.HandleFunc("/start", func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, "/end", http.StatusFound) + }) + mux.HandleFunc("/end", func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + credsSeen = true + } + fmt.Fprint(w, ExpectedMessage) + }) + server := httptest.NewServer(mux) + t.Cleanup(server.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(server.URL + "/start") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + require.Truef(t, credsSeen, "credentials should be forwarded on same-host redirect") +} + +func TestRoundTripperCrossHostRedirectDropsCredentials(t *testing.T) { + // Verify that a custom http.Client built from NewRoundTripperFromConfig + // (not NewClientFromConfig) also strips credentials on cross-host redirects, + // because isCrossHostRedirect uses req.Response and requires no CheckRedirect hook. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + http.Error(w, "credentials leaked to cross-host redirect target", http.StatusForbidden) + return + } + fmt.Fprint(w, ExpectedMessage) + })) + t.Cleanup(target.Close) + + targetPort := target.Listener.Addr().(*net.TCPAddr).Port + targetLocalhostURL := fmt.Sprintf("http://localhost:%d", targetPort) + + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, targetLocalhostURL+r.URL.Path, http.StatusFound) + })) + t.Cleanup(origin.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + } + rt, err := NewRoundTripperFromConfig(cfg, "test") + require.NoError(t, err) + + client := &http.Client{Transport: rt} + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) +} + +func TestOAuth2CrossHostRedirectDropsCredentials(t *testing.T) { + // target checks that no OAuth2 Bearer token arrives on a cross-host redirect. + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + http.Error(w, "credentials leaked to cross-host redirect target", http.StatusForbidden) + return + } + fmt.Fprint(w, ExpectedMessage) + })) + t.Cleanup(target.Close) + + targetPort := target.Listener.Addr().(*net.TCPAddr).Port + targetLocalhostURL := fmt.Sprintf("http://localhost:%d", targetPort) + + originAuthSeen := "" + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + originAuthSeen = r.Header.Get("Authorization") + http.Redirect(w, r, targetLocalhostURL+r.URL.Path, http.StatusFound) + })) + t.Cleanup(origin.Close) + + // tokenServer issues a static Bearer token used by the OAuth2 round-tripper. + tokenServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + res, _ := json.Marshal(oauth2TestServerResponse{ + AccessToken: "oauth2-secret-token", + TokenType: "Bearer", + }) + w.Header().Add("Content-Type", "application/json") + _, _ = w.Write(res) + })) + t.Cleanup(tokenServer.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + OAuth2: &OAuth2{ + ClientID: "testclient", + TokenURL: tokenServer.URL + "/token", + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + require.NotEmptyf(t, originAuthSeen, "OAuth2 Bearer token must be present on the initial request to origin") +} + +func TestMultiHopCrossHostRedirectDropsCredentials(t *testing.T) { + // Chain: origin (127.0.0.1) → hop (localhost) → final (localhost). + // Both hop and final differ from the original host (127.0.0.1 ≠ localhost), + // so isCrossHostRedirect must strip credentials from both hops. + // Note that hop→final is same-hostname (localhost→localhost), which confirms + // that the cross-host check compares against the original host, not the + // immediately preceding hop. + finalAuthSeen := "" + final := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + finalAuthSeen = r.Header.Get("Authorization") + fmt.Fprint(w, ExpectedMessage) + })) + t.Cleanup(final.Close) + + finalPort := final.Listener.Addr().(*net.TCPAddr).Port + finalURL := fmt.Sprintf("http://localhost:%d/", finalPort) + + hopAuthSeen := "" + hop := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hopAuthSeen = r.Header.Get("Authorization") + http.Redirect(w, r, finalURL, http.StatusFound) + })) + t.Cleanup(hop.Close) + + hopPort := hop.Listener.Addr().(*net.TCPAddr).Port + hopURL := fmt.Sprintf("http://localhost:%d/", hopPort) + + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, hopURL, http.StatusFound) + })) + t.Cleanup(origin.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + require.Emptyf(t, hopAuthSeen, "credentials must not reach hop: original host 127.0.0.1 differs from localhost") + require.Emptyf(t, finalAuthSeen, "credentials must not reach final: original host 127.0.0.1 still differs from localhost") +} + +func TestPortChangeRedirectKeepsCredentials(t *testing.T) { + // Both servers bind to 127.0.0.1; only the port differs. Credentials must + // be forwarded because the hostname is the same. + credsSeen := false + target := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Authorization") != "" { + credsSeen = true + } + fmt.Fprint(w, ExpectedMessage) + })) + t.Cleanup(target.Close) + + origin := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, target.URL, http.StatusFound) + })) + t.Cleanup(origin.Close) + + cfg := HTTPClientConfig{ + FollowRedirects: true, + Authorization: &Authorization{ + Type: "Bearer", + Credentials: "secret-token", + }, + } + client, err := NewClientFromConfig(cfg, "test") + require.NoError(t, err) + + resp, err := client.Get(origin.URL) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, ExpectedMessage, strings.TrimSpace(string(body))) + require.Truef(t, credsSeen, "credentials must be forwarded when only the port changes") +} + func TestValidateHTTPConfig(t *testing.T) { cfg, _, err := LoadHTTPConfigFile("testdata/http.conf.good.yml") if err != nil {