commit
6b6ecff1fd
@ -0,0 +1,15 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
stdhttp "net/http"
|
||||
|
||||
"go.wperron.io/toolkit/http"
|
||||
)
|
||||
|
||||
func main() {
|
||||
stdclient := stdhttp.Client{}
|
||||
stdclient.Transport = http.NewFollowRedirect(stdhttp.DefaultTransport)
|
||||
|
||||
fmt.Println(stdclient.Get("https://google.com"))
|
||||
}
|
@ -0,0 +1,46 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
stdhttp "net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
var _ stdhttp.RoundTripper = &FollowRedirect{}
|
||||
|
||||
// FollowRedirect implements RoundTripper will follow the `location` header in
|
||||
// the case of a 300-399 status code
|
||||
type FollowRedirect struct {
|
||||
next stdhttp.RoundTripper
|
||||
}
|
||||
|
||||
func NewFollowRedirect(next stdhttp.RoundTripper) *FollowRedirect {
|
||||
return &FollowRedirect{
|
||||
next: next,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip implements http.RoundTripper.
|
||||
func (f *FollowRedirect) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) {
|
||||
res, err := f.next.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res.StatusCode >= 300 && res.StatusCode <= 399 {
|
||||
loc := res.Header.Get("location")
|
||||
|
||||
fmt.Println("==> got a redirect, following...")
|
||||
if loc != "" {
|
||||
u, err := url.Parse(loc)
|
||||
if err != nil {
|
||||
panic("got an invalid URL from the `location` header")
|
||||
}
|
||||
clone := req.Clone(req.Context())
|
||||
clone.URL = u
|
||||
|
||||
return f.next.RoundTrip(clone)
|
||||
}
|
||||
}
|
||||
return res, nil
|
||||
}
|
Loading…
Reference in new issue