From 3f2ef4ce963faefed28fbaaf3be6432ee1f383f3 Mon Sep 17 00:00:00 2001 From: jdl Date: Fri, 13 Oct 2023 13:25:59 +0200 Subject: [PATCH] Initial commit --- cache.go | 134 +++++++++++++++++++++++++++ cache_test.go | 249 ++++++++++++++++++++++++++++++++++++++++++++++++++ go.mod | 3 + stats.go | 6 ++ 4 files changed, 392 insertions(+) create mode 100644 cache.go create mode 100644 cache_test.go create mode 100644 go.mod create mode 100644 stats.go diff --git a/cache.go b/cache.go new file mode 100644 index 0000000..3d6fa68 --- /dev/null +++ b/cache.go @@ -0,0 +1,134 @@ +package kvmemcache + +import ( + "container/list" + "sync" + "time" + + "git.crumpington.com/public/keyedmutex" +) + +type Cache[K comparable, V any] struct { + updateLock keyedmutex.KeyedMutex[K] + src func(K) (V, error) + ttl time.Duration + maxSize int + + // Lock protects variables below. + lock sync.Mutex + cache map[K]*list.Element + ll *list.List + stats Stats +} + +type lruItem[K comparable, V any] struct { + key K + createdAt time.Time + value V + err error +} + +type Config[K comparable, V any] struct { + MaxSize int + TTL time.Duration // Zero to ignore. + Src func(K) (V, error) +} + +func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] { + return &Cache[K, V]{ + updateLock: keyedmutex.New[K](), + src: conf.Src, + ttl: conf.TTL, + maxSize: conf.MaxSize, + lock: sync.Mutex{}, + cache: make(map[K]*list.Element, conf.MaxSize+1), + ll: list.New(), + } +} + +func (c *Cache[K, V]) Get(key K) (V, error) { + ok, val, err := c.get(key) + if ok { + return val, err + } + + return c.load(key) +} + +func (c *Cache[K, V]) Evict(key K) { + c.lock.Lock() + defer c.lock.Unlock() + c.evict(key) +} + +func (c *Cache[K, V]) Stats() Stats { + c.lock.Lock() + defer c.lock.Unlock() + return c.stats +} + +func (c *Cache[K, V]) put(key K, value V, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + c.stats.Misses++ + + c.cache[key] = c.ll.PushFront(lruItem[K, V]{ + key: key, + createdAt: time.Now(), + value: value, + err: err, + }) + + if c.maxSize != 0 && len(c.cache) > c.maxSize { + li := c.ll.Back() + c.ll.Remove(li) + delete(c.cache, li.Value.(lruItem[K, V]).key) + } +} + +func (c *Cache[K, V]) evict(key K) { + elem := c.cache[key] + if elem != nil { + delete(c.cache, key) + c.ll.Remove(elem) + } +} + +func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + li := c.cache[key] + if li == nil { + return false, val, nil + } + + item := li.Value.(lruItem[K, V]) + // Maybe evict. + if c.ttl != 0 && time.Since(item.createdAt) > c.ttl { + c.evict(key) + return false, val, nil + } + + c.stats.Hits++ + + c.ll.MoveToFront(li) + return true, item.value, item.err +} + +func (c *Cache[K, V]) load(key K) (V, error) { + c.updateLock.Lock(key) + defer c.updateLock.Unlock(key) + + // Check again in case we lost the update race. + ok, val, err := c.get(key) + if ok { + return val, err + } + + // Won the update race. + val, err = c.src(key) + c.put(key, val, err) + return val, err +} diff --git a/cache_test.go b/cache_test.go new file mode 100644 index 0000000..2216c12 --- /dev/null +++ b/cache_test.go @@ -0,0 +1,249 @@ +package kvmemcache + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" +) + +type State[K comparable] struct { + Keys []K + Stats Stats +} + +func (c *Cache[K,V]) assert(state State[K]) error { + c.lock.Lock() + defer c.lock.Unlock() + + if len(c.cache) != len(state.Keys) { + return fmt.Errorf( + "Expected %d keys but found %d.", + len(state.Keys), + len(c.cache)) + } + + for _, k := range state.Keys { + if _, ok := c.cache[k]; !ok { + return fmt.Errorf( + "Expected key %s not found.", + k) + } + } + + if c.stats.Hits != state.Stats.Hits { + return fmt.Errorf( + "Expected %d hits, but found %d.", + state.Stats.Hits, + c.stats.Hits) + } + + if c.stats.Misses != state.Stats.Misses { + return fmt.Errorf( + "Expected %d misses, but found %d.", + state.Stats.Misses, + c.stats.Misses) + } + + return nil +} + +var ErrTest = errors.New("Hello") + +func TestCache_basic(t *testing.T) { + c := New(Config[string, string]{ + MaxSize: 4, + TTL: 50 * time.Millisecond, + Src: func(key string) (string, error) { + if key == "err" { + return "", ErrTest + } + return key, nil + }, + }) + + type testCase struct { + name string + sleep time.Duration + key string + evict bool + state State[string] + } + + cases := []testCase{ + { + name: "get a", + key: "a", + state: State[string]{ + Keys: []string{"a"}, + Stats: Stats{Hits: 0, Misses: 1}, + }, + }, { + name: "get a again", + key: "a", + state: State[string]{ + Keys: []string{"a"}, + Stats: Stats{Hits: 1, Misses: 1}, + }, + }, { + name: "sleep, then get a again", + sleep: 55 * time.Millisecond, + key: "a", + state: State[string]{ + Keys: []string{"a"}, + Stats: Stats{Hits: 1, Misses: 2}, + }, + }, { + name: "get b", + key: "b", + state: State[string]{ + Keys: []string{"a", "b"}, + Stats: Stats{Hits: 1, Misses: 3}, + }, + }, { + name: "get c", + key: "c", + state: State[string]{ + Keys: []string{"a", "b", "c"}, + Stats: Stats{Hits: 1, Misses: 4}, + }, + }, { + name: "get d", + key: "d", + state: State[string]{ + Keys: []string{"a", "b", "c", "d"}, + Stats: Stats{Hits: 1, Misses: 5}, + }, + }, { + name: "get e", + key: "e", + state: State[string]{ + Keys: []string{"b", "c", "d", "e"}, + Stats: Stats{Hits: 1, Misses: 6}, + }, + }, { + name: "get c again", + key: "c", + state: State[string]{ + Keys: []string{"b", "c", "d", "e"}, + Stats: Stats{Hits: 2, Misses: 6}, + }, + }, { + name: "get err", + key: "err", + state: State[string]{ + Keys: []string{"c", "d", "e", "err"}, + Stats: Stats{Hits: 2, Misses: 7}, + }, + }, { + name: "get err again", + key: "err", + state: State[string]{ + Keys: []string{"c", "d", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 7}, + }, + }, { + name: "evict c", + key: "c", + evict: true, + state: State[string]{ + Keys: []string{"d", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 7}, + }, + }, { + name: "reload-all a", + key: "a", + state: State[string]{ + Keys: []string{"a", "d", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 8}, + }, + }, { + name: "reload-all b", + key: "b", + state: State[string]{ + Keys: []string{"a", "b", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 9}, + }, + }, { + name: "reload-all c", + key: "c", + state: State[string]{ + Keys: []string{"a", "b", "c", "err"}, + Stats: Stats{Hits: 3, Misses: 10}, + }, + }, { + name: "reload-all d", + key: "d", + state: State[string]{ + Keys: []string{"a", "b", "c", "d"}, + Stats: Stats{Hits: 3, Misses: 11}, + }, + }, { + name: "read a again", + key: "a", + state: State[string]{ + Keys: []string{"b", "c", "d", "a"}, + Stats: Stats{Hits: 4, Misses: 11}, + }, + }, { + name: "read e, evicting b", + key: "e", + state: State[string]{ + Keys: []string{"c", "d", "a", "e"}, + Stats: Stats{Hits: 4, Misses: 12}, + }, + }, + } + + for _, tc := range cases { + time.Sleep(tc.sleep) + if !tc.evict { + val, err := c.Get(tc.key) + if tc.key == "err" && err != ErrTest { + t.Fatal(tc.name, val) + } + if tc.key != "err" && val != tc.key { + t.Fatal(tc.name, tc.key, val) + } + } else { + c.Evict(tc.key) + } + + if err := c.assert(tc.state); err != nil { + t.Fatal(err) + } + } +} + +func TestCache_thunderingHerd(t *testing.T) { + c := New(Config[string,string]{ + MaxSize: 4, + Src: func(key string) (string, error) { + time.Sleep(time.Second) + return key, nil + }, + }) + + wg := sync.WaitGroup{} + for i := 0; i < 1024; i++ { + wg.Add(1) + go func() { + defer wg.Done() + val, err := c.Get("a") + if err != nil { + panic(err) + } + if val != "a" { + panic(err) + } + }() + } + + wg.Wait() + + stats := c.Stats() + if stats.Hits != 1023 || stats.Misses != 1 { + t.Fatal(stats) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..9036b46 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/public/kvmemcache + +go 1.21.1 diff --git a/stats.go b/stats.go new file mode 100644 index 0000000..b6415f9 --- /dev/null +++ b/stats.go @@ -0,0 +1,6 @@ +package kvmemcache + +type Stats struct { + Hits uint64 + Misses uint64 +}