package http import ( stdhttp "net/http" "net/url" ) var _ stdhttp.RoundTripper = &FollowRedirect{} // FollowRedirect implements RoundTripper will follow the `location` header in // the case of a 300-399 status code type FollowRedirect struct { next stdhttp.RoundTripper } func NewFollowRedirect(next stdhttp.RoundTripper) *FollowRedirect { return &FollowRedirect{ next: next, } } // RoundTrip implements http.RoundTripper. func (f *FollowRedirect) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) { res, err := f.next.RoundTrip(req) if err != nil { return nil, err } if res.StatusCode >= 300 && res.StatusCode <= 399 { loc := res.Header.Get("location") if loc != "" { u, err := url.Parse(loc) if err != nil { // If the location header can't be parsed as a URL, exit early // and return the original response without attempting to follow // the redirect. return res, nil } clone := req.Clone(req.Context()) clone.URL = u return f.next.RoundTrip(clone) } } return res, nil }