ratelimit/limiter.go

105 lines
3.5 KiB

package ratelimit
import (
"net/http"
"sync/atomic"
"time"
)
type Limiter struct {
current atomic.Uint64
last atomic.Uint64
limit int // number of requests per evaluation window
reset atomic.Int64 // time of last reset in UnixMicro
windowMicros int64
}
func NewLimiter(limit int, window time.Duration) *Limiter {
l := &Limiter{
limit: limit,
reset: atomic.Int64{},
windowMicros: window.Microseconds(),
}
go func() {
// If the current time in microseconds exceeds the last reset time by more
// than the evaluation window, attempt to reset the limiter. This is done by
// atomically comparing and swapping the reset time. If successful, store
// the current count in the last count and reset the current count to 0.
t := time.NewTicker(window)
for now := range t.C {
l.reset.Store(now.UnixMicro())
l.last.Store(l.current.Swap(0))
}
}()
return l
}
// Allow returns true if the current rate of requests is below the limit set by the
// Limiter.
//
// rate calculates the current rate of some process based on the last recorded rate,
// the time elapsed since the last reset, and the current value. It uses a weighted
// average approach where the weight is determined by the proportion of the time
// remaining in the current minute.
//
// The formula is:
// rate = (lastValue * (remainingTime / evaluationWindow)) + currentValue
//
// For example, imagine we set a limit of 50 requests per minute and the evaluation
// window is 60 seconds. We recorded 42 requests during the last evaluation window,
// and 18 requests have been recorded after 15 seconds in the current evaluation window.
// The rate would be calculated as follows:
//
// rate = 42 * ((60-15)/60) + 18
//
// = 42 * 0.75 + 18
// = 49.5 requests
//
// - lastValue: The count recorded during the last evaluation window.
// - remainingTime: The time remaining in the current evaluation window (in microseconds).
// - evalutationWindow: The duration of the evaluation window (in microseconds).
// - currentValue: The value of the current evaluation window.
//
// This approach ensures that the rate smoothly transitions from the last recorded
// rate to the current value as time progresses within the evaluation window.
//
// Credit to CloudFlare for the idea.
// see: https://blog.cloudflare.com/counting-things-a-lot-of-different-things/
func (l *Limiter) Allow() bool {
last := float64(l.last.Load())
current := float64(l.current.Load())
nowMicros := time.Now().UnixMicro()
resetMicros := l.reset.Load()
elapsed := nowMicros - resetMicros
rate := (last * (float64(l.windowMicros-(elapsed)) / float64(l.windowMicros))) + current
return rate <= float64(l.limit)
}
type RateLimitedHandler struct {
Next http.Handler
Limiter *Limiter
}
// ServeHTTP implements the http.Handler interface for RateLimitedHandler. It
// checks the current rate of requests against the limit set by the Limiter. If
// the rate exceeds the limit, it returns a 429 Too Many Requests status code.
// Otherwise, it calls the Next handler in the chain.
func (rlh *RateLimitedHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
l := rlh.Limiter
if !l.Allow() {
http.Error(w, "Rate limit exceeded", http.StatusTooManyRequests)
return
}
// Asynchronously increment the current count to avoid blocking the request.
// This is done after the request has been allowed to proceed to ensure that
// the rate is calculated correctly.
rlh.Next.ServeHTTP(w, r)
go l.current.Add(1)
}