package mdb import ( "unsafe" "github.com/google/btree" ) func NewIndex[T any]( c *Collection[T], name string, compare func(lhs, rhs *T) int, ) *Index[T] { return c.addIndex(indexConfig[T]{ Name: name, Unique: false, Compare: compare, Include: nil, }) } func NewPartialIndex[T any]( c *Collection[T], name string, compare func(lhs, rhs *T) int, include func(*T) bool, ) *Index[T] { return c.addIndex(indexConfig[T]{ Name: name, Unique: false, Compare: compare, Include: include, }) } func NewUniqueIndex[T any]( c *Collection[T], name string, compare func(lhs, rhs *T) int, ) *Index[T] { return c.addIndex(indexConfig[T]{ Name: name, Unique: true, Compare: compare, Include: nil, }) } func NewUniquePartialIndex[T any]( c *Collection[T], name string, compare func(lhs, rhs *T) int, include func(*T) bool, ) *Index[T] { return c.addIndex(indexConfig[T]{ Name: name, Unique: true, Compare: compare, Include: include, }) } // ---------------------------------------------------------------------------- type Index[T any] struct { db *Database name string collectionID uint64 indexID uint64 include func(*T) bool copy func(*T) *T } func (i *Index[T]) ensureSnapshot(tx *Snapshot) *Snapshot { if tx == nil { tx = i.db.Snapshot() } return tx } func (i *Index[T]) Get(tx *Snapshot, in *T) *T { tx = i.ensureSnapshot(tx) if tPtr, ok := i.get(tx, in); ok { return i.copy(tPtr) } return nil } func (i *Index[T]) get(tx *Snapshot, in *T) (*T, bool) { return i.btree(tx).Get(in) } func (i *Index[T]) Has(tx *Snapshot, in *T) bool { tx = i.ensureSnapshot(tx) return i.btree(tx).Has(in) } func (i *Index[T]) Min(tx *Snapshot) *T { tx = i.ensureSnapshot(tx) if tPtr, ok := i.btree(tx).Min(); ok { return i.copy(tPtr) } return nil } func (i *Index[T]) Max(tx *Snapshot) *T { tx = i.ensureSnapshot(tx) if tPtr, ok := i.btree(tx).Max(); ok { return i.copy(tPtr) } return nil } func (i *Index[T]) Ascend(tx *Snapshot, each func(*T) bool) { tx = i.ensureSnapshot(tx) i.btreeForIter(tx).Ascend(func(t *T) bool { return each(i.copy(t)) }) } func (i *Index[T]) AscendAfter(tx *Snapshot, after *T, each func(*T) bool) { tx = i.ensureSnapshot(tx) i.btreeForIter(tx).AscendGreaterOrEqual(after, func(t *T) bool { return each(i.copy(t)) }) } func (i *Index[T]) Descend(tx *Snapshot, each func(*T) bool) { tx = i.ensureSnapshot(tx) i.btreeForIter(tx).Descend(func(t *T) bool { return each(i.copy(t)) }) } func (i *Index[T]) DescendAfter(tx *Snapshot, after *T, each func(*T) bool) { tx = i.ensureSnapshot(tx) i.btreeForIter(tx).DescendLessOrEqual(after, func(t *T) bool { return each(i.copy(t)) }) } func (i *Index[T]) Count(tx *Snapshot) int { tx = i.ensureSnapshot(tx) return i.btree(tx).Len() } // ---------------------------------------------------------------------------- func (i *Index[T]) insertConflict(tx *Snapshot, item *T) bool { return i.btree(tx).Has(item) } func (i *Index[T]) updateConflict(tx *Snapshot, item *T) bool { current, ok := i.btree(tx).Get(item) return ok && i.getID(current) != i.getID(item) } // This should only be called after insertConflict. Additionally, the caller // should ensure that the index has been properly cloned for write before // writing. func (i *Index[T]) insert(tx *Snapshot, item *T) { if i.include != nil && !i.include(item) { return } i.btree(tx).ReplaceOrInsert(item) } func (i *Index[T]) update(tx *Snapshot, old, new *T) { bt := i.btree(tx) bt.Delete(old) // The insert call will also check the include function if available. i.insert(tx, new) } func (i *Index[T]) delete(tx *Snapshot, item *T) { i.btree(tx).Delete(item) } // ---------------------------------------------------------------------------- func (i *Index[T]) getState(tx *Snapshot) indexState[T] { return tx.collections[i.collectionID].(*collectionState[T]).Indices[i.indexID] } // Get the current btree for get/has/update/delete, etc. func (i *Index[T]) btree(tx *Snapshot) *btree.BTreeG[*T] { return i.getState(tx).BTree } func (i *Index[T]) btreeForIter(tx *Snapshot) *btree.BTreeG[*T] { cState := tx.collections[i.collectionID].(*collectionState[T]) bt := cState.Indices[i.indexID].BTree // If snapshot and index are writable, return a clone. if tx.writable() && cState.Version == tx.version { bt = bt.Clone() } return bt } func (i *Index[T]) getID(t *T) uint64 { return *((*uint64)(unsafe.Pointer(t))) }