diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go new file mode 100644 index 0000000..42ed25a --- /dev/null +++ b/ratelimiter/ratelimiter.go @@ -0,0 +1,73 @@ +package ratelimiter + +import ( + "errors" + "sync" + "time" +) + +var ErrBackoff = errors.New("Backoff") + +type Config struct { + BurstLimit int64 // Number of requests to allow to burst. + FillPeriod time.Duration // Add one call per period. + MaxWaitCount int64 // Max number of waiting requests. 0 disables. +} + +type Limiter struct { + lock sync.Mutex + + fillPeriod time.Duration + minWaitTime time.Duration + maxWaitTime time.Duration + + waitTime time.Duration + lastRequest time.Time +} + +func New(conf Config) *Limiter { + if conf.BurstLimit < 0 { + panic(conf.BurstLimit) + } + if conf.FillPeriod <= 0 { + panic(conf.FillPeriod) + } + if conf.MaxWaitCount < 0 { + panic(conf.MaxWaitCount) + } + + return &Limiter{ + fillPeriod: conf.FillPeriod, + waitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), + minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), + maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount-1), + lastRequest: time.Now(), + } +} + +func (lim *Limiter) limit() (time.Duration, error) { + lim.lock.Lock() + defer lim.lock.Unlock() + + dt := time.Since(lim.lastRequest) + waitTime := lim.waitTime - dt + if waitTime < lim.minWaitTime { + waitTime = lim.minWaitTime + } else if waitTime >= lim.maxWaitTime { + return 0, ErrBackoff + } + + lim.waitTime = waitTime + lim.fillPeriod + lim.lastRequest = lim.lastRequest.Add(dt) + + return lim.waitTime, nil +} + +// Apply the limiter to the calling thread. The function may sleep for up to +// maxWaitTime before returning. If the timeout would need to be more than +// maxWaitTime to enforce the rate limit, ErrBackoff is returned. +func (lim *Limiter) Limit() error { + dt, err := lim.limit() + time.Sleep(dt) // Will return immediately for dt <= 0. + return err +} diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go new file mode 100644 index 0000000..008f116 --- /dev/null +++ b/ratelimiter/ratelimiter_test.go @@ -0,0 +1,101 @@ +package ratelimiter + +import ( + "sync" + "testing" + "time" +) + +func TestRateLimiter_Limit_Errors(t *testing.T) { + type TestCase struct { + Name string + Conf Config + N int + ErrCount int + DT time.Duration + } + + cases := []TestCase{ + { + Name: "no burst, no wait", + Conf: Config{ + BurstLimit: 0, + FillPeriod: 100 * time.Millisecond, + MaxWaitCount: 0, + }, + N: 32, + ErrCount: 31, + DT: 100 * time.Millisecond, + }, { + Name: "no wait", + Conf: Config{ + BurstLimit: 10, + FillPeriod: 100 * time.Millisecond, + MaxWaitCount: 0, + }, + N: 32, + ErrCount: 22, + DT: 0, + }, { + Name: "no burst", + Conf: Config{ + BurstLimit: 0, + FillPeriod: 10 * time.Millisecond, + MaxWaitCount: 10, + }, + N: 32, + ErrCount: 22, + DT: 100 * time.Millisecond, + }, { + Name: "burst and wait", + Conf: Config{ + BurstLimit: 10, + FillPeriod: 10 * time.Millisecond, + MaxWaitCount: 10, + }, + N: 32, + ErrCount: 12, + DT: 100 * time.Millisecond, + }, + } + + for _, tc := range cases { + wg := sync.WaitGroup{} + l := New(tc.Conf) + errs := make([]error, tc.N) + + t0 := time.Now() + + for i := 0; i < tc.N; i++ { + wg.Add(1) + go func(i int) { + errs[i] = l.Limit() + wg.Done() + }(i) + } + + wg.Wait() + + dt := time.Since(t0) + + errCount := 0 + for _, err := range errs { + if err != nil { + errCount++ + } + } + + if errCount != tc.ErrCount { + t.Fatalf("%s: Expected %d errors but got %d.", + tc.Name, tc.ErrCount, errCount) + } + + if dt < tc.DT { + t.Fatal(tc.Name, dt, tc.DT) + } + + if dt > tc.DT+10*time.Millisecond { + t.Fatal(tc.Name, dt, tc.DT) + } + } +}