Skip to content

Commit a0000f0

Browse files
committed
Add redirect loop detection and history cleanup tests
1 parent d24c7f1 commit a0000f0

File tree

3 files changed

+197
-66
lines changed

3 files changed

+197
-66
lines changed

redirecthandler/redirecthandler.go

Lines changed: 127 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package redirecthandler
22

33
import (
4+
"fmt"
45
"net/http"
56
"net/url"
67
"sync"
@@ -10,52 +11,70 @@ import (
1011
"go.uber.org/zap"
1112
)
1213

13-
// RedirectHandler handles HTTP redirects within an http.Client.
14-
// It provides features such as redirect loop detection, security enhancements,
15-
// and integration with client settings for fine-grained control over redirect behavior.
14+
// RedirectHandler contains configurations for handling HTTP redirects.
1615
type RedirectHandler struct {
17-
Logger logger.Logger
18-
MaxRedirects int
19-
VisitedURLs map[string]int
20-
VisitedURLsMutex sync.Mutex
21-
SensitiveHeaders []string
16+
Logger logger.Logger // Logger instance for logging.
17+
MaxRedirects int // Maximum allowed redirects to prevent infinite loops.
18+
VisitedURLs map[string]int // Tracks visited URLs to detect loops.
19+
VisitedURLsMutex sync.RWMutex // Mutex for safe concurrent access to VisitedURLs.
20+
SensitiveHeaders []string // Headers to be removed on cross-domain redirects.
21+
PermanentRedirects map[string]string // Cache for permanent redirects
22+
PermRedirectsMutex sync.RWMutex // Mutex for safe concurrent access to PermanentRedirects
23+
RedirectHistories map[*http.Request][]*url.URL // Map to track redirect history for each request
2224
}
2325

24-
// NewRedirectHandler creates a new instance of RedirectHandler with the provided logger
25-
// and maximum number of redirects. It initializes internal structures and is ready to use.
26+
// NewRedirectHandler creates a new instance of RedirectHandler.
2627
func NewRedirectHandler(logger logger.Logger, maxRedirects int) *RedirectHandler {
2728
return &RedirectHandler{
28-
Logger: logger,
29-
MaxRedirects: maxRedirects,
30-
VisitedURLs: make(map[string]int),
31-
SensitiveHeaders: []string{"Authorization", "Cookie"}, // Add other sensitive headers if needed
29+
Logger: logger,
30+
MaxRedirects: maxRedirects,
31+
VisitedURLs: make(map[string]int),
32+
SensitiveHeaders: []string{"Authorization", "Cookie"},
33+
PermanentRedirects: make(map[string]string),
34+
RedirectHistories: make(map[*http.Request][]*url.URL),
3235
}
3336
}
3437

38+
// AddSensitiveHeader allows adding configurable sensitive headers.
39+
func (r *RedirectHandler) AddSensitiveHeader(header string) {
40+
r.SensitiveHeaders = append(r.SensitiveHeaders, header)
41+
}
42+
3543
// WithRedirectHandling applies the redirect handling policy to an http.Client.
36-
// It sets the CheckRedirect function on the client to use the handler's logic.
3744
func (r *RedirectHandler) WithRedirectHandling(client *http.Client) {
3845
client.CheckRedirect = r.checkRedirect
3946
}
4047

41-
// checkRedirect is the core function that implements the redirect handling logic.
42-
// It is set as the CheckRedirect function on an http.Client and is called whenever
43-
// the client encounters a 3XX response. It enforces the max redirects limit,
44-
// detects redirect loops, applies security measures for cross-domain redirects,
45-
// resolves relative redirects, and optimizes performance.
48+
// checkRedirect implements the redirect handling logic.
4649
func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) error {
47-
// Redirect Loop Detection
48-
r.VisitedURLsMutex.Lock()
49-
defer r.VisitedURLsMutex.Unlock()
50-
if _, exists := r.VisitedURLs[req.URL.String()]; exists {
51-
r.Logger.Warn("Detected redirect loop", zap.String("url", req.URL.String()))
52-
return http.ErrUseLastResponse
50+
defer r.clearRedirectHistory(req) // Ensure redirect history is always cleared to prevent memory leaks
51+
52+
// Check for cached permanent redirect
53+
if urlString, ok := r.checkPermanentRedirect(req.URL.String()); ok && (req.Method == http.MethodGet || req.Method == http.MethodHead) {
54+
parsedURL, err := url.Parse(urlString)
55+
if err != nil {
56+
r.Logger.Error("Failed to parse URL from cache", zap.String("url", urlString), zap.Error(err))
57+
// Continue with the original URL since the cached URL is invalid
58+
} else {
59+
req.URL = parsedURL // Use cached redirect location
60+
r.Logger.Info("Using cached permanent redirect", zap.String("originalURL", urlString), zap.String("redirectURL", parsedURL.String()))
61+
return nil
62+
}
5363
}
54-
r.VisitedURLs[req.URL.String()]++
5564

65+
// Track redirect history for the current request
66+
r.RedirectHistories[req] = append(r.RedirectHistories[req], req.URL)
67+
68+
// Check for redirect loops by analyzing the history
69+
if hasLoop(r.RedirectHistories[req]) {
70+
r.Logger.Error("Redirect loop detected", zap.Any("redirectHistory", r.RedirectHistories[req]))
71+
return fmt.Errorf("redirect loop detected: %v", r.RedirectHistories[req])
72+
}
73+
74+
// Enforce max redirects
5675
if len(via) >= r.MaxRedirects {
57-
r.Logger.Warn("Stopped after maximum redirects", zap.Int("maxRedirects", r.MaxRedirects))
58-
return http.ErrUseLastResponse
76+
r.Logger.Warn("Maximum redirects reached", zap.Int("maxRedirects", r.MaxRedirects))
77+
return &MaxRedirectsError{MaxRedirects: r.MaxRedirects}
5978
}
6079

6180
lastResponse := via[len(via)-1].Response
@@ -66,69 +85,111 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request)
6685
return err
6786
}
6887

