package mdb import ( "fmt" "git.crumpington.com/private/mdb/keyedmutex" ) type Collection[T any] struct { primary bool db *Database name string idLock keyedmutex.KeyedMutex[uint64] items *itemMap[T] indices []itemIndex[T] uniqueIndices []itemUniqueIndex[T] getID func(*T) uint64 sanitize func(*T) validate func(*T) error } func NewCollection[T any](db *Database, name string, getID func(*T) uint64) *Collection[T] { items := newItemMap(db.kv, name, getID) c := &Collection[T]{ primary: db.kv.Primary(), db: db, name: name, idLock: keyedmutex.New[uint64](), items: items, indices: []itemIndex[T]{items}, uniqueIndices: []itemUniqueIndex[T]{}, getID: items.getID, sanitize: func(*T) {}, validate: func(*T) error { return nil }, } db.collections[name] = c return c } func (c *Collection[T]) SetSanitize(sanitize func(*T)) { c.sanitize = sanitize } func (c *Collection[T]) SetValidate(validate func(*T) error) { c.validate = validate } // ---------------------------------------------------------------------------- func (c *Collection[T]) NextID() uint64 { return c.items.nextID() } // ---------------------------------------------------------------------------- func (c *Collection[T]) Insert(item T) (T, error) { if !c.primary { return item, ErrReadOnly } c.sanitize(&item) if err := c.validate(&item); err != nil { return item, err } id := c.getID(&item) c.idLock.Lock(id) defer c.idLock.Unlock(id) if _, ok := c.items.Get(id); ok { return item, fmt.Errorf("%w: ID", ErrDuplicate) } // Acquire locks and check for insert conflicts. for _, idx := range c.uniqueIndices { idx.lock(&item) defer idx.unlock(&item) if idx.insertConflict(&item) { return item, fmt.Errorf("%w: %s", ErrDuplicate, idx.name()) } } for _, idx := range c.indices { idx.insert(&item) } return item, nil } func (c *Collection[T]) Update(id uint64, update func(T) (T, error)) error { if !c.primary { return ErrReadOnly } c.idLock.Lock(id) defer c.idLock.Unlock(id) old, ok := c.items.Get(id) if !ok { return ErrNotFound } newItem, err := update(*old) if err != nil { if err == ErrAbortUpdate { return nil } return err } new := &newItem if c.getID(new) != id { return ErrMismatchedIDs } c.sanitize(new) if err := c.validate(new); err != nil { return err } // Acquire locks and check for update conflicts. for _, idx := range c.uniqueIndices { idx.lock(new) defer idx.unlock(new) if idx.updateConflict(new) { return fmt.Errorf("%w: %s", ErrDuplicate, idx.name()) } } for _, idx := range c.indices { idx.update(old, new) } return nil } func (c Collection[T]) Delete(id uint64) { if !c.primary { panic(ErrReadOnly) } c.idLock.Lock(id) defer c.idLock.Unlock(id) item, ok := c.items.Get(id) if !ok { return } // Acquire locks and check for insert conflicts. for _, idx := range c.uniqueIndices { idx.lock(item) defer idx.unlock(item) } for _, idx := range c.indices { idx.delete(item) } } func (c Collection[T]) Get(id uint64) (t T, ok bool) { ptr, ok := c.items.Get(id) if !ok { return t, false } return *ptr, true } // ---------------------------------------------------------------------------- func (c *Collection[T]) loadData() { for _, idx := range c.indices { must(idx.load(c.items.m)) } } func (c *Collection[T]) onStore(id uint64, data []byte) { item := decode[T](data) old, ok := c.items.Get(id) if !ok { // Insert. for _, idx := range c.indices { idx.insert(item) } } else { // Otherwise update. for _, idx := range c.indices { idx.update(old, item) } } } func (c *Collection[T]) onDelete(id uint64) { if item, ok := c.items.Get(id); ok { for _, idx := range c.indices { idx.delete(item) } } }