diff --git a/keyedmutex/keyedmutex.go b/keyedmutex/keyedmutex.go new file mode 100644 index 0000000..2430c01 --- /dev/null +++ b/keyedmutex/keyedmutex.go @@ -0,0 +1,67 @@ +package keyedmutex + +import ( + "container/list" + "sync" +) + +type KeyedMutex struct { + mu *sync.Mutex + waitList map[string]*list.List +} + +func New() KeyedMutex { + return KeyedMutex{ + mu: new(sync.Mutex), + waitList: map[string]*list.List{}, + } +} + +func (m KeyedMutex) Lock(key string) { + if ch := m.lock(key); ch != nil { + <-ch + } +} + +func (m KeyedMutex) lock(key string) chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + + if waitList, ok := m.waitList[key]; ok { + ch := make(chan struct{}) + waitList.PushBack(ch) + return ch + } + + m.waitList[key] = list.New() + return nil +} + +func (m KeyedMutex) TryLock(key string) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.waitList[key]; ok { + return false + } + + m.waitList[key] = list.New() + return true +} + +func (m KeyedMutex) Unlock(key string) { + m.mu.Lock() + defer m.mu.Unlock() + + waitList, ok := m.waitList[key] + if !ok { + panic("unlock of unlocked mutex") + } + + if waitList.Len() == 0 { + delete(m.waitList, key) + } else { + ch := waitList.Remove(waitList.Front()).(chan struct{}) + ch <- struct{}{} + } +} diff --git a/keyedmutex/keyedmutex_test.go b/keyedmutex/keyedmutex_test.go new file mode 100644 index 0000000..cd5d02b --- /dev/null +++ b/keyedmutex/keyedmutex_test.go @@ -0,0 +1,122 @@ +package keyedmutex + +import ( + "sync" + "testing" + "time" +) + +func TestKeyedMutex(t *testing.T) { + checkState := func(t *testing.T, m KeyedMutex, keys ...string) { + if len(m.waitList) != len(keys) { + t.Fatal(m.waitList, keys) + } + + for _, key := range keys { + if _, ok := m.waitList[key]; !ok { + t.Fatal(key) + } + } + } + + m := New() + checkState(t, m) + + m.Lock("a") + checkState(t, m, "a") + m.Lock("b") + checkState(t, m, "a", "b") + m.Lock("c") + checkState(t, m, "a", "b", "c") + + if m.TryLock("a") { + t.Fatal("a") + } + if m.TryLock("b") { + t.Fatal("b") + } + if m.TryLock("c") { + t.Fatal("c") + } + + if !m.TryLock("d") { + t.Fatal("d") + } + + checkState(t, m, "a", "b", "c", "d") + + if !m.TryLock("e") { + t.Fatal("e") + } + checkState(t, m, "a", "b", "c", "d", "e") + + m.Unlock("c") + checkState(t, m, "a", "b", "d", "e") + m.Unlock("a") + checkState(t, m, "b", "d", "e") + m.Unlock("e") + checkState(t, m, "b", "d") + + wg := sync.WaitGroup{} + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock("b") + m.Unlock("b") + }() + } + + time.Sleep(100 * time.Millisecond) + m.Unlock("b") + wg.Wait() + + checkState(t, m, "d") + + m.Unlock("d") + checkState(t, m) +} + +func TestKeyedMutex_unlockUnlocked(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal(r) + } + }() + + m := New() + m.Unlock("aldkfj") +} + +func BenchmarkUncontendedMutex(b *testing.B) { + m := New() + key := "xyz" + + for i := 0; i < b.N; i++ { + m.Lock(key) + m.Unlock(key) + } +} + +func BenchmarkContendedMutex(b *testing.B) { + m := New() + key := "xyz" + + m.Lock(key) + + wg := sync.WaitGroup{} + for i := 0; i < b.N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock(key) + m.Unlock(key) + }() + } + + time.Sleep(time.Second) + + b.ResetTimer() + m.Unlock(key) + wg.Wait() +}