diff --git a/internal/cache/api.go b/internal/cache/api.go index c5c2d72..a4c98ba 100644 --- a/internal/cache/api.go +++ b/internal/cache/api.go @@ -117,21 +117,6 @@ func (k *Key) MarshalText() ([]byte, error) { return []byte(k.String()), nil } -// FilterTransportHeaders returns a copy of the given headers with standard HTTP transport headers removed. -// These headers are typically added by HTTP clients/servers and should not be cached. -func FilterTransportHeaders(headers http.Header) http.Header { - filtered := make(http.Header) - for key, values := range headers { - // Skip standard HTTP headers added by transport layer or that shouldn't be cached - if key == "Content-Length" || key == "Date" || key == "Accept-Encoding" || - key == "User-Agent" || key == "Transfer-Encoding" || key == "Time-To-Live" { - continue - } - filtered[key] = values - } - return filtered -} - // Stats contains health and usage statistics for a cache. type Stats struct { // Objects is the number of objects currently in the cache. diff --git a/internal/cache/remote.go b/internal/cache/remote.go index bcd5650..f3507fe 100644 --- a/internal/cache/remote.go +++ b/internal/cache/remote.go @@ -11,6 +11,8 @@ import ( "time" "github.com/alecthomas/errors" + + "github.com/block/cachew/internal/httputil" ) const defaultNamespace = "-" @@ -66,7 +68,7 @@ func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header, } // Filter out HTTP transport headers - headers := FilterTransportHeaders(resp.Header) + headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...) return resp.Body, headers, nil } @@ -98,7 +100,7 @@ func (c *Remote) Stat(ctx context.Context, key Key) (http.Header, error) { } // Filter out HTTP transport headers - headers := FilterTransportHeaders(resp.Header) + headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...) return headers, nil } diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go new file mode 100644 index 0000000..cae1603 --- /dev/null +++ b/internal/httputil/headers.go @@ -0,0 +1,42 @@ +package httputil + +import "net/http" + +// TransportHeaders are headers added by the HTTP transport layer that should not be cached. +var TransportHeaders = []string{ //nolint:gochecknoglobals + "Content-Length", + "Date", + "Accept-Encoding", + "User-Agent", + "Transfer-Encoding", + "Time-To-Live", +} + +// HopByHopHeaders are hop-by-hop headers that should not be forwarded by proxies (RFC 7230). +var HopByHopHeaders = []string{ //nolint:gochecknoglobals + "Connection", + "Keep-Alive", + "Proxy-Authenticate", + "Proxy-Authorization", + "Te", + "Trailers", + "Transfer-Encoding", + "Upgrade", + "Host", +} + +// FilterHeaders returns a copy of headers with the specified header keys removed. +func FilterHeaders(headers http.Header, skip ...string) http.Header { + skipSet := make(map[string]bool, len(skip)) + for _, s := range skip { + skipSet[http.CanonicalHeaderKey(s)] = true + } + filtered := make(http.Header, len(headers)) + for key, values := range headers { + if skipSet[http.CanonicalHeaderKey(key)] { + continue + } + filtered[key] = values + } + return filtered +} diff --git a/internal/httputil/headers_test.go b/internal/httputil/headers_test.go new file mode 100644 index 0000000..23bd794 --- /dev/null +++ b/internal/httputil/headers_test.go @@ -0,0 +1,74 @@ +package httputil_test + +import ( + "net/http" + "testing" + + "github.com/alecthomas/assert/v2" + + "github.com/block/cachew/internal/httputil" +) + +func TestFilterHeaders(t *testing.T) { + tests := []struct { + name string + headers http.Header + skip []string + expected http.Header + }{ + { + name: "Empty", + headers: http.Header{}, + skip: httputil.TransportHeaders, + expected: http.Header{}, + }, + { + name: "TransportHeaders", + headers: http.Header{ + "Content-Type": {"application/json"}, + "Content-Length": {"42"}, + "Date": {"Mon, 01 Jan 2024 00:00:00 GMT"}, + "Transfer-Encoding": {"chunked"}, + "X-Custom": {"value"}, + }, + skip: httputil.TransportHeaders, + expected: http.Header{ + "Content-Type": {"application/json"}, + "X-Custom": {"value"}, + }, + }, + { + name: "HopByHopHeaders", + headers: http.Header{ + "Accept": {"text/html"}, + "Authorization": {"Bearer token"}, + "Connection": {"keep-alive"}, + "Keep-Alive": {"timeout=5"}, + "Host": {"example.com"}, + "Upgrade": {"websocket"}, + }, + skip: httputil.HopByHopHeaders, + expected: http.Header{ + "Accept": {"text/html"}, + "Authorization": {"Bearer token"}, + }, + }, + { + name: "CaseInsensitive", + headers: http.Header{ + "content-length": {"42"}, + "X-Custom": {"value"}, + }, + skip: []string{"Content-Length"}, + expected: http.Header{ + "X-Custom": {"value"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := httputil.FilterHeaders(tt.headers, tt.skip...) + assert.Equal(t, tt.expected, result) + }) + } +} diff --git a/internal/strategy/apiv1.go b/internal/strategy/apiv1.go index d22701c..c1e075e 100644 --- a/internal/strategy/apiv1.go +++ b/internal/strategy/apiv1.go @@ -13,6 +13,7 @@ import ( "github.com/alecthomas/errors" "github.com/block/cachew/internal/cache" + "github.com/block/cachew/internal/httputil" "github.com/block/cachew/internal/logging" ) @@ -116,7 +117,7 @@ func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) { } // Extract and filter headers from request - headers := cache.FilterTransportHeaders(r.Header) + headers := httputil.FilterHeaders(r.Header, httputil.TransportHeaders...) namespacedCache := d.cache.Namespace(namespace) cw, err := namespacedCache.Create(r.Context(), key, headers, ttl) diff --git a/internal/strategy/handler/handler.go b/internal/strategy/handler/handler.go index 432c4cc..365c037 100644 --- a/internal/strategy/handler/handler.go +++ b/internal/strategy/handler/handler.go @@ -142,6 +142,14 @@ func (h *Handler) fetchAndCache(w http.ResponseWriter, r *http.Request, key cach return } + // Forward safe headers from the original request, without overwriting headers set by transform. + forwardable := httputil.FilterHeaders(r.Header, httputil.HopByHopHeaders...) + for key, values := range forwardable { + if upstreamReq.Header.Get(key) == "" { + upstreamReq.Header[key] = values + } + } + resp, err := h.client.Do(upstreamReq) if err != nil { h.errorHandler(httputil.Errorf(http.StatusBadGateway, "failed to fetch: %w", err), w, r) diff --git a/internal/strategy/handler/handler_test.go b/internal/strategy/handler/handler_test.go index 52d25f7..cc8f462 100644 --- a/internal/strategy/handler/handler_test.go +++ b/internal/strategy/handler/handler_test.go @@ -281,6 +281,79 @@ func TestBuilder(t *testing.T) { } } +func TestHeaderForwarding(t *testing.T) { + var receivedHeaders http.Header + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + _, _ = fmt.Fprint(w, "ok") + })) + defer upstream.Close() + + c := mustNewMemoryCache() + ctx := logging.ContextWithLogger(context.Background(), slog.Default()) + + t.Run("ForwardsOriginalHeaders", func(t *testing.T) { + h := handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { return "fwd-test-1" }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil) + }) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + r.Header.Set("Accept", "application/json") + r.Header.Set("X-Custom", "forwarded") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/json", receivedHeaders.Get("Accept")) + assert.Equal(t, "forwarded", receivedHeaders.Get("X-Custom")) + }) + + t.Run("StripsHopByHopHeaders", func(t *testing.T) { + h := handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { return "fwd-test-2" }). + Transform(func(r *http.Request) (*http.Request, error) { + return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil) + }) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + r.Header.Set("Connection", "keep-alive") + r.Header.Set("Keep-Alive", "timeout=5") + r.Header.Set("Accept", "text/html") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "text/html", receivedHeaders.Get("Accept")) + assert.Equal(t, "", receivedHeaders.Get("Connection")) + assert.Equal(t, "", receivedHeaders.Get("Keep-Alive")) + }) + + t.Run("TransformHeadersTakePrecedence", func(t *testing.T) { + h := handler.New(http.DefaultClient, c). + CacheKey(func(_ *http.Request) string { return "fwd-test-3" }). + Transform(func(r *http.Request) (*http.Request, error) { + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer override") + return req, nil + }) + r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil) + r = r.WithContext(ctx) + r.Header.Set("Authorization", "Bearer original") + r.Header.Set("X-Custom", "forwarded") + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Bearer override", receivedHeaders.Get("Authorization")) + assert.Equal(t, "forwarded", receivedHeaders.Get("X-Custom")) + }) +} + func TestHandlerMethodChaining(t *testing.T) { c := mustNewMemoryCache() client := &http.Client{} diff --git a/internal/strategy/host.go b/internal/strategy/host.go index 1c2c284..1887cf5 100644 --- a/internal/strategy/host.go +++ b/internal/strategy/host.go @@ -27,16 +27,18 @@ func RegisterHost(r *Registry) { // // In this example, the strategy will be mounted under "/github.com". type HostConfig struct { - Target string `hcl:"target,label" help:"The target URL to proxy requests to."` + Target string `hcl:"target,label" help:"The target URL to proxy requests to."` + Headers map[string]string `hcl:"headers,optional" help:"Headers to add to upstream requests."` } // The Host [Strategy] forwards all GET requests to the specified host, caching the response payloads. type Host struct { - target *url.URL - cache cache.Cache - client *http.Client - logger *slog.Logger - prefix string + target *url.URL + cache cache.Cache + client *http.Client + logger *slog.Logger + prefix string + headers map[string]string } var _ Strategy = (*Host)(nil) @@ -48,11 +50,12 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux) } prefix := "/" + u.Host + u.EscapedPath() h := &Host{ - target: u, - cache: cache, - client: &http.Client{}, - logger: logging.FromContext(ctx), - prefix: prefix, + target: u, + cache: cache, + client: &http.Client{}, + logger: logging.FromContext(ctx), + prefix: prefix, + headers: config.Headers, } hdlr := handler.New(h.client, cache). @@ -61,7 +64,14 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux) }). Transform(func(r *http.Request) (*http.Request, error) { targetURL := h.buildTargetURL(r) - return http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) + req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil) + if err != nil { + return nil, errors.Wrap(err, "creating upstream request") + } + for k, v := range h.headers { + req.Header.Set(k, v) + } + return req, nil }) mux.Handle("GET "+prefix+"/", hdlr) diff --git a/internal/strategy/host_test.go b/internal/strategy/host_test.go index 1f8442b..a3bb912 100644 --- a/internal/strategy/host_test.go +++ b/internal/strategy/host_test.go @@ -98,6 +98,39 @@ func TestHostInvalidTargetURL(t *testing.T) { assert.Error(t, err) } +func TestHostHeaders(t *testing.T) { + var receivedHeaders http.Header + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + })) + defer backend.Close() + + _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) + memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour}) + assert.NoError(t, err) + defer memCache.Close() + + mux := http.NewServeMux() + _, err = strategy.NewHost(ctx, strategy.HostConfig{ + Target: backend.URL, + Headers: map[string]string{"Authorization": "Bearer QQ==", "X-Custom": "value"}, + }, memCache, mux) + assert.NoError(t, err) + + u, _ := url.Parse(backend.URL) + reqPath := "/" + u.Host + "/test" + + req := httptest.NewRequestWithContext(ctx, http.MethodGet, reqPath, nil) + w := httptest.NewRecorder() + mux.ServeHTTP(w, req) + + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "Bearer QQ==", receivedHeaders.Get("Authorization")) + assert.Equal(t, "value", receivedHeaders.Get("X-Custom")) +} + func TestHostString(t *testing.T) { _, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError}) memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})