Skip to content

Commit 49ec251

Browse files
committed
Add handling for non-idempotent methods in redirect handler
1 parent a0000f0 commit 49ec251

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

redirecthandler/redirecthandler.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,13 @@ func (r *RedirectHandler) WithRedirectHandling(client *http.Client) {
4949
func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request) error {
5050
defer r.clearRedirectHistory(req) // Ensure redirect history is always cleared to prevent memory leaks
5151

52+
// Non-idempotent methods handling
53+
if req.Method == http.MethodPost || req.Method == http.MethodPatch {
54+
r.Logger.Warn("Redirect attempted on non-idempotent method, not following", zap.String("method", req.Method))
55+
// Stop redirection and return the response as is
56+
return http.ErrUseLastResponse
57+
}
58+
5259
// Check for cached permanent redirect
5360
if urlString, ok := r.checkPermanentRedirect(req.URL.String()); ok && (req.Method == http.MethodGet || req.Method == http.MethodHead) {
5461
parsedURL, err := url.Parse(urlString)
@@ -111,6 +118,13 @@ func (r *RedirectHandler) checkRedirect(req *http.Request, via []*http.Request)
111118
return nil
112119
}
113120

121+
// Clear redirect history if redirect is successful
122+
if len(via) > 0 && lastResponse.StatusCode >= 200 && lastResponse.StatusCode < 400 {
123+
// Clear history for the redirected request
124+
redirectedReq := via[len(via)-1]
125+
r.clearRedirectHistory(redirectedReq)
126+
}
127+
114128
return http.ErrUseLastResponse // No further action required if not a redirect status code
115129
}
116130

@@ -193,3 +207,11 @@ func (r *RedirectHandler) clearRedirectHistory(req *http.Request) {
193207
delete(r.RedirectHistories, req)
194208
r.VisitedURLsMutex.Unlock()
195209
}
210+
211+
// GetRedirectHistory returns the redirect history for a given request.
212+
func (r *RedirectHandler) GetRedirectHistory(req *http.Request) []*url.URL {
213+
r.VisitedURLsMutex.RLock()
214+
defer r.VisitedURLsMutex.RUnlock()
215+
216+
return r.RedirectHistories[req]
217+
}

0 commit comments

Comments
 (0)