From 190fb9576f3ab35197f4ba05c6d4463c92f39920 Mon Sep 17 00:00:00 2001 From: jdl Date: Fri, 13 Oct 2023 12:07:24 +0200 Subject: [PATCH] Initial commit --- go.mod | 3 ++ keyedmutex.go | 67 ++++++++++++++++++++++++ keyedmutex_test.go | 123 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 go.mod create mode 100644 keyedmutex.go create mode 100644 keyedmutex_test.go diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..da66878 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/public/keyedmutex + +go 1.21.1 diff --git a/keyedmutex.go b/keyedmutex.go new file mode 100644 index 0000000..699e4b1 --- /dev/null +++ b/keyedmutex.go @@ -0,0 +1,67 @@ +package keyedmutex + +import ( + "container/list" + "sync" +) + +type KeyedMutex[K comparable] struct { + mu *sync.Mutex + waitList map[K]*list.List +} + +func New[K comparable]() KeyedMutex[K] { + return KeyedMutex[K]{ + mu: new(sync.Mutex), + waitList: map[K]*list.List{}, + } +} + +func (m KeyedMutex[K]) Lock(key K) { + if ch := m.lock(key); ch != nil { + <-ch + } +} + +func (m KeyedMutex[K]) lock(key K) chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + + if waitList, ok := m.waitList[key]; ok { + ch := make(chan struct{}) + waitList.PushBack(ch) + return ch + } + + m.waitList[key] = list.New() + return nil +} + +func (m KeyedMutex[K]) TryLock(key K) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.waitList[key]; ok { + return false + } + + m.waitList[key] = list.New() + return true +} + +func (m KeyedMutex[K]) Unlock(key K) { + m.mu.Lock() + defer m.mu.Unlock() + + waitList, ok := m.waitList[key] + if !ok { + panic("unlock of unlocked mutex") + } + + if waitList.Len() == 0 { + delete(m.waitList, key) + } else { + ch := waitList.Remove(waitList.Front()).(chan struct{}) + ch <- struct{}{} + } +} diff --git a/keyedmutex_test.go b/keyedmutex_test.go new file mode 100644 index 0000000..14fdaf0 --- /dev/null +++ b/keyedmutex_test.go @@ -0,0 +1,123 @@ +package keyedmutex + + +import ( + "sync" + "testing" + "time" +) + +func TestKeyedMutex(t *testing.T) { + checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) { + if len(m.waitList) != len(keys) { + t.Fatal(m.waitList, keys) + } + + for _, key := range keys { + if _, ok := m.waitList[key]; !ok { + t.Fatal(key) + } + } + } + + m := New[string]() + checkState(t, m) + + m.Lock("a") + checkState(t, m, "a") + m.Lock("b") + checkState(t, m, "a", "b") + m.Lock("c") + checkState(t, m, "a", "b", "c") + + if m.TryLock("a") { + t.Fatal("a") + } + if m.TryLock("b") { + t.Fatal("b") + } + if m.TryLock("c") { + t.Fatal("c") + } + + if !m.TryLock("d") { + t.Fatal("d") + } + + checkState(t, m, "a", "b", "c", "d") + + if !m.TryLock("e") { + t.Fatal("e") + } + checkState(t, m, "a", "b", "c", "d", "e") + + m.Unlock("c") + checkState(t, m, "a", "b", "d", "e") + m.Unlock("a") + checkState(t, m, "b", "d", "e") + m.Unlock("e") + checkState(t, m, "b", "d") + + wg := sync.WaitGroup{} + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock("b") + m.Unlock("b") + }() + } + + time.Sleep(100 * time.Millisecond) + m.Unlock("b") + wg.Wait() + + checkState(t, m, "d") + + m.Unlock("d") + checkState(t, m) +} + +func TestKeyedMutex_unlockUnlocked(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal(r) + } + }() + + m := New[string]() + m.Unlock("aldkfj") +} + +func BenchmarkUncontendedMutex(b *testing.B) { + m := New[string]() + key := "xyz" + + for i := 0; i < b.N; i++ { + m.Lock(key) + m.Unlock(key) + } +} + +func BenchmarkContendedMutex(b *testing.B) { + m := New[string]() + key := "xyz" + + m.Lock(key) + + wg := sync.WaitGroup{} + for i := 0; i < b.N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock(key) + m.Unlock(key) + }() + } + + time.Sleep(time.Second) + + b.ResetTimer() + m.Unlock(key) + wg.Wait() +}