This repository has been archived on 2022-07-30. You can view files and clone it, but cannot push or open issues/pull-requests.
mdb/collection.go

206 lines
3.8 KiB
Go

package mdb
import (
"fmt"
"git.crumpington.com/private/mdb/keyedmutex"
)
type Collection[T any] struct {
primary bool
db *Database
name string
idLock keyedmutex.KeyedMutex[uint64]
items *itemMap[T]
indices []itemIndex[T]
uniqueIndices []itemUniqueIndex[T]
getID func(*T) uint64
sanitize func(*T)
validate func(*T) error
}
func NewCollection[T any](db *Database, name string, getID func(*T) uint64) *Collection[T] {
items := newItemMap(db.kv, name, getID)
c := &Collection[T]{
primary: db.kv.Primary(),
db: db,
name: name,
idLock: keyedmutex.New[uint64](),
items: items,
indices: []itemIndex[T]{items},
uniqueIndices: []itemUniqueIndex[T]{},
getID: items.getID,
sanitize: func(*T) {},
validate: func(*T) error { return nil },
}
db.collections[name] = c
return c
}
func (c *Collection[T]) SetSanitize(sanitize func(*T)) {
c.sanitize = sanitize
}
func (c *Collection[T]) SetValidate(validate func(*T) error) {
c.validate = validate
}
// ----------------------------------------------------------------------------
func (c *Collection[T]) NextID() uint64 {
return c.items.nextID()
}
// ----------------------------------------------------------------------------
func (c *Collection[T]) Insert(item T) (T, error) {
if !c.primary {
return item, ErrReadOnly
}
c.sanitize(&item)
if err := c.validate(&item); err != nil {
return item, err
}
id := c.getID(&item)
c.idLock.Lock(id)
defer c.idLock.Unlock(id)
if _, ok := c.items.Get(id); ok {
return item, fmt.Errorf("%w: ID", ErrDuplicate)
}
// Acquire locks and check for insert conflicts.
for _, idx := range c.uniqueIndices {
idx.lock(&item)
defer idx.unlock(&item)
if idx.insertConflict(&item) {
return item, fmt.Errorf("%w: %s", ErrDuplicate, idx.name())
}
}
for _, idx := range c.indices {
idx.insert(&item)
}
return item, nil
}
func (c *Collection[T]) Update(id uint64, update func(T) (T, error)) error {
if !c.primary {
return ErrReadOnly
}
c.idLock.Lock(id)
defer c.idLock.Unlock(id)
old, ok := c.items.Get(id)
if !ok {
return ErrNotFound
}
newItem, err := update(*old)
if err != nil {
if err == ErrAbortUpdate {
return nil
}
return err
}
new := &newItem
if c.getID(new) != id {
return ErrMismatchedIDs
}
c.sanitize(new)
if err := c.validate(new); err != nil {
return err
}
// Acquire locks and check for update conflicts.
for _, idx := range c.uniqueIndices {
idx.lock(new)
defer idx.unlock(new)
if idx.updateConflict(new) {
return fmt.Errorf("%w: %s", ErrDuplicate, idx.name())
}
}
for _, idx := range c.indices {
idx.update(old, new)
}
return nil
}
func (c Collection[T]) Delete(id uint64) {
if !c.primary {
panic(ErrReadOnly)
}
c.idLock.Lock(id)
defer c.idLock.Unlock(id)
item, ok := c.items.Get(id)
if !ok {
return
}
// Acquire locks and check for insert conflicts.
for _, idx := range c.uniqueIndices {
idx.lock(item)
defer idx.unlock(item)
}
for _, idx := range c.indices {
idx.delete(item)
}
}
func (c Collection[T]) Get(id uint64) (t T, ok bool) {
ptr, ok := c.items.Get(id)
if !ok {
return t, false
}
return *ptr, true
}
// ----------------------------------------------------------------------------
func (c *Collection[T]) loadData() {
for _, idx := range c.indices {
must(idx.load(c.items.m))
}
}
func (c *Collection[T]) onStore(id uint64, data []byte) {
item := decode[T](data)
old, ok := c.items.Get(id)
if !ok {
// Insert.
for _, idx := range c.indices {
idx.insert(item)
}
} else {
// Otherwise update.
for _, idx := range c.indices {
idx.update(old, item)
}
}
}
func (c *Collection[T]) onDelete(id uint64) {
if item, ok := c.items.Get(id); ok {
for _, idx := range c.indices {
idx.delete(item)
}
}
}