105 lines
3.5 KiB
105 lines
3.5 KiB
package ratelimit
|
|
|
|
import (
|
|
"net/http"
|
|
"sync/atomic"
|
|
"time"
|
|
)
|
|
|
|
type Limiter struct {
|
|
current atomic.Uint64
|
|
last atomic.Uint64
|
|
limit int // number of requests per evaluation window
|
|
reset atomic.Int64 // time of last reset in UnixMicro
|
|
windowMicros int64
|
|
}
|
|
|
|
func NewLimiter(limit int, window time.Duration) *Limiter {
|
|
l := &Limiter{
|
|
limit: limit,
|
|
reset: atomic.Int64{},
|
|
windowMicros: window.Microseconds(),
|
|
}
|
|
|
|
go func() {
|
|
// If the current time in microseconds exceeds the last reset time by more
|
|
// than the evaluation window, attempt to reset the limiter. This is done by
|
|
// atomically comparing and swapping the reset time. If successful, store
|
|
// the current count in the last count and reset the current count to 0.
|
|
t := time.NewTicker(window)
|
|
for now := range t.C {
|
|
l.reset.Store(now.UnixMicro())
|
|
l.last.Store(l.current.Swap(0))
|
|
}
|
|
}()
|
|
|
|
return l
|
|
}
|
|
|
|
// Allow returns true if the current rate of requests is below the limit set by the
|
|
// Limiter.
|
|
//
|
|
// rate calculates the current rate of some process based on the last recorded rate,
|
|
// the time elapsed since the last reset, and the current value. It uses a weighted
|
|
// average approach where the weight is determined by the proportion of the time
|
|
// remaining in the current minute.
|
|
//
|
|
// The formula is:
|
|
// rate = (lastValue * (remainingTime / evaluationWindow)) + currentValue
|
|
//
|
|
// For example, imagine we set a limit of 50 requests per minute and the evaluation
|
|
// window is 60 seconds. We recorded 42 requests during the last evaluation window,
|
|
// and 18 requests have been recorded after 15 seconds in the current evaluation window.
|
|
// The rate would be calculated as follows:
|
|
//
|
|
// rate = 42 * ((60-15)/60) + 18
|
|
//
|
|
// = 42 * 0.75 + 18
|
|
// = 49.5 requests
|
|
//
|
|
// - lastValue: The count recorded during the last evaluation window.
|
|
// - remainingTime: The time remaining in the current evaluation window (in microseconds).
|
|
// - evalutationWindow: The duration of the evaluation window (in microseconds).
|
|
// - currentValue: The value of the current evaluation window.
|
|
//
|
|
// This approach ensures that the rate smoothly transitions from the last recorded
|
|
// rate to the current value as time progresses within the evaluation window.
|
|
//
|
|
// Credit to CloudFlare for the idea.
|
|
// see: https://blog.cloudflare.com/counting-things-a-lot-of-different-things/
|
|
func (l *Limiter) Allow() bool {
|
|
last := float64(l.last.Load())
|
|
current := float64(l.current.Load())
|
|
nowMicros := time.Now().UnixMicro()
|
|
resetMicros := l.reset.Load()
|
|
elapsed := nowMicros - resetMicros
|
|
|
|
rate := (last * (float64(l.windowMicros-(elapsed)) / float64(l.windowMicros))) + current
|
|
|
|
return rate <= float64(l.limit)
|
|
}
|
|
|
|
type RateLimitedHandler struct {
|
|
Next http.Handler
|
|
Limiter *Limiter
|
|
}
|
|
|
|
// ServeHTTP implements the http.Handler interface for RateLimitedHandler. It
|
|
// checks the current rate of requests against the limit set by the Limiter. If
|
|
// the rate exceeds the limit, it returns a 429 Too Many Requests status code.
|
|
// Otherwise, it calls the Next handler in the chain.
|
|
func (rlh *RateLimitedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
l := rlh.Limiter
|
|
|
|
if !l.Allow() {
|
|
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
|
|
return
|
|
}
|
|
|
|
// Asynchronously increment the current count to avoid blocking the request.
|
|
// This is done after the request has been allowed to proceed to ensure that
|
|
// the rate is calculated correctly.
|
|
rlh.Next.ServeHTTP(w, r)
|
|
go l.current.Add(1)
|
|
}
|