From 6b6ecff1fd48be5cb5d917cef59e51661523a3a8 Mon Sep 17 00:00:00 2001 From: William Perron Date: Wed, 15 Feb 2023 15:45:42 -0500 Subject: [PATCH] initial commit --- .gitignore | 0 example/http_client/main.go | 15 ++++++++++++ go.mod | 3 +++ http/round_tripper.go | 46 +++++++++++++++++++++++++++++++++++++ 4 files changed, 64 insertions(+) create mode 100644 .gitignore create mode 100644 example/http_client/main.go create mode 100644 go.mod create mode 100644 http/round_tripper.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e69de29 diff --git a/example/http_client/main.go b/example/http_client/main.go new file mode 100644 index 0000000..31a9a69 --- /dev/null +++ b/example/http_client/main.go @@ -0,0 +1,15 @@ +package main + +import ( + "fmt" + stdhttp "net/http" + + "go.wperron.io/toolkit/http" +) + +func main() { + stdclient := stdhttp.Client{} + stdclient.Transport = http.NewFollowRedirect(stdhttp.DefaultTransport) + + fmt.Println(stdclient.Get("https://google.com")) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..7d4f557 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module go.wperron.io/toolkit + +go 1.20 diff --git a/http/round_tripper.go b/http/round_tripper.go new file mode 100644 index 0000000..7845af2 --- /dev/null +++ b/http/round_tripper.go @@ -0,0 +1,46 @@ +package http + +import ( + "fmt" + stdhttp "net/http" + "net/url" +) + +var _ stdhttp.RoundTripper = &FollowRedirect{} + +// FollowRedirect implements RoundTripper will follow the `location` header in +// the case of a 300-399 status code +type FollowRedirect struct { + next stdhttp.RoundTripper +} + +func NewFollowRedirect(next stdhttp.RoundTripper) *FollowRedirect { + return &FollowRedirect{ + next: next, + } +} + +// RoundTrip implements http.RoundTripper. +func (f *FollowRedirect) RoundTrip(req *stdhttp.Request) (*stdhttp.Response, error) { + res, err := f.next.RoundTrip(req) + if err != nil { + return nil, err + } + + if res.StatusCode >= 300 && res.StatusCode <= 399 { + loc := res.Header.Get("location") + + fmt.Println("==> got a redirect, following...") + if loc != "" { + u, err := url.Parse(loc) + if err != nil { + panic("got an invalid URL from the `location` header") + } + clone := req.Clone(req.Context()) + clone.URL = u + + return f.next.RoundTrip(clone) + } + } + return res, nil +}