69-
// Resolve relative redirects against the current request URL
7088
newReqURL, err := r.resolveRedirectURL(req.URL, location)
7189
if err != nil {
7290
r.Logger.Error("Failed to resolve redirect URL", zap.Error(err))
7391
return err
7492
}
7593

76-
// Security Measures
94+
// Apply security measures for cross-domain redirects
7795
if newReqURL.Host != req.URL.Host {
7896
r.secureRequest(req)
7997
}
8098

81-
// Handling 303 See Other
82-
if lastResponse.StatusCode == http.StatusSeeOther {
83-
req.Method = http.MethodGet
84-
req.Body = nil
85-
req.GetBody = nil
86-
req.ContentLength = 0
87-
req.Header.Del("Content-Type")
88-
r.Logger.Info("Changed request method to GET for 303 See Other response")
99+
// Cache permanent redirects
100+
if status.IsPermanentRedirect(lastResponse.StatusCode) {
101+
r.cachePermanentRedirect(req.URL.String(), newReqURL.String())
89102
}
90103

91-
// Logging enhancements
92-
r.Logger.Info("Redirecting request",
93-
zap.String("originalURL", req.URL.String()),
94-
zap.String("newURL", newReqURL.String()),
95-
zap.String("method", req.Method),
96-
zap.Int("redirectCount", len(via)),
97-
)
98-
99-
// Log removed sensitive headers
100-
for _, header := range r.SensitiveHeaders {
101-
r.Logger.Info("Removed sensitive header due to domain change",
102-
zap.String("header", header),
103-
)
104+
// Special handling for 303 See Other
105+
if lastResponse.StatusCode == http.StatusSeeOther {
106+
r.adjustForSeeOther(req)
104107
}
105108

106-
req.URL = newReqURL
109+
r.Logger.Info("Redirecting request", zap.String("originalURL", req.URL.String()), zap.String("newURL", newReqURL.String()), zap.Int("redirectCount", len(via)))
110+
req.URL = newReqURL // Update request URL to follow the redirect
107111
return nil
108112
}
109113

110-
return http.ErrUseLastResponse
114+
return http.ErrUseLastResponse // No further action required if not a redirect status code
111115
}
112116

