package http import ( "fmt" stdhttp "net/http" "net/url" ) var _ stdhttp.RoundTripper = &FollowRedirect{} // FollowRedirect implements RoundTripper will follow the `location` header in // the case of a 300-399 status code type FollowRedirect struct { next stdhttp.RoundTripper } func NewFollowRedirect(next stdhttp.RoundTripper) *FollowRedirect { return &FollowRedirect{ next: next, } } // RoundTrip implements http.RoundTripper. func (f *FollowRedirect) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) { res, err := f.next.RoundTrip(req) if err != nil { return nil, err } if res.StatusCode >= 300 && res.StatusCode <= 399 { loc := res.Header.Get("location") fmt.Println("==> got a redirect, following...") if loc != "" { u, err := url.Parse(loc) if err != nil { panic("got an invalid URL from the `location` header") } clone := req.Clone(req.Context()) clone.URL = u return f.next.RoundTrip(clone) } } return res, nil }