This commit is contained in:
jdl
2026-06-14 19:27:10 +02:00
parent 61b9a0a287
commit 8f74535b1f
2 changed files with 5 additions and 40 deletions

View File

@@ -53,17 +53,16 @@ func New(conf Config) *Limiter {
return lim
}
func (lim *Limiter) getWaitTime(count int64) (time.Duration, error) {
func (lim *Limiter) getWaitTime() (time.Duration, error) {
lim.lock.Lock()
defer lim.lock.Unlock()
dt := time.Since(lim.lastRequest)
waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod
waitTime := lim.waitTime - dt
waitTime = max(waitTime, lim.minWaitTime) + lim.fillPeriod
if waitTime < lim.minWaitTime {
waitTime = lim.minWaitTime
} else if waitTime > lim.maxWaitTime {
if waitTime > lim.maxWaitTime {
return 0, ErrBackoff
}
@@ -77,14 +76,7 @@ func (lim *Limiter) getWaitTime(count int64) (time.Duration, error) {
// maxWaitTime before returning. If the timeout would need to be more than
// maxWaitTime to enforce the rate limit, ErrBackoff is returned.
func (lim *Limiter) Limit() error {
dt, err := lim.getWaitTime(1)
dt, err := lim.getWaitTime()
time.Sleep(dt) // Will return immediately for dt <= 0.
return err
}
// Apply the limiter for multiple items at once.
func (lim *Limiter) LimitMultiple(count int64) error {
dt, err := lim.getWaitTime(count)
time.Sleep(dt)
return err
}

View File

@@ -162,30 +162,3 @@ func TestLimit_BurstCap(t *testing.T) {
}
})
}
// TestLimitMultiple: multiple tokens are consumed per call.
//
// Config: BurstLimit=4, FillPeriod=1s, MaxWaitCount=2 → minWaitTime=-4s, maxWaitTime=2s.
//
// Call 1: LimitMultiple(3) → waitTime = -4s+3s = -1s → immediate
// Call 2: LimitMultiple(3) → waitTime = -1s+3s = 2s → sleeps 2s
// Call 3: LimitMultiple(3) → waitTime = 2s+3s = 5s > 2s → ErrBackoff
func TestLimitMultiple(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
t0 := time.Now()
l := New(Config{BurstLimit: 4, FillPeriod: time.Second, MaxWaitCount: 2})
if err := l.LimitMultiple(3); err != nil {
t.Fatalf("call 1: %v", err)
}
if err := l.LimitMultiple(3); err != nil {
t.Fatalf("call 2: %v", err)
}
if err := l.LimitMultiple(3); !errors.Is(err, ErrBackoff) {
t.Fatalf("call 3: want ErrBackoff, got %v", err)
}
if elapsed := time.Since(t0); elapsed != 2*time.Second {
t.Errorf("elapsed: want 2s, got %v", elapsed)
}
})
}