174 lines
3.4 KiB
Go
174 lines
3.4 KiB
Go
|
package mdb
|
||
|
|
||
|
import (
|
||
|
"sync"
|
||
|
|
||
|
"git.crumpington.com/private/mdb/keyedmutex"
|
||
|
)
|
||
|
|
||
|
type MapIndex[K comparable, T any] struct {
|
||
|
c *Collection[T]
|
||
|
_name string
|
||
|
keyLock keyedmutex.KeyedMutex[K]
|
||
|
mapLock sync.Mutex
|
||
|
m map[K]*T
|
||
|
getID func(*T) uint64
|
||
|
getKey func(*T) K
|
||
|
include func(*T) bool
|
||
|
}
|
||
|
|
||
|
func NewMapIndex[K comparable, T any](
|
||
|
c *Collection[T],
|
||
|
name string,
|
||
|
getKey func(*T) K,
|
||
|
include func(*T) bool,
|
||
|
) *MapIndex[K, T] {
|
||
|
m := &MapIndex[K, T]{
|
||
|
c: c,
|
||
|
_name: name,
|
||
|
keyLock: keyedmutex.New[K](),
|
||
|
m: map[K]*T{},
|
||
|
getID: c.getID,
|
||
|
getKey: getKey,
|
||
|
include: include,
|
||
|
}
|
||
|
|
||
|
c.indices = append(c.indices, m)
|
||
|
c.uniqueIndices = append(c.uniqueIndices, m)
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
// ----------------------------------------------------------------------------
|
||
|
|
||
|
func (m *MapIndex[K, T]) load(src map[uint64]*T) error {
|
||
|
for _, item := range src {
|
||
|
if m.include != nil && !m.include(item) {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
k := m.getKey(item)
|
||
|
if _, ok := m.m[k]; ok {
|
||
|
return ErrDuplicate
|
||
|
}
|
||
|
|
||
|
m.m[k] = item
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (m *MapIndex[K, T]) Get(k K) (t T, ok bool) {
|
||
|
if item, ok := m.mapGet(k); ok {
|
||
|
return *item, true
|
||
|
}
|
||
|
|
||
|
return t, false
|
||
|
}
|
||
|
|
||
|
func (m *MapIndex[K, T]) Update(k K, update func(T) (T, error)) error {
|
||
|
wrapped := func(item T) (T, error) {
|
||
|
new, err := update(item)
|
||
|
if err != nil {
|
||
|
return item, err
|
||
|
}
|
||
|
|
||
|
if m.getKey(&new) != k {
|
||
|
return item, ErrMismatchedIDs
|
||
|
}
|
||
|
|
||
|
return new, nil
|
||
|
}
|
||
|
|
||
|
if item, ok := m.mapGet(k); ok {
|
||
|
return m.c.Update(m.getID(item), wrapped)
|
||
|
}
|
||
|
|
||
|
return ErrNotFound
|
||
|
}
|
||
|
|
||
|
func (m *MapIndex[K, T]) Delete(k K) {
|
||
|
if item, ok := m.mapGet(k); ok {
|
||
|
m.c.Delete(m.getID(item))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ----------------------------------------------------------------------------
|
||
|
|
||
|
func (m *MapIndex[K, T]) insert(item *T) {
|
||
|
if m.include == nil || m.include(item) {
|
||
|
m.mapSet(m.getKey(item), item)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *MapIndex[K, T]) update(old, new *T) {
|
||
|
if m.include == nil || m.include(old) {
|
||
|
m.mapDelete(m.getKey(old))
|
||
|
}
|
||
|
if m.include == nil || m.include(new) {
|
||
|
m.mapSet(m.getKey(new), new)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *MapIndex[K, T]) delete(item *T) {
|
||
|
if m.include == nil || m.include(item) {
|
||
|
m.mapDelete(m.getKey(item))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// ----------------------------------------------------------------------------
|
||
|
|
||
|
func (idx *MapIndex[K, T]) name() string {
|
||
|
return idx._name
|
||
|
}
|
||
|
|
||
|
func (idx *MapIndex[K, T]) lock(item *T) {
|
||
|
if idx.include == nil || idx.include(item) {
|
||
|
idx.keyLock.Lock(idx.getKey(item))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (idx *MapIndex[K, T]) unlock(item *T) {
|
||
|
if idx.include == nil || idx.include(item) {
|
||
|
idx.keyLock.Unlock(idx.getKey(item))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// Should hold item lock when calling.
|
||
|
func (idx *MapIndex[K, T]) insertConflict(new *T) bool {
|
||
|
if idx.include != nil && !idx.include(new) {
|
||
|
return false
|
||
|
}
|
||
|
_, ok := idx.Get(idx.getKey(new))
|
||
|
return ok
|
||
|
}
|
||
|
|
||
|
// Should hold item lock when calling.
|
||
|
func (idx *MapIndex[K, T]) updateConflict(new *T) bool {
|
||
|
if idx.include != nil && !idx.include(new) {
|
||
|
return false
|
||
|
}
|
||
|
val, ok := idx.mapGet(idx.getKey(new))
|
||
|
return ok && idx.getID(val) != idx.getID(new)
|
||
|
}
|
||
|
|
||
|
// ----------------------------------------------------------------------------
|
||
|
|
||
|
func (idx *MapIndex[K, T]) mapSet(k K, t *T) {
|
||
|
idx.mapLock.Lock()
|
||
|
idx.m[k] = t
|
||
|
idx.mapLock.Unlock()
|
||
|
}
|
||
|
|
||
|
func (idx *MapIndex[K, T]) mapDelete(k K) {
|
||
|
idx.mapLock.Lock()
|
||
|
delete(idx.m, k)
|
||
|
idx.mapLock.Unlock()
|
||
|
}
|
||
|
|
||
|
func (idx *MapIndex[K, T]) mapGet(k K) (*T, bool) {
|
||
|
idx.mapLock.Lock()
|
||
|
t, ok := idx.m[k]
|
||
|
idx.mapLock.Unlock()
|
||
|
return t, ok
|
||
|
}
|