From 1a1fb13aa71153c6c3727bd6bd32fa4741d0c179 Mon Sep 17 00:00:00 2001 From: Alec Thomas Date: Tue, 17 Mar 2026 14:59:41 +1100 Subject: [PATCH] feat: forward request headers in handler and add configurable headers to host strategy The handler now automatically forwards safe request headers to upstream, stripping hop-by-hop headers per RFC 7230. This fixes issues where upstream servers (e.g. ghcr.io) require headers like Accept from the original request. The host strategy also supports configurable static headers, useful for injecting auth headers (e.g. Bearer token for ghcr.io/Homebrew bottles). Header filtering is unified in httputil.FilterHeaders, replacing the previous cache.FilterTransportHeaders. Co-Authored-By: Claude Opus 4.6 (1M context) Co-authored-by: Claude Code Ai-assisted: true --- internal/cache/api.go | 15 ----- internal/cache/remote.go | 6 +- internal/httputil/headers.go | 42 +++++++++++++ internal/httputil/headers_test.go | 74 +++++++++++++++++++++++ internal/strategy/apiv1.go | 3 +- internal/strategy/handler/handler.go | 8 +++ internal/strategy/handler/handler_test.go | 73 ++++++++++++++++++++++ internal/strategy/host.go | 34 +++++++---- internal/strategy/host_test.go | 33 ++++++++++ 9 files changed, 258 insertions(+), 30 deletions(-) create mode 100644 internal/httputil/headers.go create mode 100644 internal/httputil/headers_test.go diff --git a/internal/cache/api.go b/internal/cache/api.go index c5c2d72..a4c98ba 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -117,21 +117,6 @@ func (k *Key) MarshalText() ([]byte, error) { return []byte(k.String()), nil } -// FilterTransportHeaders returns a copy of the given headers with standard HTTP transport headers removed. -// These headers are typically added by HTTP clients/servers and should not be cached. -func FilterTransportHeaders(headers http.Header) http.Header { - filtered := make(http.Header) - for key, values := range headers { - // Skip standard HTTP headers added by transport layer or that shouldn't be cached - if key == "Content-Length" || key == "Date" || key == "Accept-Encoding" || - key == "User-Agent" || key == "Transfer-Encoding" || key == "Time-To-Live" { - continue - } - filtered[key] = values - } - return filtered -} - // Stats contains health and usage statistics for a cache. type Stats struct { // Objects is the number of objects currently in the cache. diff --git a/internal/cache/remote.go b/internal/cache/remote.go index bcd5650..f3507fe 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -11,6 +11,8 @@ import ( "time" "github.com/alecthomas/errors" + + "github.com/block/cachew/internal/httputil" ) const defaultNamespace = "-" @@ -66,7 +68,7 @@ func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, } // Filter out HTTP transport headers - headers := FilterTransportHeaders(resp.Header) + headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...) return resp.Body, headers, nil } @@ -98,7 +100,7 @@ func (c *Remote) Stat(ctx context.Context, key Key) (http.Header, error) { } // Filter out HTTP transport headers - headers := FilterTransportHeaders(resp.Header) + headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...) return headers, nil } diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go new file mode 100644 index 0000000..cae1603 --- /dev/null +++ b/internal/httputil/headers.go @@ -0,0 +1,42 @@ +package httputil + +import "net/http" + +// TransportHeaders are headers added by the HTTP transport layer that should not be cached. +var TransportHeaders = []string{ //nolint:gochecknoglobals + "Content-Length", + "Date", + "Accept-Encoding", + "User-Agent", + "Transfer-Encoding", + "Time-To-Live", +} + +// HopByHopHeaders are hop-by-hop headers that should not be forwarded by proxies (RFC 7230). +var HopByHopHeaders = []string{ //nolint:gochecknoglobals + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailers", + "Transfer-Encoding", + "Upgrade", + "Host", +} + +// FilterHeaders returns a copy of headers with the specified header keys removed. +func FilterHeaders(headers http.Header, skip ...string) http.Header { + skipSet := make(map[string]bool, len(skip)) + for _, s := range skip { + skipSet[http.CanonicalHeaderKey(s)] = true + } + filtered := make(http.Header, len(headers)) + for key, values := range headers { + if skipSet[http.CanonicalHeaderKey(key)] { + continue + } + filtered[key] = values + } + return filtered +} diff --git a/internal/httputil/headers_test.go b/internal/httputil/headers_test.go new file mode 100644 index 0000000..23bd794 --- /dev/null +++ b/internal/httputil/headers_test.go @@ -0,0 +1,74 @@ +package httputil_test + +import ( + "net/http" + "testing" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/httputil" +) + +func TestFilterHeaders(t *testing.T) { + tests := []struct { + name string + headers http.Header + skip []string + expected http.Header + }{ + { + name: "Empty", + headers: http.Header{}, + skip: httputil.TransportHeaders, + expected: http.Header{}, + }, + { + name: "TransportHeaders", + headers: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {"42"}, + "Date": {"Mon, 01 Jan 2024 00:00:00 GMT"}, + "Transfer-Encoding": {"chunked"}, + "X-Custom": {"value"}, + }, + skip: httputil.TransportHeaders, + expected: http.Header{ + "Content-Type": {"application/json"}, + "X-Custom": {"value"}, + }, + }, + { + name: "HopByHopHeaders", + headers: http.Header{ + "Accept": {"text/html"}, + "Authorization": {"Bearer token"}, + "Connection": {"keep-alive"}, + "Keep-Alive": {"timeout=5"}, + "Host": {"example.com"}, + "Upgrade": {"websocket"}, + }, + skip: httputil.HopByHopHeaders, + expected: http.Header{ + "Accept": {"text/html"}, + "Authorization": {"Bearer token"}, + }, + }, + { + name: "CaseInsensitive", + headers: http.Header{ + "content-length": {"42"}, + "X-Custom": {"value"}, + }, + skip: []string{"Content-Length"}, + expected: http.Header{ + "X-Custom": {"value"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := httputil.FilterHeaders(tt.headers, tt.skip...) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/strategy/apiv1.go b/internal/strategy/apiv1.go index d22701c..c1e075e 100644 --- a/internal/strategy/apiv1.go +++ b/internal/strategy/apiv1.go @@ -13,6 +13,7 @@ import ( "github.com/alecthomas/errors" "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/logging" ) @@ -116,7 +117,7 @@ func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) { } // Extract and filter headers from request - headers := cache.FilterTransportHeaders(r.Header) + headers := httputil.FilterHeaders(r.Header, httputil.TransportHeaders...) namespacedCache := d.cache.Namespace(namespace) cw, err := namespacedCache.Create(r.Context(), key, headers, ttl) diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go index 432c4cc..365c037 100644 --- a/internal/strategy/handler/handler.go +++ b/internal/strategy/handler/handler.go @@ -142,6 +142,14 @@ func (h *Handler) fetchAndCache(w http.ResponseWriter, r *http.Request, key cach return } + // Forward safe headers from the original request, without overwriting headers set by transform. + forwardable := httputil.FilterHeaders(r.Header, httputil.HopByHopHeaders...) + for key, values := range forwardable { + if upstreamReq.Header.Get(key) == "" { + upstreamReq.Header[key] = values + } + } + resp, err := h.client.Do(upstreamReq) if err != nil { h.errorHandler(httputil.Errorf(http.StatusBadGateway, "failed to fetch: %w", err), w, r) diff --git a/internal/strategy/handler/handler_test.go b/internal/strategy/handler/handler_test.go index 52d25f7..cc8f462 100644 --- a/internal/strategy/handler/handler_test.go +++ b/internal/strategy/handler/handler_test.go @@ -281,6 +281,79 @@ func TestBuilder(t *testing.T) { } } +func TestHeaderForwarding(t *testing.T) { + var receivedHeaders http.Header + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + c := mustNewMemoryCache() + ctx := logging.ContextWithLogger(context.Background(), slog.Default()) + + t.Run("ForwardsOriginalHeaders", func(t *testing.T) { + h := handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { return "fwd-test-1" }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil) + }) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + r.Header.Set("Accept", "application/json") + r.Header.Set("X-Custom", "forwarded") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", receivedHeaders.Get("Accept")) + assert.Equal(t, "forwarded", receivedHeaders.Get("X-Custom")) + }) + + t.Run("StripsHopByHopHeaders", func(t *testing.T) { + h := handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { return "fwd-test-2" }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil) + }) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + r.Header.Set("Connection", "keep-alive") + r.Header.Set("Keep-Alive", "timeout=5") + r.Header.Set("Accept", "text/html") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/html", receivedHeaders.Get("Accept")) + assert.Equal(t, "", receivedHeaders.Get("Connection")) + assert.Equal(t, "", receivedHeaders.Get("Keep-Alive")) + }) + + t.Run("TransformHeadersTakePrecedence", func(t *testing.T) { + h := handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { return "fwd-test-3" }). + Transform(func(r *http.Request) (*http.Request, error) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer override") + return req, nil + }) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + r.Header.Set("Authorization", "Bearer original") + r.Header.Set("X-Custom", "forwarded") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Bearer override", receivedHeaders.Get("Authorization")) + assert.Equal(t, "forwarded", receivedHeaders.Get("X-Custom")) + }) +} + func TestHandlerMethodChaining(t *testing.T) { c := mustNewMemoryCache() client := &http.Client{} diff --git a/internal/strategy/host.go b/internal/strategy/host.go index 1c2c284..1887cf5 100644 --- a/internal/strategy/host.go +++ b/internal/strategy/host.go @@ -27,16 +27,18 @@ func RegisterHost(r *Registry) { // // In this example, the strategy will be mounted under "/github.com". type HostConfig struct { - Target string `hcl:"target,label" help:"The target URL to proxy requests to."` + Target string `hcl:"target,label" help:"The target URL to proxy requests to."` + Headers map[string]string `hcl:"headers,optional" help:"Headers to add to upstream requests."` } // The Host [Strategy] forwards all GET requests to the specified host, caching the response payloads. type Host struct { - target *url.URL - cache cache.Cache - client *http.Client - logger *slog.Logger - prefix string + target *url.URL + cache cache.Cache + client *http.Client + logger *slog.Logger + prefix string + headers map[string]string } var _ Strategy = (*Host)(nil) @@ -48,11 +50,12 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux) } prefix := "/" + u.Host + u.EscapedPath() h := &Host{ - target: u, - cache: cache, - client: &http.Client{}, - logger: logging.FromContext(ctx), - prefix: prefix, + target: u, + cache: cache, + client: &http.Client{}, + logger: logging.FromContext(ctx), + prefix: prefix, + headers: config.Headers, } hdlr := handler.New(h.client, cache). @@ -61,7 +64,14 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux) }). Transform(func(r *http.Request) (*http.Request, error) { targetURL := h.buildTargetURL(r) - return http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) + if err != nil { + return nil, errors.Wrap(err, "creating upstream request") + } + for k, v := range h.headers { + req.Header.Set(k, v) + } + return req, nil }) mux.Handle("GET "+prefix+"/", hdlr) diff --git a/internal/strategy/host_test.go b/internal/strategy/host_test.go index 1f8442b..a3bb912 100644 --- a/internal/strategy/host_test.go +++ b/internal/strategy/host_test.go @@ -98,6 +98,39 @@ func TestHostInvalidTargetURL(t *testing.T) { assert.Error(t, err) } +func TestHostHeaders(t *testing.T) { + var receivedHeaders http.Header + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + mux := http.NewServeMux() + _, err = strategy.NewHost(ctx, strategy.HostConfig{ + Target: backend.URL, + Headers: map[string]string{"Authorization": "Bearer QQ==", "X-Custom": "value"}, + }, memCache, mux) + assert.NoError(t, err) + + u, _ := url.Parse(backend.URL) + reqPath := "/" + u.Host + "/test" + + req := httptest.NewRequestWithContext(ctx, http.MethodGet, reqPath, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Bearer QQ==", receivedHeaders.Get("Authorization")) + assert.Equal(t, "value", receivedHeaders.Get("X-Custom")) +} + func TestHostString(t *testing.T) { _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})