Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 0 additions & 15 deletions internal/cache/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,21 +117,6 @@ func (k *Key) MarshalText() ([]byte, error) {
return []byte(k.String()), nil
}

// FilterTransportHeaders returns a copy of the given headers with standard HTTP transport headers removed.
// These headers are typically added by HTTP clients/servers and should not be cached.
func FilterTransportHeaders(headers http.Header) http.Header {
filtered := make(http.Header)
for key, values := range headers {
// Skip standard HTTP headers added by transport layer or that shouldn't be cached
if key == "Content-Length" || key == "Date" || key == "Accept-Encoding" ||
key == "User-Agent" || key == "Transfer-Encoding" || key == "Time-To-Live" {
continue
}
filtered[key] = values
}
return filtered
}

// Stats contains health and usage statistics for a cache.
type Stats struct {
// Objects is the number of objects currently in the cache.
Expand Down
6 changes: 4 additions & 2 deletions internal/cache/remote.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"time"

"github.com/alecthomas/errors"

"github.com/block/cachew/internal/httputil"
)

const defaultNamespace = "-"
Expand Down Expand Up @@ -66,7 +68,7 @@ func (c *Remote) Open(ctx context.Context, key Key) (io.ReadCloser, http.Header,
}

// Filter out HTTP transport headers
headers := FilterTransportHeaders(resp.Header)
headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...)

return resp.Body, headers, nil
}
Expand Down Expand Up @@ -98,7 +100,7 @@ func (c *Remote) Stat(ctx context.Context, key Key) (http.Header, error) {
}

// Filter out HTTP transport headers
headers := FilterTransportHeaders(resp.Header)
headers := httputil.FilterHeaders(resp.Header, httputil.TransportHeaders...)

return headers, nil
}
Expand Down
42 changes: 42 additions & 0 deletions internal/httputil/headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package httputil

import "net/http"

// TransportHeaders are headers added by the HTTP transport layer that should not be cached.
var TransportHeaders = []string{ //nolint:gochecknoglobals
"Content-Length",
"Date",
"Accept-Encoding",
"User-Agent",
"Transfer-Encoding",
"Time-To-Live",
}

// HopByHopHeaders are hop-by-hop headers that should not be forwarded by proxies (RFC 7230).
var HopByHopHeaders = []string{ //nolint:gochecknoglobals
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"Te",
"Trailers",
"Transfer-Encoding",
"Upgrade",
"Host",
}

// FilterHeaders returns a copy of headers with the specified header keys removed.
func FilterHeaders(headers http.Header, skip ...string) http.Header {
skipSet := make(map[string]bool, len(skip))
for _, s := range skip {
skipSet[http.CanonicalHeaderKey(s)] = true
}
filtered := make(http.Header, len(headers))
for key, values := range headers {
if skipSet[http.CanonicalHeaderKey(key)] {
continue
}
filtered[key] = values
}
return filtered
}
74 changes: 74 additions & 0 deletions internal/httputil/headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package httputil_test

import (
"net/http"
"testing"

"github.com/alecthomas/assert/v2"

"github.com/block/cachew/internal/httputil"
)

