diff --git a/ratelimiter.go b/ratelimiter.go index 81fbd0f..b15c1c1 100644 --- a/ratelimiter.go +++ b/ratelimiter.go @@ -53,17 +53,16 @@ func New(conf Config) *Limiter { return lim } -func (lim *Limiter) getWaitTime(count int64) (time.Duration, error) { +func (lim *Limiter) getWaitTime() (time.Duration, error) { lim.lock.Lock() defer lim.lock.Unlock() dt := time.Since(lim.lastRequest) - waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod + waitTime := lim.waitTime - dt + waitTime = max(waitTime, lim.minWaitTime) + lim.fillPeriod - if waitTime < lim.minWaitTime { - waitTime = lim.minWaitTime - } else if waitTime > lim.maxWaitTime { + if waitTime > lim.maxWaitTime { return 0, ErrBackoff } @@ -77,14 +76,7 @@ func (lim *Limiter) getWaitTime(count int64) (time.Duration, error) { // maxWaitTime before returning. If the timeout would need to be more than // maxWaitTime to enforce the rate limit, ErrBackoff is returned. func (lim *Limiter) Limit() error { - dt, err := lim.getWaitTime(1) + dt, err := lim.getWaitTime() time.Sleep(dt) // Will return immediately for dt <= 0. return err } - -// Apply the limiter for multiple items at once. -func (lim *Limiter) LimitMultiple(count int64) error { - dt, err := lim.getWaitTime(count) - time.Sleep(dt) - return err -} diff --git a/ratelimiter_test.go b/ratelimiter_test.go index aff4854..f12ba5c 100644 --- a/ratelimiter_test.go +++ b/ratelimiter_test.go @@ -162,30 +162,3 @@ func TestLimit_BurstCap(t *testing.T) { } }) } - -// TestLimitMultiple: multiple tokens are consumed per call. -// -// Config: BurstLimit=4, FillPeriod=1s, MaxWaitCount=2 → minWaitTime=-4s, maxWaitTime=2s. -// -// Call 1: LimitMultiple(3) → waitTime = -4s+3s = -1s → immediate -// Call 2: LimitMultiple(3) → waitTime = -1s+3s = 2s → sleeps 2s -// Call 3: LimitMultiple(3) → waitTime = 2s+3s = 5s > 2s → ErrBackoff -func TestLimitMultiple(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - t0 := time.Now() - l := New(Config{BurstLimit: 4, FillPeriod: time.Second, MaxWaitCount: 2}) - - if err := l.LimitMultiple(3); err != nil { - t.Fatalf("call 1: %v", err) - } - if err := l.LimitMultiple(3); err != nil { - t.Fatalf("call 2: %v", err) - } - if err := l.LimitMultiple(3); !errors.Is(err, ErrBackoff) { - t.Fatalf("call 3: want ErrBackoff, got %v", err) - } - if elapsed := time.Since(t0); elapsed != 2*time.Second { - t.Errorf("elapsed: want 2s, got %v", elapsed) - } - }) -}