commit eba33751ed99518edec5edead8e05da50f9c4797 Author: William Perron <hey@wperron.io> Date: Wed Feb 19 09:40:16 2025 -0500 initial commit diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..e6afc77 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 William Perron + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..0dfedd2 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# HTTP Rate Limiter + +An HTTP Handler that implements rate limiting. diff --git a/cmd/example/main.go b/cmd/example/main.go new file mode 100644 index 0000000..31ca448 --- /dev/null +++ b/cmd/example/main.go @@ -0,0 +1,29 @@ +package main + +import ( + "fmt" + "log" + "net/http" + "time" + + "go.wperron.io/ratelimit" +) + +func main() { + hello := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintf(w, "Hello, World!") + }) + + // A simple Hello World handler wrapped in a RateLimitedHandler + l := ratelimit.NewLimiter(1_000, time.Second) + handler := &ratelimit.RateLimitedHandler{ + Next: hello, + Limiter: l, + } + + log.Printf("Starting server on :8080") + if err := http.ListenAndServe(":8080", handler); err != nil { + log.Println(err) + } + log.Println("Server stopped") +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5db8c5f --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module go.wperron.io/ratelimit + +go 1.24.0 diff --git a/limiter.go b/limiter.go new file mode 100644 index 0000000..5b357eb --- /dev/null +++ b/limiter.go @@ -0,0 +1,104 @@ +package ratelimit + +import ( + "net/http" + "sync/atomic" + "time" +) + +type Limiter struct { + current atomic.Uint64 + last atomic.Uint64 + limit int // number of requests per evaluation window + reset atomic.Int64 // time of last reset in UnixMicro + windowMicros int64 +} + +func NewLimiter(limit int, window time.Duration) *Limiter { + l := &Limiter{ + limit: limit, + reset: atomic.Int64{}, + windowMicros: window.Microseconds(), + } + + go func() { + // If the current time in microseconds exceeds the last reset time by more + // than the evaluation window, attempt to reset the limiter. This is done by + // atomically comparing and swapping the reset time. If successful, store + // the current count in the last count and reset the current count to 0. + t := time.NewTicker(window) + for now := range t.C { + l.reset.Store(now.UnixMicro()) + l.last.Store(l.current.Swap(0)) + } + }() + + return l +} + +// Allow returns true if the current rate of requests is below the limit set by the +// Limiter. +// +// rate calculates the current rate of some process based on the last recorded rate, +// the time elapsed since the last reset, and the current value. It uses a weighted +// average approach where the weight is determined by the proportion of the time +// remaining in the current minute. +// +// The formula is: +// rate = (lastValue * (remainingTime / evaluationWindow)) + currentValue +// +// For example, imagine we set a limit of 50 requests per minute and the evaluation +// window is 60 seconds. We recorded 42 requests during the last evaluation window, +// and 18 requests have been recorded after 15 seconds in the current evaluation window. +// The rate would be calculated as follows: +// +// rate = 42 * ((60-15)/60) + 18 +// +// = 42 * 0.75 + 18 +// = 49.5 requests +// +// - lastValue: The count recorded during the last evaluation window. +// - remainingTime: The time remaining in the current evaluation window (in microseconds). +// - evalutationWindow: The duration of the evaluation window (in microseconds). +// - currentValue: The value of the current evaluation window. +// +// This approach ensures that the rate smoothly transitions from the last recorded +// rate to the current value as time progresses within the evaluation window. +// +// Credit to CloudFlare for the idea. +// see: https://blog.cloudflare.com/counting-things-a-lot-of-different-things/ +func (l *Limiter) Allow() bool { + last := float64(l.last.Load()) + current := float64(l.current.Load()) + nowMicros := time.Now().UnixMicro() + resetMicros := l.reset.Load() + elapsed := nowMicros - resetMicros + + rate := (last * (float64(l.windowMicros-(elapsed)) / float64(l.windowMicros))) + current + + return rate <= float64(l.limit) +} + +type RateLimitedHandler struct { + Next http.Handler + Limiter *Limiter +} + +// ServeHTTP implements the http.Handler interface for RateLimitedHandler. It +// checks the current rate of requests against the limit set by the Limiter. If +// the rate exceeds the limit, it returns a 429 Too Many Requests status code. +// Otherwise, it calls the Next handler in the chain. +func (rlh *RateLimitedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + l := rlh.Limiter + + if !l.Allow() { + http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests) + return + } + + // Asynchronously increment the current count to avoid blocking the request. + // This is done after the request has been allowed to proceed to ensure that + // the rate is calculated correctly. + rlh.Next.ServeHTTP(w, r) + go l.current.Add(1) +}