func TestFilterHeaders(t *testing.T) {
tests := []struct {
name string
headers http.Header
skip []string
expected http.Header
}{
{
name: "Empty",
headers: http.Header{},
skip: httputil.TransportHeaders,
expected: http.Header{},
},
{
name: "TransportHeaders",
headers: http.Header{
"Content-Type": {"application/json"},
"Content-Length": {"42"},
"Date": {"Mon, 01 Jan 2024 00:00:00 GMT"},
"Transfer-Encoding": {"chunked"},
"X-Custom": {"value"},
},
skip: httputil.TransportHeaders,
expected: http.Header{
"Content-Type": {"application/json"},
"X-Custom": {"value"},
},
},
{
name: "HopByHopHeaders",
headers: http.Header{
"Accept": {"text/html"},
"Authorization": {"Bearer token"},
"Connection": {"keep-alive"},
"Keep-Alive": {"timeout=5"},
"Host": {"example.com"},
"Upgrade": {"websocket"},
},
skip: httputil.HopByHopHeaders,
expected: http.Header{
"Accept": {"text/html"},
"Authorization": {"Bearer token"},
},
},
{
name: "CaseInsensitive",
headers: http.Header{
"content-length": {"42"},
"X-Custom": {"value"},
},
skip: []string{"Content-Length"},
expected: http.Header{
"X-Custom": {"value"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := httputil.FilterHeaders(tt.headers, tt.skip...)
assert.Equal(t, tt.expected, result)
})
}
}
3 changes: 2 additions & 1 deletion internal/strategy/apiv1.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"github.com/alecthomas/errors"

"github.com/block/cachew/internal/cache"
"github.com/block/cachew/internal/httputil"
"github.com/block/cachew/internal/logging"
)

Expand Down Expand Up @@ -116,7 +117,7 @@ func (d *APIV1) putObject(w http.ResponseWriter, r *http.Request) {
}

// Extract and filter headers from request
headers := cache.FilterTransportHeaders(r.Header)
headers := httputil.FilterHeaders(r.Header, httputil.TransportHeaders...)

namespacedCache := d.cache.Namespace(namespace)
cw, err := namespacedCache.Create(r.Context(), key, headers, ttl)
Expand Down
8 changes: 8 additions & 0 deletions internal/strategy/handler/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ func (h *Handler) fetchAndCache(w http.ResponseWriter, r *http.Request, key cach
return
}

// Forward safe headers from the original request, without overwriting headers set by transform.
forwardable := httputil.FilterHeaders(r.Header, httputil.HopByHopHeaders...)
for key, values := range forwardable {
if upstreamReq.Header.Get(key) == "" {
upstreamReq.Header[key] = values
}
}

resp, err := h.client.Do(upstreamReq)
if err != nil {
h.errorHandler(httputil.Errorf(http.StatusBadGateway, "failed to fetch: %w", err), w, r)
Expand Down
73 changes: 73 additions & 0 deletions internal/strategy/handler/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,79 @@ func TestBuilder(t *testing.T) {
}
}

func TestHeaderForwarding(t *testing.T) {
var receivedHeaders http.Header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
_, _ = fmt.Fprint(w, "ok")
}))
defer upstream.Close()

c := mustNewMemoryCache()
ctx := logging.ContextWithLogger(context.Background(), slog.Default())

t.Run("ForwardsOriginalHeaders", func(t *testing.T) {
h := handler.New(http.DefaultClient, c).
CacheKey(func(_ *http.Request) string { return "fwd-test-1" }).
Transform(func(r *http.Request) (*http.Request, error) {
return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil)
})
r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
r.Header.Set("X-Custom", "forwarded")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "application/json", receivedHeaders.Get("Accept"))
assert.Equal(t, "forwarded", receivedHeaders.Get("X-Custom"))
})

t.Run("StripsHopByHopHeaders", func(t *testing.T) {
h := handler.New(http.DefaultClient, c).
CacheKey(func(_ *http.Request) string { return "fwd-test-2" }).
Transform(func(r *http.Request) (*http.Request, error) {
return http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil)
})
r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil)
r = r.WithContext(ctx)
r.Header.Set("Connection", "keep-alive")
r.Header.Set("Keep-Alive", "timeout=5")
r.Header.Set("Accept", "text/html")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "text/html", receivedHeaders.Get("Accept"))
assert.Equal(t, "", receivedHeaders.Get("Connection"))
assert.Equal(t, "", receivedHeaders.Get("Keep-Alive"))
})

