commit ef64204f780bc7fc52be17e09d78cd9624da81dc Author: jdl Date: Fri Oct 13 13:14:40 2023 +0200 Initial commit diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..157bbe8 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/public/ratelimiter + +go 1.21.1 diff --git a/ratelimiter.go b/ratelimiter.go new file mode 100644 index 0000000..f17b473 --- /dev/null +++ b/ratelimiter.go @@ -0,0 +1,86 @@ +package ratelimiter + +import ( + "errors" + "sync" + "time" +) + +var ErrBackoff = errors.New("Backoff") + +type Config struct { + BurstLimit int64 // Number of requests to allow to burst. + FillPeriod time.Duration // Add one per period. + MaxWaitCount int64 // Max number of waiting requests. 0 disables. +} + +type Limiter struct { + lock sync.Mutex + + fillPeriod time.Duration + minWaitTime time.Duration + maxWaitTime time.Duration + + waitTime time.Duration // If waitTime < 0, no waiting occurs. + lastRequest time.Time +} + +func New(conf Config) *Limiter { + if conf.BurstLimit < 0 { + panic(conf.BurstLimit) + } + if conf.FillPeriod <= 0 { + panic(conf.FillPeriod) + } + if conf.MaxWaitCount < 0 { + panic(conf.MaxWaitCount) + } + + lim := &Limiter{ + lastRequest: time.Now(), + fillPeriod: conf.FillPeriod, + waitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), + minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), + maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount), + } + + lim.waitTime = lim.minWaitTime + + return lim +} + +func (lim *Limiter) limit(count int64) (time.Duration, error) { + lim.lock.Lock() + defer lim.lock.Unlock() + + dt := time.Since(lim.lastRequest) + + waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod + + if waitTime < lim.minWaitTime { + waitTime = lim.minWaitTime + } else if waitTime > lim.maxWaitTime { + return 0, ErrBackoff + } + + lim.waitTime = waitTime + lim.lastRequest = lim.lastRequest.Add(dt) + + return lim.waitTime, nil +} + +// Apply the limiter to the calling thread. The function may sleep for up to +// maxWaitTime before returning. If the timeout would need to be more than +// maxWaitTime to enforce the rate limit, ErrBackoff is returned. +func (lim *Limiter) Limit() error { + dt, err := lim.limit(1) + time.Sleep(dt) // Will return immediately for dt <= 0. + return err +} + +// Apply the limiter for multiple items at once. +func (lim *Limiter) LimitMultiple(count int64) error { + dt, err := lim.limit(count) + time.Sleep(dt) + return err +} diff --git a/ratelimiter_test.go b/ratelimiter_test.go new file mode 100644 index 0000000..d3f0ef6 --- /dev/null +++ b/ratelimiter_test.go @@ -0,0 +1,101 @@ +package ratelimiter + +import ( + "sync" + "testing" + "time" +) + +func TestRateLimiter_Limit_Errors(t *testing.T) { + type TestCase struct { + Name string + Conf Config + N int + ErrCount int + DT time.Duration + } + + cases := []TestCase{ + { + Name: "no burst, no wait", + Conf: Config{ + BurstLimit: 0, + FillPeriod: 100 * time.Millisecond, + MaxWaitCount: 0, + }, + N: 32, + ErrCount: 32, + DT: 0, + }, { + Name: "no wait", + Conf: Config{ + BurstLimit: 10, + FillPeriod: 100 * time.Millisecond, + MaxWaitCount: 0, + }, + N: 32, + ErrCount: 22, + DT: 0, + }, { + Name: "no burst", + Conf: Config{ + BurstLimit: 0, + FillPeriod: 10 * time.Millisecond, + MaxWaitCount: 10, + }, + N: 32, + ErrCount: 22, + DT: 100 * time.Millisecond, + }, { + Name: "burst and wait", + Conf: Config{ + BurstLimit: 10, + FillPeriod: 10 * time.Millisecond, + MaxWaitCount: 10, + }, + N: 32, + ErrCount: 12, + DT: 100 * time.Millisecond, + }, + } + + for _, tc := range cases { + wg := sync.WaitGroup{} + l := New(tc.Conf) + errs := make([]error, tc.N) + + t0 := time.Now() + + for i := 0; i < tc.N; i++ { + wg.Add(1) + go func(i int) { + errs[i] = l.Limit() + wg.Done() + }(i) + } + + wg.Wait() + + dt := time.Since(t0) + + errCount := 0 + for _, err := range errs { + if err != nil { + errCount++ + } + } + + if errCount != tc.ErrCount { + t.Fatalf("%s: Expected %d errors but got %d.", + tc.Name, tc.ErrCount, errCount) + } + + if dt < tc.DT { + t.Fatal(tc.Name, dt, tc.DT) + } + + if dt > tc.DT+10*time.Millisecond { + t.Fatal(tc.Name, dt, tc.DT) + } + } +}