package ratelimiter

import (
	"sync"
	"testing"
	"time"
)

func TestRateLimiter_Limit_Errors(t *testing.T) {
	type TestCase struct {
		Name     string
		Conf     Config
		N        int
		ErrCount int
		DT       time.Duration
	}

	cases := []TestCase{
		{
			Name: "no burst, no wait",
			Conf: Config{
				BurstLimit:   0,
				FillPeriod:   100 * time.Millisecond,
				MaxWaitCount: 0,
			},
			N:        32,
			ErrCount: 32,
			DT:       0,
		}, {
			Name: "no wait",
			Conf: Config{
				BurstLimit:   10,
				FillPeriod:   100 * time.Millisecond,
				MaxWaitCount: 0,
			},
			N:        32,
			ErrCount: 22,
			DT:       0,
		}, {
			Name: "no burst",
			Conf: Config{
				BurstLimit:   0,
				FillPeriod:   10 * time.Millisecond,
				MaxWaitCount: 10,
			},
			N:        32,
			ErrCount: 22,
			DT:       100 * time.Millisecond,
		}, {
			Name: "burst and wait",
			Conf: Config{
				BurstLimit:   10,
				FillPeriod:   10 * time.Millisecond,
				MaxWaitCount: 10,
			},
			N:        32,
			ErrCount: 12,
			DT:       100 * time.Millisecond,
		},
	}

	for _, tc := range cases {
		wg := sync.WaitGroup{}
		l := New(tc.Conf)
		errs := make([]error, tc.N)

		t0 := time.Now()

		for i := 0; i < tc.N; i++ {
			wg.Add(1)
			go func(i int) {
				errs[i] = l.Limit()
				wg.Done()
			}(i)
		}

		wg.Wait()

		dt := time.Since(t0)

		errCount := 0
		for _, err := range errs {
			if err != nil {
				errCount++
			}
		}

		if errCount != tc.ErrCount {
			t.Fatalf("%s: Expected %d errors but got %d.",
				tc.Name, tc.ErrCount, errCount)
		}

		if dt < tc.DT {
			t.Fatal(tc.Name, dt, tc.DT)
		}

		if dt > tc.DT+10*time.Millisecond {
			t.Fatal(tc.Name, dt, tc.DT)
		}
	}
}