113-
// resolveRedirectURL resolves the redirect location URL against the current request URL
114-
// to handle relative redirects accurately.
117+
// resolveRedirectURL resolves the redirect location URL against the current request URL.
115118
func (r *RedirectHandler) resolveRedirectURL(reqURL *url.URL, redirectURL *url.URL) (*url.URL, error) {
116-
if redirectURL.IsAbs() {
117-
return redirectURL, nil // Absolute URL, no need to resolve
119+
if !redirectURL.IsAbs() {
120+
redirectURL.Scheme = reqURL.Scheme // Preserve the scheme
118121
}
119-
120-
// Relative URL, resolve against the current request URL
121-
absoluteURL := *reqURL
122-
absoluteURL.Path = redirectURL.Path
123-
absoluteURL.RawQuery = redirectURL.RawQuery
124-
absoluteURL.Fragment = redirectURL.Fragment
125-
return &absoluteURL, nil
122+
return redirectURL, nil
126123
}
127124

128125
// secureRequest removes sensitive headers from the request if the new destination is a different domain.
129126
func (r *RedirectHandler) secureRequest(req *http.Request) {
130127
for _, header := range r.SensitiveHeaders {
131128
req.Header.Del(header)
132-
r.Logger.Info("Removed sensitive header due to domain change", zap.String("header", header))
133129
}
134130
}
131+
132+
// adjustForSeeOther adjusts the request for "303 See Other" responses.
133+
func (r *RedirectHandler) adjustForSeeOther(req *http.Request) {
134+
req.Method = http.MethodGet
135+
req.Body = nil
136+
req.GetBody = nil
137+
req.ContentLength = 0
138+
req.Header.Del("Content-Type")
139+
}
140+
141+
// RedirectLoopError represents an error when a redirect loop is detected.
142+
type RedirectLoopError struct {
143+
URL string
144+
}
145+
146+
// RedirectLoopError defines an error for when a redirect loop is detected.
147+
func (e *RedirectLoopError) Error() string {
148+
return fmt.Sprintf("redirect loop detected at %s", e.URL)
149+
}
150+
151+
// MaxRedirectsError represents an error when the maximum number of redirects is reached.
152+
type MaxRedirectsError struct {
153+
MaxRedirects int
154+
}
155+
156+
// MaxRedirectsError defines an error for when the maximum number of redirects is reached.
157+
func (e *MaxRedirectsError) Error() string {
158+
return fmt.Sprintf("maximum redirects reached: %d", e.MaxRedirects)
159+
}
160+
161+
// cachePermanentRedirect caches the permanent redirect location.
162+
func (r *RedirectHandler) cachePermanentRedirect(originalURL, redirectURL string) {
163+
r.PermRedirectsMutex.Lock()
164+
defer r.PermRedirectsMutex.Unlock()
165+
166+
r.PermanentRedirects[originalURL] = redirectURL
167+
}
168+
169+
// checkPermanentRedirect checks if there's a cached redirect for the given URL.
170+
func (r *RedirectHandler) checkPermanentRedirect(originalURL string) (string, bool) {
171+
r.PermRedirectsMutex.RLock()
172+
defer r.PermRedirectsMutex.RUnlock()
173+
174+
url, exists := r.PermanentRedirects[originalURL]
175+
return url, exists
176+
}
177+
178+
// hasLoop checks if there's a loop in the redirect history.
179+
func hasLoop(history []*url.URL) bool {
180+
urlSet := make(map[string]struct{})
181+
for _, url := range history {
182+
if _, exists := urlSet[url.String()]; exists {
183+
return true // Loop detected
184+
}
185+
urlSet[url.String()] = struct{}{}
186+
}
187+
return false
188+
}
189+
190+
// clearRedirectHistory clears the redirect history for a given request to prevent memory leaks.
191+
func (r *RedirectHandler) clearRedirectHistory(req *http.Request) {
192+
r.VisitedURLsMutex.Lock() // Use the appropriate mutex to synchronize access to RedirectHistories
193+
delete(r.RedirectHistories, req)
194+
r.VisitedURLsMutex.Unlock()
195+
}

redirecthandler/redirecthandler_test.go

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,62 @@ func TestRedirectHandler_SecureRequest(t *testing.T) {
127127
assert.Contains(t, mockLogger.Calls[0].Arguments.String(0), "Removed sensitive header")
128128
})
129129
}
130+
131+
// Test for Redirect Loop Detection - This test ensures that the redirect handler correctly identifies and stops redirect loops.
132+
func TestRedirectLoopDetection(t *testing.T) {
133+
// Setup
134+
mockLogger := mocklogger.NewMockLogger()
135+
handler := NewRedirectHandler(mockLogger, 5)
136+
loopURL, _ := url.Parse("http://example.com/loop")
137+
req := &http.Request{URL: loopURL}
138+
139+
// Simulate a redirect loop by adding the same URL to the history multiple times
140+
handler.RedirectHistories[req] = []*url.URL{loopURL, loopURL}
141+
142+
// Test
143+
err := handler.checkRedirect(req, []*http.Request{req, req})
144+
145+
// Assertions
146+
assert.NotNil(t, err)
147+
assert.Contains(t, err.Error(), "redirect loop detected")
148+
// Verify log message for loop detection
149+
assert.Contains(t, mockLogger.Calls[0].Arguments.String(0), "Redirect loop detected")
150+
}
151+
152+
// TestRedirectHistoryCleanup - This test ensures that the redirect history for each request is properly cleaned up to prevent memory leaks.
153+
func TestRedirectHistoryCleanup(t *testing.T) {
154+
// Setup
155+
mockLogger := mocklogger.NewMockLogger()
156+
handler := NewRedirectHandler(mockLogger, 5)
157+
req := &http.Request{URL: &url.URL{Path: "/test"}}
158+
159+
// Simulate adding some history
160+
handler.RedirectHistories[req] = []*url.URL{{Path: "/redirect1"}, {Path: "/redirect2"}}
161+
162+
// Perform a redirect that will trigger the cleanup
163+
handler.checkRedirect(req, []*http.Request{req})
164+
165+
// Assertions
166+
_, exists := handler.RedirectHistories[req]
167+
assert.False(t, exists)
168+
}
169+
170+
// TestMaxRedirectsReached - This test checks that the handler stops redirects after reaching the maximum limit.
171+
func TestMaxRedirectsReached(t *testing.T) {
172+
// Setup
173+
mockLogger := mocklogger.NewMockLogger()
174+
handler := NewRedirectHandler(mockLogger, 1) // Set max redirects to 1
175+
req := &http.Request{URL: &url.URL{Path: "/start"}}
176+
via := []*http.Request{{}, {}} // Simulate one redirect has already occurred
177+
178+
// Test
179+
err := handler.checkRedirect(req, via)
180+
181+
// Assertions
182+
assert.NotNil(t, err)
183+
assert.IsType(t, &MaxRedirectsError{}, err)
184+
maxRedirectsError := err.(*MaxRedirectsError)
185+
assert.Equal(t, 1, maxRedirectsError.MaxRedirects)
186+
// Verify log message for max redirects reached
187+
assert.Contains(t, mockLogger.Calls[0].Arguments.String(0), "Maximum redirects reached")
188+
}

status/status.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,17 @@ func IsRedirectStatusCode(statusCode int) bool {
9696
}
9797
}
9898

99+
// IsPermanentRedirect checks if the provided HTTP status code is one of the permanent redirect codes.
100+
func IsPermanentRedirect(statusCode int) bool {
101+
switch statusCode {
102+
case http.StatusMovedPermanently, // 301
103+
http.StatusPermanentRedirect: // 308
104+
return true
105+
default:
106+
return false
107+
}
108+
}
109+
99110
// IsNonRetryableStatusCode checks if the provided response indicates a non-retryable error.
100111
func IsNonRetryableStatusCode(resp *http.Response) bool {
101112
// Expanded list of non-retryable HTTP status codes

0 commit comments

Comments
 (0)