373 lines
7.2 KiB
Go
373 lines
7.2 KiB
Go
package mdb
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/json"
|
|
"errors"
|
|
"hash/crc64"
|
|
"unsafe"
|
|
|
|
"git.crumpington.com/public/jldb/lib/errs"
|
|
|
|
"github.com/google/btree"
|
|
)
|
|
|
|
type Collection[T any] struct {
|
|
db *Database
|
|
name string
|
|
collectionID uint64
|
|
|
|
copy func(*T) *T
|
|
sanitize func(*T)
|
|
validate func(*T) error
|
|
|
|
indices []*Index[T]
|
|
uniqueIndices []*Index[T]
|
|
|
|
ByID *Index[T]
|
|
|
|
buf *bytes.Buffer
|
|
}
|
|
|
|
type CollectionConfig[T any] struct {
|
|
Copy func(*T) *T
|
|
Sanitize func(*T)
|
|
Validate func(*T) error
|
|
}
|
|
|
|
func NewCollection[T any](db *Database, name string, conf *CollectionConfig[T]) *Collection[T] {
|
|
if conf == nil {
|
|
conf = &CollectionConfig[T]{}
|
|
}
|
|
|
|
if conf.Copy == nil {
|
|
conf.Copy = func(from *T) *T {
|
|
to := new(T)
|
|
*to = *from
|
|
return to
|
|
}
|
|
}
|
|
|
|
if conf.Sanitize == nil {
|
|
conf.Sanitize = func(*T) {}
|
|
}
|
|
|
|
if conf.Validate == nil {
|
|
conf.Validate = func(*T) error {
|
|
return nil
|
|
}
|
|
}
|
|
|
|
c := &Collection[T]{
|
|
db: db,
|
|
name: name,
|
|
collectionID: crc64.Checksum([]byte(name), crc64Table),
|
|
copy: conf.Copy,
|
|
sanitize: conf.Sanitize,
|
|
validate: conf.Validate,
|
|
indices: []*Index[T]{},
|
|
uniqueIndices: []*Index[T]{},
|
|
buf: &bytes.Buffer{},
|
|
}
|
|
|
|
db.addCollection(c.collectionID, c, &collectionState[T]{
|
|
Indices: []indexState[T]{},
|
|
})
|
|
|
|
c.ByID = c.addIndex(indexConfig[T]{
|
|
Name: "ByID",
|
|
Unique: true,
|
|
Compare: func(lhs, rhs *T) int {
|
|
l := c.getID(lhs)
|
|
r := c.getID(rhs)
|
|
if l < r {
|
|
return -1
|
|
} else if l > r {
|
|
return 1
|
|
}
|
|
return 0
|
|
},
|
|
})
|
|
|
|
return c
|
|
}
|
|
|
|
func (c *Collection[T]) Name() string {
|
|
return c.name
|
|
}
|
|
|
|
type indexConfig[T any] struct {
|
|
Name string
|
|
Unique bool
|
|
|
|
// If an index isn't unique, an additional comparison by ID is added if
|
|
// two items are otherwise equal.
|
|
Compare func(lhs, rhs *T) int
|
|
|
|
// If not nil, indicates if a given item should be in the index.
|
|
Include func(item *T) bool
|
|
}
|
|
|
|
// AddIndex: Add an index to the collection.
|
|
func (c *Collection[T]) addIndex(conf indexConfig[T]) *Index[T] {
|
|
var less func(*T, *T) bool
|
|
|
|
if conf.Unique {
|
|
less = func(lhs, rhs *T) bool {
|
|
return conf.Compare(lhs, rhs) == -1
|
|
}
|
|
} else {
|
|
less = func(lhs, rhs *T) bool {
|
|
switch conf.Compare(lhs, rhs) {
|
|
case -1:
|
|
return true
|
|
case 1:
|
|
return false
|
|
default:
|
|
return c.getID(lhs) < c.getID(rhs)
|
|
}
|
|
}
|
|
}
|
|
|
|
indexState := indexState[T]{
|
|
BTree: btree.NewG(256, less),
|
|
}
|
|
|
|
index := &Index[T]{
|
|
db: c.db,
|
|
collectionID: c.collectionID,
|
|
name: conf.Name,
|
|
indexID: c.getState(c.db.Snapshot()).addIndex(indexState),
|
|
include: conf.Include,
|
|
copy: c.copy,
|
|
}
|
|
|
|
c.indices = append(c.indices, index)
|
|
if conf.Unique {
|
|
c.uniqueIndices = append(c.uniqueIndices, index)
|
|
}
|
|
|
|
return index
|
|
}
|
|
|
|
func (c *Collection[T]) Get(tx *Snapshot, id uint64) *T {
|
|
if tx == nil {
|
|
tx = c.db.Snapshot()
|
|
}
|
|
item := new(T)
|
|
c.setID(item, id)
|
|
return c.ByID.Get(tx, item)
|
|
}
|
|
|
|
func (c *Collection[T]) Insert(tx *Snapshot, userItem *T) error {
|
|
if tx == nil {
|
|
return c.db.Update(func(tx *Snapshot) error {
|
|
return c.insert(tx, userItem)
|
|
})
|
|
}
|
|
|
|
return c.insert(tx, userItem)
|
|
}
|
|
|
|
func (c *Collection[T]) insert(tx *Snapshot, userItem *T) error {
|
|
if err := c.ensureMutable(tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
item := c.copy(userItem)
|
|
c.sanitize(item)
|
|
|
|
if err := c.validate(item); err != nil {
|
|
return err
|
|
}
|
|
|
|
for i := range c.uniqueIndices {
|
|
if c.uniqueIndices[i].insertConflict(tx, item) {
|
|
return errs.Duplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name)
|
|
}
|
|
}
|
|
|
|
tx.store(c.collectionID, c.getID(item), item)
|
|
|
|
for i := range c.indices {
|
|
c.indices[i].insert(tx, item)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Collection[T]) Update(tx *Snapshot, userItem *T) error {
|
|
if tx == nil {
|
|
return c.db.Update(func(tx *Snapshot) error {
|
|
return c.update(tx, userItem)
|
|
})
|
|
}
|
|
return c.update(tx, userItem)
|
|
}
|
|
|
|
func (c *Collection[T]) update(tx *Snapshot, userItem *T) error {
|
|
if err := c.ensureMutable(tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
item := c.copy(userItem)
|
|
c.sanitize(item)
|
|
|
|
if err := c.validate(item); err != nil {
|
|
return err
|
|
}
|
|
|
|
old, ok := c.ByID.get(tx, item)
|
|
if !ok {
|
|
return errs.NotFound
|
|
}
|
|
|
|
for i := range c.uniqueIndices {
|
|
if c.uniqueIndices[i].updateConflict(tx, item) {
|
|
return errs.Duplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name)
|
|
}
|
|
}
|
|
|
|
tx.store(c.collectionID, c.getID(item), item)
|
|
|
|
for i := range c.indices {
|
|
c.indices[i].update(tx, old, item)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Collection[T]) Upsert(tx *Snapshot, item *T) error {
|
|
if tx == nil {
|
|
return c.db.Update(func(tx *Snapshot) error {
|
|
return c.upsert(tx, item)
|
|
})
|
|
}
|
|
return c.upsert(tx, item)
|
|
}
|
|
|
|
func (c *Collection[T]) upsert(tx *Snapshot, item *T) error {
|
|
err := c.Insert(tx, item)
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
if errors.Is(err, errs.Duplicate) {
|
|
return c.Update(tx, item)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (c *Collection[T]) Delete(tx *Snapshot, itemID uint64) error {
|
|
if tx == nil {
|
|
return c.db.Update(func(tx *Snapshot) error {
|
|
return c.delete(tx, itemID)
|
|
})
|
|
}
|
|
return c.delete(tx, itemID)
|
|
}
|
|
|
|
func (c *Collection[T]) delete(tx *Snapshot, itemID uint64) error {
|
|
if err := c.ensureMutable(tx); err != nil {
|
|
return err
|
|
}
|
|
|
|
return c.deleteItem(tx, itemID)
|
|
}
|
|
|
|
func (c *Collection[T]) Count(tx *Snapshot) int {
|
|
return c.ByID.Count(tx)
|
|
}
|
|
|
|
func (c *Collection[T]) getByID(tx *Snapshot, itemID uint64) (*T, bool) {
|
|
x := new(T)
|
|
c.setID(x, itemID)
|
|
return c.ByID.get(tx, x)
|
|
}
|
|
|
|
func (c *Collection[T]) ensureMutable(tx *Snapshot) error {
|
|
if !tx.writable() {
|
|
return errs.ReadOnly
|
|
}
|
|
|
|
state := c.getState(tx)
|
|
if state.Version != tx.version {
|
|
tx.collections[c.collectionID] = state.clone(tx.version)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// For initial data loading.
|
|
func (c *Collection[T]) insertItem(tx *Snapshot, itemID uint64, data []byte) error {
|
|
item := new(T)
|
|
if err := json.Unmarshal(data, item); err != nil {
|
|
return errs.Encoding.WithErr(err).WithCollection(c.name)
|
|
}
|
|
|
|
// Check for insert conflict.
|
|
for _, index := range c.uniqueIndices {
|
|
if index.insertConflict(tx, item) {
|
|
return errs.Duplicate
|
|
}
|
|
}
|
|
|
|
// Do the insert.
|
|
for _, index := range c.indices {
|
|
index.insert(tx, item)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Collection[T]) deleteItem(tx *Snapshot, itemID uint64) error {
|
|
item, ok := c.getByID(tx, itemID)
|
|
if !ok {
|
|
return errs.NotFound
|
|
}
|
|
|
|
tx.delete(c.collectionID, itemID)
|
|
|
|
for i := range c.indices {
|
|
c.indices[i].delete(tx, item)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// upsertItem inserts or updates the item with itemID and the given serialized
|
|
// form. It's called by
|
|
func (c *Collection[T]) upsertItem(tx *Snapshot, itemID uint64, data []byte) error {
|
|
item, ok := c.getByID(tx, itemID)
|
|
if ok {
|
|
tx.delete(c.collectionID, itemID)
|
|
|
|
for i := range c.indices {
|
|
c.indices[i].delete(tx, item)
|
|
}
|
|
}
|
|
|
|
item = new(T)
|
|
if err := json.Unmarshal(data, item); err != nil {
|
|
return errs.Encoding.WithErr(err).WithCollection(c.name)
|
|
}
|
|
|
|
// Do the insert.
|
|
for _, index := range c.indices {
|
|
index.insert(tx, item)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Collection[T]) getID(t *T) uint64 {
|
|
return *((*uint64)(unsafe.Pointer(t)))
|
|
}
|
|
|
|
func (c *Collection[T]) setID(t *T, id uint64) {
|
|
*((*uint64)(unsafe.Pointer(t))) = id
|
|
}
|
|
|
|
func (c *Collection[T]) getState(tx *Snapshot) *collectionState[T] {
|
|
return tx.collections[c.collectionID].(*collectionState[T])
|
|
}
|