You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
986 B
47 lines
986 B
2 years ago
|
package http
|
||
|
|
||
|
import (
|
||
|
"fmt"
|
||
|
stdhttp "net/http"
|
||
|
"net/url"
|
||
|
)
|
||
|
|
||
|
var _ stdhttp.RoundTripper = &FollowRedirect{}
|
||
|
|
||
|
// FollowRedirect implements RoundTripper will follow the `location` header in
|
||
|
// the case of a 300-399 status code
|
||
|
type FollowRedirect struct {
|
||
|
next stdhttp.RoundTripper
|
||
|
}
|
||
|
|
||
|
func NewFollowRedirect(next stdhttp.RoundTripper) *FollowRedirect {
|
||
|
return &FollowRedirect{
|
||
|
next: next,
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// RoundTrip implements http.RoundTripper.
|
||
|
func (f *FollowRedirect) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) {
|
||
|
res, err := f.next.RoundTrip(req)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
if res.StatusCode >= 300 && res.StatusCode <= 399 {
|
||
|
loc := res.Header.Get("location")
|
||
|
|
||
|
fmt.Println("==> got a redirect, following...")
|
||
|
if loc != "" {
|
||
|
u, err := url.Parse(loc)
|
||
|
if err != nil {
|
||
|
panic("got an invalid URL from the `location` header")
|
||
|
}
|
||
|
clone := req.Clone(req.Context())
|
||
|
clone.URL = u
|
||
|
|
||
|
return f.next.RoundTrip(clone)
|
||
|
}
|
||
|
}
|
||
|
return res, nil
|
||
|
}
|