diff --git a/README.md b/README.md index b652aba..7b9d92f 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,3 @@ # keyedmutex -Keyed mutex library. \ No newline at end of file +Keyed mutex library. diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..5cb6856 --- /dev/null +++ b/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/keyedmutex + +go 1.25.1 diff --git a/keyedmutex.go b/keyedmutex.go new file mode 100644 index 0000000..7bc8494 --- /dev/null +++ b/keyedmutex.go @@ -0,0 +1,69 @@ +package keyedmutex + +import ( + "sync" +) + +type keyedLock struct { + lock sync.Mutex + count int +} + +type KeyedMutex[K comparable] struct { + mu sync.Mutex + byKey map[K]*keyedLock +} + +func New[K comparable]() *KeyedMutex[K] { + return &KeyedMutex[K]{ + byKey: map[K]*keyedLock{}, + } +} + +func (m *KeyedMutex[K]) getLock(key K) *keyedLock { + m.mu.Lock() + defer m.mu.Unlock() + + item, ok := m.byKey[key] + if !ok { + item = &keyedLock{} + m.byKey[key] = item + } + item.count++ + return item +} + +func (m *KeyedMutex[K]) release(key K, unlock bool) { + m.mu.Lock() + defer m.mu.Unlock() + + item, ok := m.byKey[key] + if !ok { + panic("unlock of unlocked mutex") + } + + item.count-- + if unlock { + item.lock.Unlock() + } + + if item.count == 0 { + delete(m.byKey, key) + } +} + +func (m *KeyedMutex[K]) Lock(key K) { + m.getLock(key).lock.Lock() +} + +func (m *KeyedMutex[K]) TryLock(key K) bool { + if ok := m.getLock(key).lock.TryLock(); !ok { + m.release(key, false) + return false + } + return true +} + +func (m *KeyedMutex[K]) Unlock(key K) { + m.release(key, true) +} diff --git a/keyedmutex_test.go b/keyedmutex_test.go new file mode 100644 index 0000000..da0ee2f --- /dev/null +++ b/keyedmutex_test.go @@ -0,0 +1,111 @@ +package keyedmutex + +import ( + "sync" + "testing" +) + +func TestLock_CleansUpMap(t *testing.T) { + m := New[string]() + m.Lock("a") + m.Unlock("a") + if len(m.byKey) != 0 { + t.Fatalf("expected empty byKey, got %v", m.byKey) + } +} + +func TestTryLock_SucceedsOnFreeKey(t *testing.T) { + m := New[string]() + if !m.TryLock("a") { + t.Fatal("expected TryLock to succeed on free key") + } + m.Unlock("a") +} + +func TestTryLock_FailsOnLockedKey(t *testing.T) { + m := New[string]() + m.Lock("a") + if m.TryLock("a") { + t.Fatal("expected TryLock to fail on locked key") + } + m.Unlock("a") +} + +func TestTryLock_FailureDoesNotCorruptLock(t *testing.T) { + // A failed TryLock must not call Unlock on the key's mutex. + // The old bug did this unconditionally, so the Unlock below would + // double-unlock and panic. + m := New[string]() + m.Lock("a") + m.TryLock("a") // must return false and leave the lock intact + m.Unlock("a") // must not panic +} + +func TestTryLock_FailureDecrementsCount(t *testing.T) { + // A failed TryLock must undo its getLock increment so that the + // original holder's Unlock cleans up the map entry. + m := New[string]() + m.Lock("a") + m.TryLock("a") // fails; a leaked count would leave a stale map entry + m.Unlock("a") + if len(m.byKey) != 0 { + t.Fatalf("expected empty byKey after unlock, got %v", m.byKey) + } +} + +func TestMultipleKeys_AreIndependent(t *testing.T) { + m := New[string]() + m.Lock("a") + if !m.TryLock("b") { + t.Fatal("expected TryLock on different key to succeed while 'a' is locked") + } + m.Unlock("b") + m.Unlock("a") +} + +func TestConcurrentLock_MutualExclusion(t *testing.T) { + m := New[string]() + const N = 100 + var wg sync.WaitGroup + var shared int // intentionally non-atomic: race detector catches improper access + + for range N { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock("a") + shared++ + m.Unlock("a") + }() + } + + wg.Wait() + + if shared != N { + t.Fatalf("expected %d, got %d", N, shared) + } + if len(m.byKey) != 0 { + t.Fatalf("expected empty byKey after all goroutines done, got %v", m.byKey) + } +} + +func TestKeyedMutex_unlockUnlocked(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal(r) + } + }() + + m := New[string]() + m.Unlock("aldkfj") +} + +func BenchmarkUncontendedMutex(b *testing.B) { + m := New[string]() + key := "xyz" + + for b.Loop() { + m.Lock(key) + m.Unlock(key) + } +}