package mdb import ( "sync" "git.crumpington.com/private/mdb/keyedmutex" ) type MapIndex[K comparable, T any] struct { c *Collection[T] _name string keyLock keyedmutex.KeyedMutex[K] mapLock sync.Mutex m map[K]*T getID func(*T) uint64 getKey func(*T) K include func(*T) bool } func NewMapIndex[K comparable, T any]( c *Collection[T], name string, getKey func(*T) K, include func(*T) bool, ) *MapIndex[K, T] { m := &MapIndex[K, T]{ c: c, _name: name, keyLock: keyedmutex.New[K](), m: map[K]*T{}, getID: c.getID, getKey: getKey, include: include, } c.indices = append(c.indices, m) c.uniqueIndices = append(c.uniqueIndices, m) return m } // ---------------------------------------------------------------------------- func (m *MapIndex[K, T]) load(src map[uint64]*T) error { for _, item := range src { if m.include != nil && !m.include(item) { continue } k := m.getKey(item) if _, ok := m.m[k]; ok { return ErrDuplicate } m.m[k] = item } return nil } func (m *MapIndex[K, T]) Get(k K) (t T, ok bool) { if item, ok := m.mapGet(k); ok { return *item, true } return t, false } func (m *MapIndex[K, T]) Update(k K, update func(T) (T, error)) error { wrapped := func(item T) (T, error) { new, err := update(item) if err != nil { return item, err } return new, nil } if item, ok := m.mapGet(k); ok { return m.c.Update(m.getID(item), wrapped) } return ErrNotFound } func (m *MapIndex[K, T]) Delete(k K) { if item, ok := m.mapGet(k); ok { m.c.Delete(m.getID(item)) } } // ---------------------------------------------------------------------------- func (m *MapIndex[K, T]) insert(item *T) { if m.include == nil || m.include(item) { m.mapSet(m.getKey(item), item) } } func (m *MapIndex[K, T]) update(old, new *T) { if m.include == nil || m.include(old) { m.mapDelete(m.getKey(old)) } if m.include == nil || m.include(new) { m.mapSet(m.getKey(new), new) } } func (m *MapIndex[K, T]) delete(item *T) { if m.include == nil || m.include(item) { m.mapDelete(m.getKey(item)) } } // ---------------------------------------------------------------------------- func (idx *MapIndex[K, T]) name() string { return idx._name } func (idx *MapIndex[K, T]) lock(item *T) { if idx.include == nil || idx.include(item) { idx.keyLock.Lock(idx.getKey(item)) } } func (idx *MapIndex[K, T]) unlock(item *T) { if idx.include == nil || idx.include(item) { idx.keyLock.Unlock(idx.getKey(item)) } } // Should hold item lock when calling. func (idx *MapIndex[K, T]) insertConflict(new *T) bool { if idx.include != nil && !idx.include(new) { return false } _, ok := idx.Get(idx.getKey(new)) return ok } // Should hold item lock when calling. func (idx *MapIndex[K, T]) updateConflict(new *T) bool { if idx.include != nil && !idx.include(new) { return false } cur, ok := idx.mapGet(idx.getKey(new)) return ok && idx.getID(cur) != idx.getID(new) } // ---------------------------------------------------------------------------- func (idx *MapIndex[K, T]) mapSet(k K, t *T) { idx.mapLock.Lock() idx.m[k] = t idx.mapLock.Unlock() } func (idx *MapIndex[K, T]) mapDelete(k K) { idx.mapLock.Lock() delete(idx.m, k) idx.mapLock.Unlock() } func (idx *MapIndex[K, T]) mapGet(k K) (*T, bool) { idx.mapLock.Lock() t, ok := idx.m[k] idx.mapLock.Unlock() return t, ok }