jldb/mdb/collection.go

432 lines
8.5 KiB
Go

package mdb
import (
"encoding/json"
"errors"
"hash/crc64"
"unsafe"
"git.crumpington.com/public/jldb/lib/errs"
"github.com/google/btree"
)
type Collection[T any] struct {
db *Database
name string
collectionID uint64
copy func(*T) *T
sanitize func(*T)
validate func(*T) error
indices []*Index[T]
uniqueIndices []*Index[T]
ByID *Index[T]
}
type CollectionConfig[T any] struct {
Copy func(*T) *T
Sanitize func(*T)
Validate func(*T) error
}
func NewCollection[T any](db *Database, name string, conf *CollectionConfig[T]) *Collection[T] {
if conf == nil {
conf = &CollectionConfig[T]{}
}
if conf.Copy == nil {
conf.Copy = func(from *T) *T {
to := new(T)
*to = *from
return to
}
}
if conf.Sanitize == nil {
conf.Sanitize = func(*T) {}
}
if conf.Validate == nil {
conf.Validate = func(*T) error {
return nil
}
}
c := &Collection[T]{
db: db,
name: name,
collectionID: crc64.Checksum([]byte(name), crc64Table),
copy: conf.Copy,
sanitize: conf.Sanitize,
validate: conf.Validate,
indices: []*Index[T]{},
uniqueIndices: []*Index[T]{},
}
db.addCollection(c.collectionID, c, &collectionState[T]{
Indices: []indexState[T]{},
})
c.ByID = c.addIndex(indexConfig[T]{
Name: "ByID",
Unique: true,
Compare: func(lhs, rhs *T) int {
l := c.getID(lhs)
r := c.getID(rhs)
if l < r {
return -1
} else if l > r {
return 1
}
return 0
},
})
return c
}
func (c *Collection[T]) Name() string {
return c.name
}
type indexConfig[T any] struct {
Name string
Unique bool
// If an index isn't unique, an additional comparison by ID is added if
// two items are otherwise equal.
Compare func(lhs, rhs *T) int
// If not nil, indicates if a given item should be in the index.
Include func(item *T) bool
}
// AddIndex: Add an index to the collection.
func (c *Collection[T]) addIndex(conf indexConfig[T]) *Index[T] {
var less func(*T, *T) bool
if conf.Unique {
less = func(lhs, rhs *T) bool {
return conf.Compare(lhs, rhs) == -1
}
} else {
less = func(lhs, rhs *T) bool {
switch conf.Compare(lhs, rhs) {
case -1:
return true
case 1:
return false
default:
return c.getID(lhs) < c.getID(rhs)
}
}
}
indexState := indexState[T]{
BTree: btree.NewG(256, less),
}
index := &Index[T]{
db: c.db,
collectionID: c.collectionID,
name: conf.Name,
indexID: c.getState(c.db.Snapshot()).addIndex(indexState),
include: conf.Include,
copy: c.copy,
}
c.indices = append(c.indices, index)
if conf.Unique {
c.uniqueIndices = append(c.uniqueIndices, index)
}
return index
}
func (c *Collection[T]) Get(tx *Snapshot, id uint64) *T {
if tx == nil {
tx = c.db.Snapshot()
}
item := new(T)
c.setID(item, id)
return c.ByID.Get(tx, item)
}
func (c *Collection[T]) Has(tx *Snapshot, id uint64) bool {
if tx == nil {
tx = c.db.Snapshot()
}
item := new(T)
c.setID(item, id)
return c.ByID.Has(tx, item)
}
func (c *Collection[T]) Insert(tx *Snapshot, userItem *T) error {
if tx == nil {
return c.db.Update(func(tx *Snapshot) error {
return c.insert(tx, userItem)
})
}
return c.insert(tx, userItem)
}
func (c *Collection[T]) insert(tx *Snapshot, userItem *T) error {
if err := c.ensureMutable(tx); err != nil {
return err
}
item := c.copy(userItem)
c.sanitize(item)
if err := c.validate(item); err != nil {
return err
}
for i := range c.uniqueIndices {
if c.uniqueIndices[i].insertConflict(tx, item) {
return errs.Duplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name)
}
}
tx.store(c.collectionID, c.getID(item), item)
for i := range c.indices {
c.indices[i].insert(tx, item)
}
return nil
}
func (c *Collection[T]) Update(tx *Snapshot, userItem *T) error {
if tx == nil {
return c.db.Update(func(tx *Snapshot) error {
return c.update(tx, userItem)
})
}
return c.update(tx, userItem)
}
func (c *Collection[T]) update(tx *Snapshot, userItem *T) error {
if err := c.ensureMutable(tx); err != nil {
return err
}
item := c.copy(userItem)
c.sanitize(item)
if err := c.validate(item); err != nil {
return err
}
old, ok := c.ByID.get(tx, item)
if !ok {
return errs.NotFound
}
for i := range c.uniqueIndices {
if c.uniqueIndices[i].updateConflict(tx, item) {
return errs.Duplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name)
}
}
tx.store(c.collectionID, c.getID(item), item)
for i := range c.indices {
c.indices[i].update(tx, old, item)
}
return nil
}
func (c *Collection[T]) UpdateFunc(tx *Snapshot, id uint64, update func(item *T) error) error {
if tx == nil {
return c.db.Update(func(tx *Snapshot) error {
return c.updateFunc(tx, id, update)
})
}
return c.updateFunc(tx, id, update)
}
func (c *Collection[T]) updateFunc(tx *Snapshot, id uint64, update func(item *T) error) error {
item := c.Get(tx, id)
if item == nil {
return errs.NotFound
}
if err := update(item); err != nil {
return err
}
c.setID(item, id) // Don't allow the ID to change.
return c.update(tx, item)
}
func (c *Collection[T]) Upsert(tx *Snapshot, item *T) error {
if tx == nil {
return c.db.Update(func(tx *Snapshot) error {
return c.upsert(tx, item)
})
}
return c.upsert(tx, item)
}
func (c *Collection[T]) upsert(tx *Snapshot, item *T) error {
err := c.Insert(tx, item)
if err == nil {
return nil
}
if errors.Is(err, errs.Duplicate) {
return c.Update(tx, item)
}
return err
}
func (c *Collection[T]) UpsertFunc(tx *Snapshot, id uint64, update func(item *T) error) error {
if tx == nil {
c.db.Update(func(tx *Snapshot) error {
return c.upsertFunc(tx, id, update)
})
}
return c.upsertFunc(tx, id, update)
}
func (c *Collection[T]) upsertFunc(tx *Snapshot, id uint64, update func(item *T) error) error {
insert := false
item := c.Get(tx, id)
if item == nil {
item = new(T)
insert = true
}
if err := update(item); err != nil {
return err
}
c.setID(item, id) // Don't allow the ID to change.
if insert {
return c.insert(tx, item)
}
return c.update(tx, item)
}
func (c *Collection[T]) Delete(tx *Snapshot, itemID uint64) error {
if tx == nil {
return c.db.Update(func(tx *Snapshot) error {
return c.delete(tx, itemID)
})
}
return c.delete(tx, itemID)
}
func (c *Collection[T]) delete(tx *Snapshot, itemID uint64) error {
if err := c.ensureMutable(tx); err != nil {
return err
}
return c.deleteItem(tx, itemID)
}
func (c *Collection[T]) Count(tx *Snapshot) int {
if tx == nil {
tx = c.db.Snapshot()
}
return c.ByID.Count(tx)
}
func (c *Collection[T]) getByID(tx *Snapshot, itemID uint64) (*T, bool) {
x := new(T)
c.setID(x, itemID)
return c.ByID.get(tx, x)
}
func (c *Collection[T]) ensureMutable(tx *Snapshot) error {
if !tx.writable() {
return errs.ReadOnly
}
state := c.getState(tx)
if state.Version != tx.version {
tx.collections[c.collectionID] = state.clone(tx.version)
}
return nil
}
// For initial data loading.
func (c *Collection[T]) insertItem(tx *Snapshot, itemID uint64, data []byte) error {
item := new(T)
if err := json.Unmarshal(data, item); err != nil {
return errs.Encoding.WithErr(err).WithCollection(c.name)
}
// Check for insert conflict.
for _, index := range c.uniqueIndices {
if index.insertConflict(tx, item) {
return errs.Duplicate
}
}
// Do the insert.
for _, index := range c.indices {
index.insert(tx, item)
}
return nil
}
func (c *Collection[T]) deleteItem(tx *Snapshot, itemID uint64) error {
item, ok := c.getByID(tx, itemID)
if !ok {
return errs.NotFound
}
tx.delete(c.collectionID, itemID)
for i := range c.indices {
c.indices[i].delete(tx, item)
}
return nil
}
// upsertItem inserts or updates the item with itemID and the given serialized
// form. It's called by
func (c *Collection[T]) upsertItem(tx *Snapshot, itemID uint64, data []byte) error {
item, ok := c.getByID(tx, itemID)
if ok {
tx.delete(c.collectionID, itemID)
for i := range c.indices {
c.indices[i].delete(tx, item)
}
}
item = new(T)
if err := json.Unmarshal(data, item); err != nil {
return errs.Encoding.WithErr(err).WithCollection(c.name)
}
// Do the insert.
for _, index := range c.indices {
index.insert(tx, item)
}
return nil
}
func (c *Collection[T]) getID(t *T) uint64 {
return *((*uint64)(unsafe.Pointer(t)))
}
func (c *Collection[T]) setID(t *T, id uint64) {
*((*uint64)(unsafe.Pointer(t))) = id
}
func (c *Collection[T]) getState(tx *Snapshot) *collectionState[T] {
return tx.collections[c.collectionID].(*collectionState[T])
}