package mdb import ( "encoding/json" "errors" "hash/crc64" "unsafe" "git.crumpington.com/public/jldb/lib/errs" "github.com/google/btree" ) type Collection[T any] struct { db *Database name string collectionID uint64 copy func(*T) *T sanitize func(*T) validate func(*T) error indices []*Index[T] uniqueIndices []*Index[T] ByID *Index[T] } type CollectionConfig[T any] struct { Copy func(*T) *T Sanitize func(*T) Validate func(*T) error } func NewCollection[T any](db *Database, name string, conf *CollectionConfig[T]) *Collection[T] { if conf == nil { conf = &CollectionConfig[T]{} } if conf.Copy == nil { conf.Copy = func(from *T) *T { to := new(T) *to = *from return to } } if conf.Sanitize == nil { conf.Sanitize = func(*T) {} } if conf.Validate == nil { conf.Validate = func(*T) error { return nil } } c := &Collection[T]{ db: db, name: name, collectionID: crc64.Checksum([]byte(name), crc64Table), copy: conf.Copy, sanitize: conf.Sanitize, validate: conf.Validate, indices: []*Index[T]{}, uniqueIndices: []*Index[T]{}, } db.addCollection(c.collectionID, c, &collectionState[T]{ Indices: []indexState[T]{}, }) c.ByID = c.addIndex(indexConfig[T]{ Name: "ByID", Unique: true, Compare: func(lhs, rhs *T) int { l := c.getID(lhs) r := c.getID(rhs) if l < r { return -1 } else if l > r { return 1 } return 0 }, }) return c } func (c *Collection[T]) Name() string { return c.name } type indexConfig[T any] struct { Name string Unique bool // If an index isn't unique, an additional comparison by ID is added if // two items are otherwise equal. Compare func(lhs, rhs *T) int // If not nil, indicates if a given item should be in the index. Include func(item *T) bool } // AddIndex: Add an index to the collection. func (c *Collection[T]) addIndex(conf indexConfig[T]) *Index[T] { var less func(*T, *T) bool if conf.Unique { less = func(lhs, rhs *T) bool { return conf.Compare(lhs, rhs) == -1 } } else { less = func(lhs, rhs *T) bool { switch conf.Compare(lhs, rhs) { case -1: return true case 1: return false default: return c.getID(lhs) < c.getID(rhs) } } } indexState := indexState[T]{ BTree: btree.NewG(256, less), } index := &Index[T]{ db: c.db, collectionID: c.collectionID, name: conf.Name, indexID: c.getState(c.db.Snapshot()).addIndex(indexState), include: conf.Include, copy: c.copy, } c.indices = append(c.indices, index) if conf.Unique { c.uniqueIndices = append(c.uniqueIndices, index) } return index } func (c *Collection[T]) Get(tx *Snapshot, id uint64) *T { if tx == nil { tx = c.db.Snapshot() } item := new(T) c.setID(item, id) return c.ByID.Get(tx, item) } func (c *Collection[T]) Has(tx *Snapshot, id uint64) bool { if tx == nil { tx = c.db.Snapshot() } item := new(T) c.setID(item, id) return c.ByID.Has(tx, item) } func (c *Collection[T]) Insert(tx *Snapshot, userItem *T) error { if tx == nil { return c.db.Update(func(tx *Snapshot) error { return c.insert(tx, userItem) }) } return c.insert(tx, userItem) } func (c *Collection[T]) insert(tx *Snapshot, userItem *T) error { if err := c.ensureMutable(tx); err != nil { return err } item := c.copy(userItem) c.sanitize(item) if err := c.validate(item); err != nil { return err } for i := range c.uniqueIndices { if c.uniqueIndices[i].insertConflict(tx, item) { return errs.Duplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name) } } tx.store(c.collectionID, c.getID(item), item) for i := range c.indices { c.indices[i].insert(tx, item) } return nil } func (c *Collection[T]) Update(tx *Snapshot, userItem *T) error { if tx == nil { return c.db.Update(func(tx *Snapshot) error { return c.update(tx, userItem) }) } return c.update(tx, userItem) } func (c *Collection[T]) update(tx *Snapshot, userItem *T) error { if err := c.ensureMutable(tx); err != nil { return err } item := c.copy(userItem) c.sanitize(item) if err := c.validate(item); err != nil { return err } old, ok := c.ByID.get(tx, item) if !ok { return errs.NotFound } for i := range c.uniqueIndices { if c.uniqueIndices[i].updateConflict(tx, item) { return errs.Duplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name) } } tx.store(c.collectionID, c.getID(item), item) for i := range c.indices { c.indices[i].update(tx, old, item) } return nil } func (c *Collection[T]) UpdateFunc(tx *Snapshot, id uint64, update func(item *T) error) error { if tx == nil { return c.db.Update(func(tx *Snapshot) error { return c.updateFunc(tx, id, update) }) } return c.updateFunc(tx, id, update) } func (c *Collection[T]) updateFunc(tx *Snapshot, id uint64, update func(item *T) error) error { item := c.Get(tx, id) if item == nil { return errs.NotFound } if err := update(item); err != nil { return err } c.setID(item, id) // Don't allow the ID to change. return c.update(tx, item) } func (c *Collection[T]) Upsert(tx *Snapshot, item *T) error { if tx == nil { return c.db.Update(func(tx *Snapshot) error { return c.upsert(tx, item) }) } return c.upsert(tx, item) } func (c *Collection[T]) upsert(tx *Snapshot, item *T) error { err := c.Insert(tx, item) if err == nil { return nil } if errors.Is(err, errs.Duplicate) { return c.Update(tx, item) } return err } func (c *Collection[T]) UpsertFunc(tx *Snapshot, id uint64, update func(item *T) error) error { if tx == nil { c.db.Update(func(tx *Snapshot) error { return c.upsertFunc(tx, id, update) }) } return c.upsertFunc(tx, id, update) } func (c *Collection[T]) upsertFunc(tx *Snapshot, id uint64, update func(item *T) error) error { insert := false item := c.Get(tx, id) if item == nil { item = new(T) insert = true } if err := update(item); err != nil { return err } c.setID(item, id) // Don't allow the ID to change. if insert { return c.insert(tx, item) } return c.update(tx, item) } func (c *Collection[T]) Delete(tx *Snapshot, itemID uint64) error { if tx == nil { return c.db.Update(func(tx *Snapshot) error { return c.delete(tx, itemID) }) } return c.delete(tx, itemID) } func (c *Collection[T]) delete(tx *Snapshot, itemID uint64) error { if err := c.ensureMutable(tx); err != nil { return err } return c.deleteItem(tx, itemID) } func (c *Collection[T]) Count(tx *Snapshot) int { if tx == nil { tx = c.db.Snapshot() } return c.ByID.Count(tx) } func (c *Collection[T]) getByID(tx *Snapshot, itemID uint64) (*T, bool) { x := new(T) c.setID(x, itemID) return c.ByID.get(tx, x) } func (c *Collection[T]) ensureMutable(tx *Snapshot) error { if !tx.writable() { return errs.ReadOnly } state := c.getState(tx) if state.Version != tx.version { tx.collections[c.collectionID] = state.clone(tx.version) } return nil } // For initial data loading. func (c *Collection[T]) insertItem(tx *Snapshot, itemID uint64, data []byte) error { item := new(T) if err := json.Unmarshal(data, item); err != nil { return errs.Encoding.WithErr(err).WithCollection(c.name) } // Check for insert conflict. for _, index := range c.uniqueIndices { if index.insertConflict(tx, item) { return errs.Duplicate } } // Do the insert. for _, index := range c.indices { index.insert(tx, item) } return nil } func (c *Collection[T]) deleteItem(tx *Snapshot, itemID uint64) error { item, ok := c.getByID(tx, itemID) if !ok { return errs.NotFound } tx.delete(c.collectionID, itemID) for i := range c.indices { c.indices[i].delete(tx, item) } return nil } // upsertItem inserts or updates the item with itemID and the given serialized // form. It's called by func (c *Collection[T]) upsertItem(tx *Snapshot, itemID uint64, data []byte) error { item, ok := c.getByID(tx, itemID) if ok { tx.delete(c.collectionID, itemID) for i := range c.indices { c.indices[i].delete(tx, item) } } item = new(T) if err := json.Unmarshal(data, item); err != nil { return errs.Encoding.WithErr(err).WithCollection(c.name) } // Do the insert. for _, index := range c.indices { index.insert(tx, item) } return nil } func (c *Collection[T]) getID(t *T) uint64 { return *((*uint64)(unsafe.Pointer(t))) } func (c *Collection[T]) setID(t *T, id uint64) { *((*uint64)(unsafe.Pointer(t))) = id } func (c *Collection[T]) getState(tx *Snapshot) *collectionState[T] { return tx.collections[c.collectionID].(*collectionState[T]) }