You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.0 KiB
47 lines
1.0 KiB
package http
|
|
|
|
import (
|
|
stdhttp "net/http"
|
|
"net/url"
|
|
)
|
|
|
|
var _ stdhttp.RoundTripper = &FollowRedirect{}
|
|
|
|
// FollowRedirect implements RoundTripper will follow the `location` header in
|
|
// the case of a 300-399 status code
|
|
type FollowRedirect struct {
|
|
next stdhttp.RoundTripper
|
|
}
|
|
|
|
func NewFollowRedirect(next stdhttp.RoundTripper) *FollowRedirect {
|
|
return &FollowRedirect{
|
|
next: next,
|
|
}
|
|
}
|
|
|
|
// RoundTrip implements http.RoundTripper.
|
|
func (f *FollowRedirect) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) {
|
|
res, err := f.next.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if res.StatusCode >= 300 && res.StatusCode <= 399 {
|
|
loc := res.Header.Get("location")
|
|
if loc != "" {
|
|
u, err := url.Parse(loc)
|
|
if err != nil {
|
|
// If the location header can't be parsed as a URL, exit early
|
|
// and return the original response without attempting to follow
|
|
// the redirect.
|
|
return res, nil
|
|
}
|
|
clone := req.Clone(req.Context())
|
|
clone.URL = u
|
|
|
|
return f.next.RoundTrip(clone)
|
|
}
|
|
}
|
|
return res, nil
|
|
}
|