t.Run("TransformHeadersTakePrecedence", func(t *testing.T) {
h := handler.New(http.DefaultClient, c).
CacheKey(func(_ *http.Request) string { return "fwd-test-3" }).
Transform(func(r *http.Request) (*http.Request, error) {
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, upstream.URL+"/test", nil)
if err != nil {
return nil, err
}
req.Header.Set("Authorization", "Bearer override")
return req, nil
})
r := httptest.NewRequest(http.MethodGet, "http://example.com/test", nil)
r = r.WithContext(ctx)
r.Header.Set("Authorization", "Bearer original")
r.Header.Set("X-Custom", "forwarded")
w := httptest.NewRecorder()
h.ServeHTTP(w, r)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Bearer override", receivedHeaders.Get("Authorization"))
assert.Equal(t, "forwarded", receivedHeaders.Get("X-Custom"))
})
}

func TestHandlerMethodChaining(t *testing.T) {
c := mustNewMemoryCache()
client := &http.Client{}
Expand Down
34 changes: 22 additions & 12 deletions internal/strategy/host.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,18 @@ func RegisterHost(r *Registry) {
//
// In this example, the strategy will be mounted under "/github.com".
type HostConfig struct {
Target string `hcl:"target,label" help:"The target URL to proxy requests to."`
Target string `hcl:"target,label" help:"The target URL to proxy requests to."`
Headers map[string]string `hcl:"headers,optional" help:"Headers to add to upstream requests."`
}

// The Host [Strategy] forwards all GET requests to the specified host, caching the response payloads.
type Host struct {
target *url.URL
cache cache.Cache
client *http.Client
logger *slog.Logger
prefix string
target *url.URL
cache cache.Cache
client *http.Client
logger *slog.Logger
prefix string
headers map[string]string
}

var _ Strategy = (*Host)(nil)
Expand All @@ -48,11 +50,12 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux)
}
prefix := "/" + u.Host + u.EscapedPath()
h := &Host{
target: u,
cache: cache,
client: &http.Client{},
logger: logging.FromContext(ctx),
prefix: prefix,
target: u,
cache: cache,
client: &http.Client{},
logger: logging.FromContext(ctx),
prefix: prefix,
headers: config.Headers,
}

hdlr := handler.New(h.client, cache).
Expand All @@ -61,7 +64,14 @@ func NewHost(ctx context.Context, config HostConfig, cache cache.Cache, mux Mux)
}).
Transform(func(r *http.Request) (*http.Request, error) {
targetURL := h.buildTargetURL(r)
return http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil)
req, err := http.NewRequestWithContext(r.Context(), http.MethodGet, targetURL.String(), nil)
if err != nil {
return nil, errors.Wrap(err, "creating upstream request")
}
for k, v := range h.headers {
req.Header.Set(k, v)
}
return req, nil
})

mux.Handle("GET "+prefix+"/", hdlr)
Expand Down
33 changes: 33 additions & 0 deletions internal/strategy/host_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,39 @@ func TestHostInvalidTargetURL(t *testing.T) {
assert.Error(t, err)
}

func TestHostHeaders(t *testing.T) {
var receivedHeaders http.Header
backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
receivedHeaders = r.Header.Clone()
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
defer backend.Close()

_, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError})
memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})
assert.NoError(t, err)
defer memCache.Close()

mux := http.NewServeMux()
_, err = strategy.NewHost(ctx, strategy.HostConfig{
Target: backend.URL,
Headers: map[string]string{"Authorization": "Bearer QQ==", "X-Custom": "value"},
}, memCache, mux)
assert.NoError(t, err)

u, _ := url.Parse(backend.URL)
reqPath := "/" + u.Host + "/test"

req := httptest.NewRequestWithContext(ctx, http.MethodGet, reqPath, nil)
w := httptest.NewRecorder()
mux.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, "Bearer QQ==", receivedHeaders.Get("Authorization"))
assert.Equal(t, "value", receivedHeaders.Get("X-Custom"))
}

func TestHostString(t *testing.T) {
_, ctx := logging.Configure(context.Background(), logging.Config{Level: slog.LevelError})
memCache, err := cache.NewMemory(ctx, cache.MemoryConfig{MaxTTL: time.Hour})
Expand Down