diff --git a/btreeindex.go b/btreeindex.go new file mode 100644 index 0000000..a33fd7d --- /dev/null +++ b/btreeindex.go @@ -0,0 +1,157 @@ +package mdb + +import ( + "sync" + "sync/atomic" + + "github.com/google/btree" +) + +type BTreeIndex[T any] struct { + c *Collection[T] + modLock sync.Mutex + bt atomic.Value // *btree.BTreeG[*T] + getID func(*T) uint64 + less func(*T, *T) bool + include func(*T) bool +} + +func NewBTreeIndex[T any]( + c *Collection[T], + less func(*T, *T) bool, + include func(*T) bool, +) *BTreeIndex[T] { + + t := &BTreeIndex[T]{ + c: c, + getID: c.getID, + less: less, + include: include, + } + + btree := btree.NewG(64, less) + t.bt.Store(btree) + c.indices = append(c.indices, t) + + return t +} + +func (t *BTreeIndex[T]) load(m map[uint64]*T) error { + btree := btree.NewG(64, t.less) + t.bt.Store(btree) + for _, item := range m { + if t.include == nil || t.include(item) { + if x, _ := btree.ReplaceOrInsert(item); x != nil { + return ErrDuplicate + } + } + } + return nil +} + +// ---------------------------------------------------------------------------- + +func (t *BTreeIndex[T]) Ascend() *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().Ascend(iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) AscendAfter(pivot T) *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().AscendGreaterOrEqual(&pivot, iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) Descend() *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().Descend(iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) DescendAfter(pivot T) *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().DescendLessOrEqual(&pivot, iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) Get(item T) (T, bool) { + ptr, ok := t.btree().Get(&item) + if !ok { + return item, false + } + return *ptr, true +} + +func (t *BTreeIndex[T]) Min() (item T, ok bool) { + if ptr, ok := t.btree().Min(); ok { + return *ptr, ok + } + return item, false +} + +func (t *BTreeIndex[T]) Max() (item T, ok bool) { + if ptr, ok := t.btree().Max(); ok { + return *ptr, ok + } + return item, false +} + +func (t *BTreeIndex[T]) Len() int { + return t.btree().Len() +} + +// ---------------------------------------------------------------------------- + +func (t *BTreeIndex[T]) insert(item *T) { + if t.include == nil || t.include(item) { + t.modify(func(bt *btree.BTreeG[*T]) { + bt.ReplaceOrInsert(item) + }) + } +} + +func (t *BTreeIndex[T]) update(old, new *T) { + t.modify(func(bt *btree.BTreeG[*T]) { + if t.include == nil || t.include(old) { + bt.Delete(old) + } + if t.include == nil || t.include(new) { + bt.ReplaceOrInsert(new) + } + }) +} + +func (t *BTreeIndex[T]) delete(item *T) { + if t.include == nil || t.include(item) { + t.modify(func(bt *btree.BTreeG[*T]) { + bt.Delete(item) + }) + } +} + +// ---------------------------------------------------------------------------- + +func (t *BTreeIndex[T]) btree() *btree.BTreeG[*T] { + return t.bt.Load().(*btree.BTreeG[*T]) +} + +func (t *BTreeIndex[T]) modify(mod func(clone *btree.BTreeG[*T])) { + t.modLock.Lock() + defer t.modLock.Unlock() + clone := t.btree().Clone() + mod(clone) + t.bt.Store(clone) +} diff --git a/btreeindex_test.go b/btreeindex_test.go new file mode 100644 index 0000000..6d4adef --- /dev/null +++ b/btreeindex_test.go @@ -0,0 +1,382 @@ +package mdb + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestBTreeIndex(t *testing.T) { + type Item struct { + ID uint64 + Name string + } + + checkIndexOne := func(idx *BTreeIndex[Item], expected ...Item) error { + if idx.Len() != len(expected) { + return fmt.Errorf("Expected %d items but found %d.", len(expected), idx.Len()) + } + + if len(expected) == 0 { + return nil + } + + for _, item := range expected { + item2, ok := idx.Get(item) + if !ok { + return fmt.Errorf("Missing expected item: %v", item) + } + if !reflect.DeepEqual(item, item2) { + return fmt.Errorf("Items not equal: %v != %v", item2, item) + } + } + + item, ok := idx.Min() + if !ok { + return fmt.Errorf("Min item not found, expected: %v", expected[0]) + } + if !reflect.DeepEqual(item, expected[0]) { + return fmt.Errorf("Min items not equal: %v != %v", item, expected[0]) + } + + item, ok = idx.Max() + i := len(expected) - 1 + if !ok { + return fmt.Errorf("Max item not found, expected: %v", expected[i]) + } + if !reflect.DeepEqual(item, expected[i]) { + return fmt.Errorf("Max items not equal: %v != %v", item, expected[i]) + } + + i = 0 + + iter := idx.Ascend() + defer iter.Close() + for iter.Next() { + if !reflect.DeepEqual(iter.Value(), expected[i]) { + return fmt.Errorf("Items not equal (%d): %v != %v", i, iter.Value(), expected[i]) + } + i++ + } + + i = len(expected) - 1 + iter = idx.Descend() + defer iter.Close() + for iter.Next() { + if !reflect.DeepEqual(iter.Value(), expected[i]) { + return fmt.Errorf("Items not equal (%d): %v != %v", i, iter.Value(), expected[i]) + } + i-- + } + + i = 1 + iter = idx.AscendAfter(expected[1]) + defer iter.Close() + for iter.Next() { + if !reflect.DeepEqual(iter.Value(), expected[i]) { + return fmt.Errorf("Items not equal (%d): %v != %v", i, iter.Value(), expected[i]) + } + i++ + } + + i = len(expected) - 2 + iter = idx.DescendAfter(expected[len(expected)-2]) + defer iter.Close() + for iter.Next() { + if !reflect.DeepEqual(iter.Value(), expected[i]) { + return fmt.Errorf("Items not equal (%d): %v != %v", i, iter.Value(), expected[i]) + } + i-- + } + + return nil + } + + checkIndex := func(idx *BTreeIndex[Item], expected ...Item) error { + idx.c.db.waitForWAL() + + if err := checkIndexOne(idx, expected...); err != nil { + return fmt.Errorf("%w: original", err) + } + + db := NewPrimary(idx.c.db.root) + defer db.Close() + c := NewCollection(db, "collection", func(i *Item) uint64 { return i.ID }) + idx = NewBTreeIndex(c, + func(i, j *Item) bool { return i.Name < j.Name }, + func(i *Item) bool { return i.Name != "" }) + db.Start() + + return checkIndexOne(idx, expected...) + } + + run := func(name string, inner func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item])) { + t.Run(name, func(t *testing.T) { + root := filepath.Join(os.TempDir(), randString()) + //defer os.RemoveAll(root) + + db := NewPrimary(root) + defer db.Close() + + c := NewCollection(db, "collection", func(i *Item) uint64 { return i.ID }) + idx := NewBTreeIndex(c, + func(i, j *Item) bool { return i.Name < j.Name }, + func(i *Item) bool { return i.Name != "" }) + + db.Start() + + inner(t, c, idx) + }) + } + + run("no items", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + if err := checkIndex(idx); err != nil { + t.Fatal(err) + } + }) + + run("insert some", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + if err := checkIndex(idx, item5, item1, item3); err != nil { + t.Fatal(err) + } + }) + + run("partial iteration", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + iter := idx.Ascend() + defer iter.Close() + + if !iter.Next() { + t.Fatal("Expected", item5) + } + + if !reflect.DeepEqual(iter.Value(), item5) { + t.Fatal(iter.Value(), item5) + } + }) + + run("get", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + c.Insert(item1) + c.Insert(item2) + + item, ok := idx.Get(Item{0, "one"}) + if !ok || !reflect.DeepEqual(item, item1) { + t.Fatal(ok, item, item1) + } + }) + + run("get not found", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + c.Insert(item1) + c.Insert(item2) + + if item, ok := idx.Get(Item{0, "three"}); ok { + t.Fatal(item) + } + }) + + run("min max on empty", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item2 := Item{2, ""} + c.Insert(item2) + + if item, ok := idx.Min(); ok { + t.Fatal(item) + } + if item, ok := idx.Max(); ok { + t.Fatal(item) + } + }) + + run("min max with one item", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + + c.Insert(item1) + c.Insert(item2) + + i1, ok := idx.Min() + if !ok { + t.Fatal(ok) + } + i2, ok := idx.Max() + if !ok { + t.Fatal(ok) + } + + if !reflect.DeepEqual(i1, i2) { + t.Fatal(i1, i2) + } + }) + + run("update outside of index", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + c.Update(2, func(in Item) (Item, error) { + return in, nil + }) + + if err := checkIndex(idx, item5, item1, item3); err != nil { + t.Fatal(err) + } + }) + + run("update into index", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + err := c.Update(2, func(in Item) (Item, error) { + in.Name = "two" + return in, nil + }) + if err != nil { + t.Fatal(err) + } + + item2.Name = "two" + + if err := checkIndex(idx, item5, item1, item3, item2); err != nil { + t.Fatal(err) + } + }) + + run("update out of index", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + err := c.Update(1, func(in Item) (Item, error) { + in.Name = "" + return in, nil + }) + if err != nil { + t.Fatal(err) + } + + if err := checkIndex(idx, item5, item3); err != nil { + t.Fatal(err) + } + }) + + run("update within index", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + err := c.Update(1, func(in Item) (Item, error) { + in.Name = "xone" + return in, nil + }) + if err != nil { + t.Fatal(err) + } + + item1.Name = "xone" + + if err := checkIndex(idx, item5, item3, item1); err != nil { + t.Fatal(err) + } + }) + + run("delete outside index", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + c.Delete(item2.ID) + + if err := checkIndex(idx, item5, item1, item3); err != nil { + t.Fatal(err) + } + }) + + run("delete within index", func(t *testing.T, c *Collection[Item], idx *BTreeIndex[Item]) { + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + + c.Delete(item1.ID) + + if err := checkIndex(idx, item5, item3); err != nil { + t.Fatal(err) + } + }) +} diff --git a/btreeiterator.go b/btreeiterator.go new file mode 100644 index 0000000..0ead187 --- /dev/null +++ b/btreeiterator.go @@ -0,0 +1,44 @@ +package mdb + +type BTreeIterator[T any] struct { + out chan *T + close chan struct{} + current *T +} + +func newBTreeIterator[T any]() *BTreeIterator[T] { + return &BTreeIterator[T]{ + out: make(chan *T), + close: make(chan struct{}, 1), + } +} + +func (iter *BTreeIterator[T]) each(i *T) bool { + select { + case iter.out <- i: + return true + case <-iter.close: + return false + } +} + +func (iter *BTreeIterator[T]) done() { + close(iter.out) +} + +func (iter *BTreeIterator[T]) Next() bool { + val, ok := <-iter.out + if ok { + iter.current = val + return true + } + return false +} + +func (iter *BTreeIterator[T]) Close() { + iter.close <- struct{}{} +} + +func (iter *BTreeIterator[T]) Value() T { + return *iter.current +} diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..d8f4ad6 --- /dev/null +++ b/codec.go @@ -0,0 +1,17 @@ +package mdb + +import ( + "encoding/json" +) + +func decode[T any](data []byte) *T { + item := new(T) + must(json.Unmarshal(data, item)) + return item +} + +func encode(item any) []byte { + buf, err := json.Marshal(item) + must(err) + return buf +} diff --git a/collection.go b/collection.go new file mode 100644 index 0000000..43dd7e3 --- /dev/null +++ b/collection.go @@ -0,0 +1,217 @@ +package mdb + +import ( + "fmt" + "log" + + "git.crumpington.com/private/mdb/keyedmutex" +) + +type Collection[T any] struct { + primary bool + db *Database + name string + idLock keyedmutex.KeyedMutex[uint64] + items *itemMap[T] + indices []itemIndex[T] + uniqueIndices []itemUniqueIndex[T] + getID func(*T) uint64 + sanitize func(*T) + validate func(*T) error +} + +func NewCollection[T any](db *Database, name string, getID func(*T) uint64) *Collection[T] { + items := newItemMap(db.kv, name, getID) + c := &Collection[T]{ + primary: db.kv.Primary(), + db: db, + name: name, + idLock: keyedmutex.New[uint64](), + items: items, + indices: []itemIndex[T]{items}, + uniqueIndices: []itemUniqueIndex[T]{}, + getID: items.getID, + sanitize: func(*T) {}, + validate: func(*T) error { return nil }, + } + + db.collections[name] = c + return c +} + +func (c *Collection[T]) SetSanitize(sanitize func(*T)) { + c.sanitize = sanitize +} + +func (c *Collection[T]) SetValidate(validate func(*T) error) { + c.validate = validate +} + +// ---------------------------------------------------------------------------- + +func (c *Collection[T]) NextID() uint64 { + return c.items.nextID() +} + +// ---------------------------------------------------------------------------- + +func (c *Collection[T]) Insert(item T) (T, error) { + if !c.primary { + return item, ErrReadOnly + } + + c.sanitize(&item) + + if err := c.validate(&item); err != nil { + return item, err + } + + id := c.getID(&item) + c.idLock.Lock(id) + defer c.idLock.Unlock(id) + + if _, ok := c.items.Get(id); ok { + return item, fmt.Errorf("%w: ID", ErrDuplicate) + } + + // Acquire locks and check for insert conflicts. + for _, idx := range c.uniqueIndices { + idx.lock(&item) + defer idx.unlock(&item) + + if idx.insertConflict(&item) { + return item, fmt.Errorf("%w: %s", ErrDuplicate, idx.name()) + } + } + + for _, idx := range c.indices { + idx.insert(&item) + } + + return item, nil +} + +func (c *Collection[T]) Update(id uint64, update func(T) (T, error)) error { + if !c.primary { + return ErrReadOnly + } + + c.idLock.Lock(id) + defer c.idLock.Unlock(id) + + old, ok := c.items.Get(id) + if !ok { + return ErrNotFound + } + + newItem, err := update(*old) + if err != nil { + if err == ErrAbortUpdate { + return nil + } + return err + } + + new := &newItem + + if c.getID(new) != id { + return ErrMismatchedIDs + } + + c.sanitize(new) + + if err := c.validate(new); err != nil { + return err + } + + // Acquire locks and check for update conflicts. + for _, idx := range c.uniqueIndices { + idx.lock(new) + defer idx.unlock(new) + + if idx.updateConflict(new) { + return fmt.Errorf("%w: %s", ErrDuplicate, idx.name()) + } + } + + for _, idx := range c.indices { + idx.update(old, new) + } + + return nil +} + +func (c Collection[T]) Delete(id uint64) { + if !c.primary { + panic(ErrReadOnly) + } + + c.idLock.Lock(id) + defer c.idLock.Unlock(id) + + item, ok := c.items.Get(id) + if !ok { + return + } + + // Acquire locks and check for insert conflicts. + for _, idx := range c.uniqueIndices { + idx.lock(item) + defer idx.unlock(item) + } + + for _, idx := range c.indices { + idx.delete(item) + } +} + +func (c Collection[T]) Get(id uint64) (t T, ok bool) { + ptr, ok := c.items.Get(id) + if !ok { + return t, false + } + return *ptr, true +} + +// ---------------------------------------------------------------------------- + +func (c *Collection[T]) loadData() { + toRemove := []int{} + for i, idx := range c.indices { + if err := idx.load(c.items.m); err != nil { + log.Printf("Removing index %d because of error: %v", i, err) + toRemove = append([]int{i}, toRemove...) + } + } + + for _, i := range toRemove { + c.indices = append(c.indices[:i], c.indices[i+1:]...) + } +} + +func (c *Collection[T]) onStore(collection string, id uint64, data []byte) { + item := decode[T](data) + old, ok := c.items.Get(id) + if !ok { + // Insert. + for _, idx := range c.indices { + idx.insert(item) + } + } else { + // Otherwise update. + for _, idx := range c.indices { + idx.update(old, item) + } + } +} + +func (c *Collection[T]) onDelete(collection string, id uint64) { + item, ok := c.items.Get(id) + if !ok { + return + } + + for _, idx := range c.indices { + idx.delete(item) + } +} diff --git a/collection_test.go b/collection_test.go new file mode 100644 index 0000000..63c2751 --- /dev/null +++ b/collection_test.go @@ -0,0 +1,426 @@ +package mdb + +import ( + "errors" + "fmt" + "log" + "os" + "path/filepath" + "reflect" + "strings" + "sync" + "testing" +) + +func TestCollection(t *testing.T) { + type Item struct { + ID uint64 + Name string // Full map. + ExtID string // Partial map. + } + + sanitize := func(item *Item) { + item.Name = strings.TrimSpace(item.Name) + item.ExtID = strings.TrimSpace(item.ExtID) + } + + ErrInvalidExtID := errors.New("InvalidExtID") + + validate := func(item *Item) error { + if len(item.ExtID) != 0 && !strings.HasPrefix(item.ExtID, "x") { + return ErrInvalidExtID + } + return nil + } + + run := func(name string, inner func(t *testing.T, c *Collection[Item])) { + + t.Run(name, func(t *testing.T) { + root := filepath.Join(os.TempDir(), randString()) + //defer os.RemoveAll(root) + + db := NewPrimary(root) + defer db.Close() + + c := NewCollection(db, randString(), func(i *Item) uint64 { return i.ID }) + c.SetSanitize(sanitize) + c.SetValidate(validate) + + NewMapIndex(c, + "Name", + func(i *Item) string { return i.Name }, + nil) + + NewMapIndex(c, + "ExtID", + func(i *Item) string { return i.ExtID }, + func(i *Item) bool { return i.ExtID != "" }) + + inner(t, c) + }) + } + + verifyCollectionOnce := func(c *Collection[Item], expected ...Item) error { + if len(c.items.m) != len(expected) { + return fmt.Errorf("Expected %d items, but got %d.", len(expected), len(c.items.m)) + } + + for _, item := range expected { + i, ok := c.Get(item.ID) + if !ok { + return fmt.Errorf("Missing expected item: %v", item) + } + if !reflect.DeepEqual(i, item) { + return fmt.Errorf("Items aren't equal: %v != %v", i, item) + } + } + + return nil + } + + verifyCollection := func(c *Collection[Item], expected ...Item) error { + if err := verifyCollectionOnce(c, expected...); err != nil { + return fmt.Errorf("%w: original", err) + } + + // Reload the collection and verify again. + c.db.Close() + + db := NewSecondary(c.db.root) + c2 := NewCollection(db, c.name, func(i *Item) uint64 { return i.ID }) + db.Start() + defer db.Close() + return verifyCollectionOnce(c2, expected...) + } + + run("empty", func(t *testing.T, c *Collection[Item]) { + err := verifyCollection(c) + if err != nil { + t.Fatal(err) + } + }) + + run("check NextID", func(t *testing.T, c *Collection[Item]) { + id := c.NextID() + for i := 0; i < 100; i++ { + next := c.NextID() + if next <= id { + t.Fatal(next, id) + } + id = next + } + }) + + run("insert", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{1, "Name", "xid"}) + if err != nil { + t.Fatal(err) + } + err = verifyCollection(c, item) + if err != nil { + t.Fatal(err) + } + }) + + run("insert concurrent differnt items", func(t *testing.T, c *Collection[Item]) { + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + c.Insert(Item{ + ID: c.NextID(), + Name: fmt.Sprintf("Name.%03d", i), + ExtID: fmt.Sprintf("x.%03d", i), + }) + }(i) + } + + wg.Wait() + }) + + run("insert concurrent same item", func(t *testing.T, c *Collection[Item]) { + wg := sync.WaitGroup{} + for i := 0; i < 100; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + c.Insert(Item{ + ID: 1, + Name: "My name", + }) + }(i) + } + + wg.Wait() + + if err := verifyCollection(c, Item{1, "My name", ""}); err != nil { + t.Fatal(err) + } + }) + + run("insert invalid", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ + ID: c.NextID(), + Name: "Hello", + ExtID: "123"}) + if !errors.Is(err, ErrInvalidExtID) { + t.Fatal(item, err) + } + }) + + run("insert duplicate ID", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ + ID: c.NextID(), + Name: "Hello", + }) + if err != nil { + t.Fatal(err) + } + + item2, err := c.Insert(Item{ID: item.ID, Name: "Item"}) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err, item2) + } + }) + + run("insert duplicate name", func(t *testing.T, c *Collection[Item]) { + _, err := c.Insert(Item{ + ID: c.NextID(), + Name: "Hello", + }) + if err != nil { + t.Fatal(err) + } + + item2, err := c.Insert(Item{ID: c.NextID(), Name: "Hello"}) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err, item2) + } + }) + + run("insert duplicate ext ID", func(t *testing.T, c *Collection[Item]) { + _, err := c.Insert(Item{ + ID: c.NextID(), + Name: "Hello", + ExtID: "x1", + }) + if err != nil { + t.Fatal(err) + } + + item2, err := c.Insert(Item{ID: c.NextID(), Name: "name", ExtID: "x1"}) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err, item2) + } + }) + + run("get not found", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ + ID: c.NextID(), + Name: "Hello", + ExtID: "x1", + }) + if err != nil { + t.Fatal(err) + } + + if i, ok := c.Get(item.ID + 1); ok { + t.Fatal(i) + } + }) + + run("update", func(t *testing.T, c *Collection[Item]) { + item1, err := c.Insert(Item{ + ID: c.NextID(), + Name: "Hello", + ExtID: "x1", + }) + if err != nil { + t.Fatal(err) + } + + err = c.Update(item1.ID, func(item Item) (Item, error) { + item.Name = "name" + item.ExtID = "x88" + return item, nil + }) + if err != nil { + t.Fatal(err) + } + + err = verifyCollection(c, Item{ID: item1.ID, Name: "name", ExtID: "x88"}) + if err != nil { + t.Fatal(err) + } + }) + + run("update concurrent different items", func(t *testing.T, c *Collection[Item]) { + items := make([]Item, 10) + for i := range items { + item, err := c.Insert(Item{ID: c.NextID(), Name: randString()}) + if err != nil { + t.Fatal(err) + } + items[i] = item + } + + wg := sync.WaitGroup{} + wg.Add(10) + for i := range items { + item := items[i] + go func() { + defer wg.Done() + for x := 0; x < 100; x++ { + err := c.Update(item.ID, func(i Item) (Item, error) { + i.Name = randString() + return i, nil + }) + if err != nil { + panic(err) + } + } + }() + } + + wg.Wait() + }) + + run("update concurrent same item", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ID: c.NextID(), Name: randString()}) + if err != nil { + t.Fatal(err) + } + + wg := sync.WaitGroup{} + wg.Add(10) + for i := 0; i < 10; i++ { + go func() { + defer wg.Done() + for x := 0; x < 100; x++ { + err := c.Update(item.ID, func(i Item) (Item, error) { + i.Name = randString() + return i, nil + }) + if err != nil { + panic(err) + } + } + }() + } + + wg.Wait() + }) + + run("update not found", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ID: c.NextID(), Name: randString()}) + if err != nil { + t.Fatal(err) + } + + err = c.Update(item.ID+1, func(i Item) (Item, error) { + i.Name = randString() + return i, nil + }) + if !errors.Is(err, ErrNotFound) { + t.Fatal(err) + } + }) + + run("update mismatched IDs", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ID: c.NextID(), Name: randString()}) + if err != nil { + t.Fatal(err) + } + + err = c.Update(item.ID, func(i Item) (Item, error) { + i.ID++ + return i, nil + }) + + if !errors.Is(err, ErrMismatchedIDs) { + t.Fatal(err) + } + }) + + run("update invalid", func(t *testing.T, c *Collection[Item]) { + item, err := c.Insert(Item{ID: c.NextID(), Name: randString()}) + if err != nil { + t.Fatal(err) + } + + err = c.Update(item.ID, func(i Item) (Item, error) { + i.ExtID = "a" + return i, nil + }) + if !errors.Is(err, ErrInvalidExtID) { + t.Fatal(err) + } + }) + + run("delete", func(t *testing.T, c *Collection[Item]) { + item1, err := c.Insert(Item{c.NextID(), "name1", "x1"}) + if err != nil { + t.Fatal(err) + } + item2, err := c.Insert(Item{c.NextID(), "name2", "x2"}) + if err != nil { + t.Fatal(err) + } + item3, err := c.Insert(Item{c.NextID(), "name3", "x3"}) + if err != nil { + t.Fatal(err) + } + + c.Delete(item2.ID) + if err := verifyCollection(c, item1, item3); err != nil { + t.Fatal(err) + } + }) + + run("delete not found", func(t *testing.T, c *Collection[Item]) { + item1, err := c.Insert(Item{c.NextID(), "name1", "x1"}) + if err != nil { + t.Fatal(err) + } + item2, err := c.Insert(Item{c.NextID(), "name2", "x2"}) + if err != nil { + t.Fatal(err) + } + item3, err := c.Insert(Item{c.NextID(), "name3", "x3"}) + if err != nil { + t.Fatal(err) + } + + c.Delete(c.NextID()) + if err := verifyCollection(c, item1, item2, item3); err != nil { + t.Fatal(err) + } + }) +} + +func BenchmarkLoad(b *testing.B) { + type Item struct { + ID uint64 + Name string + } + + root := filepath.Join("test-files", randString()) + db := NewPrimary(root) + getID := func(item *Item) uint64 { return item.ID } + + c := NewCollection(db, "items", getID) + + for i := 0; i < b.N; i++ { + item := Item{ID: c.NextID(), Name: fmt.Sprintf("Name %04d", i)} + c.Insert(item) + } + + b.ResetTimer() + + c2 := NewCollection(db, "items", getID) + log.Print(len(c2.items.m)) + if len(c2.items.m) != b.N { + panic("What?") + } +} diff --git a/database.go b/database.go new file mode 100644 index 0000000..25e2ee7 --- /dev/null +++ b/database.go @@ -0,0 +1,84 @@ +package mdb + +import ( + "net" + "os" + "sync" + + "git.crumpington.com/private/mdb/kvstore" +) + +type Database struct { + root string + kv *kvstore.KV + + collections map[string]dbCollection +} + +func NewPrimary(root string) *Database { + must(os.MkdirAll(root, 0700)) + db := &Database{ + root: root, + collections: map[string]dbCollection{}, + } + db.kv = kvstore.NewPrimary(root) + return db +} + +func NewSecondary(root string) *Database { + must(os.MkdirAll(root, 0700)) + db := &Database{ + root: root, + collections: map[string]dbCollection{}, + } + db.kv = kvstore.NewSecondary(root, db.onStore, db.onDelete) + return db +} + +func (db *Database) Start() { + wg := sync.WaitGroup{} + for _, c := range db.collections { + wg.Add(1) + go func(c dbCollection) { + defer wg.Done() + c.loadData() + }(c) + } + wg.Wait() +} + +func (db *Database) WALStatus() (ws WALStatus) { + ws.MaxID = db.kv.WALMaxSeqNum() + ws.MaxAppliedID = db.kv.MaxSeqNum() + return +} + +func (db *Database) Close() { + db.kv.Close() +} + +// ---------------------------------------------------------------------------- + +func (db *Database) onStore(collection string, id uint64, data []byte) { + c, ok := db.collections[collection] + if ok { + c.onStore(collection, id, data) + } +} + +func (db *Database) onDelete(collection string, id uint64) { + c, ok := db.collections[collection] + if ok { + c.onDelete(collection, id) + } +} + +// ---------------------------------------------------------------------------- + +func (db *Database) SyncSend(conn net.Conn) { + db.kv.SyncSend(conn) +} + +func (db *Database) SyncRecv(conn net.Conn) { + db.kv.SyncRecv(conn) +} diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..6c8376d --- /dev/null +++ b/database_test.go @@ -0,0 +1,13 @@ +package mdb + +import "time" + +func (db *Database) waitForWAL() { + for { + status := db.WALStatus() + if status.MaxAppliedID == status.MaxID { + return + } + time.Sleep(100 * time.Millisecond) + } +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..7541dca --- /dev/null +++ b/errors.go @@ -0,0 +1,13 @@ +package mdb + +import "errors" + +var ( + ErrMismatchedIDs = errors.New("MismatchedIDs") + ErrDuplicate = errors.New("Duplicate") + ErrNotFound = errors.New("NotFound") + ErrReadOnly = errors.New("ReadOnly") + + // Return in update function to abort changes. + ErrAbortUpdate = errors.New("AbortUpdate") +) diff --git a/itemmap.go b/itemmap.go new file mode 100644 index 0000000..1d561f6 --- /dev/null +++ b/itemmap.go @@ -0,0 +1,106 @@ +package mdb + +import ( + "math/rand" + + "sync" + "sync/atomic" + + "git.crumpington.com/private/mdb/keyedmutex" + "git.crumpington.com/private/mdb/kvstore" +) + +// Implements ItemMap and ItemUniqueIndex interfaces. +type itemMap[T any] struct { + primary bool + kv *kvstore.KV + collection string + idLock keyedmutex.KeyedMutex[uint64] + mapLock sync.Mutex + m map[uint64]*T + getID func(*T) uint64 + maxID uint64 +} + +func newItemMap[T any](kv *kvstore.KV, collection string, getID func(*T) uint64) *itemMap[T] { + m := &itemMap[T]{ + primary: kv.Primary(), + kv: kv, + collection: collection, + idLock: keyedmutex.New[uint64](), + m: map[uint64]*T{}, + getID: getID, + } + + kv.Iterate(collection, func(id uint64, data []byte) { + item := decode[T](data) + if id > m.maxID { + m.maxID = id + } + m.m[id] = item + }) + + return m +} + +func (m *itemMap[T]) load(src map[uint64]*T) error { + // No-op: The itemmap is the source for loading all other indices. + return nil +} + +// ---------------------------------------------------------------------------- + +func (m *itemMap[T]) Get(id uint64) (*T, bool) { + return m.mapGet(id) +} + +// Should hold item lock when calling. +func (idx *itemMap[T]) insert(item *T) { + id := idx.getID(item) + if idx.primary { + idx.kv.Store(idx.collection, id, encode(item)) + } + idx.mapSet(id, item) +} + +// Should hold item lock when calling. old and new MUST have the same ID. +func (idx *itemMap[T]) update(old, new *T) { + idx.insert(new) +} + +// Should hold item lock when calling. +func (idx *itemMap[T]) delete(item *T) { + id := idx.getID(item) + if idx.primary { + idx.kv.Delete(idx.collection, id) + } + idx.mapDelete(id) +} + +// ---------------------------------------------------------------------------- + +func (idx *itemMap[T]) mapSet(id uint64, item *T) { + idx.mapLock.Lock() + idx.m[id] = item + idx.mapLock.Unlock() +} + +func (idx *itemMap[T]) mapDelete(id uint64) { + idx.mapLock.Lock() + delete(idx.m, id) + idx.mapLock.Unlock() +} + +func (idx *itemMap[T]) mapGet(id uint64) (*T, bool) { + idx.mapLock.Lock() + item, ok := idx.m[id] + idx.mapLock.Unlock() + return item, ok +} + +// ---------------------------------------------------------------------------- + +func (idx *itemMap[T]) nextID() uint64 { + n := rand.Int63n(256) + return atomic.AddUint64(&idx.maxID, uint64(n)) +} diff --git a/mapindex.go b/mapindex.go new file mode 100644 index 0000000..9fe01e4 --- /dev/null +++ b/mapindex.go @@ -0,0 +1,173 @@ +package mdb + +import ( + "sync" + + "git.crumpington.com/private/mdb/keyedmutex" +) + +type MapIndex[K comparable, T any] struct { + c *Collection[T] + _name string + keyLock keyedmutex.KeyedMutex[K] + mapLock sync.Mutex + m map[K]*T + getID func(*T) uint64 + getKey func(*T) K + include func(*T) bool +} + +func NewMapIndex[K comparable, T any]( + c *Collection[T], + name string, + getKey func(*T) K, + include func(*T) bool, +) *MapIndex[K, T] { + m := &MapIndex[K, T]{ + c: c, + _name: name, + keyLock: keyedmutex.New[K](), + m: map[K]*T{}, + getID: c.getID, + getKey: getKey, + include: include, + } + + c.indices = append(c.indices, m) + c.uniqueIndices = append(c.uniqueIndices, m) + return m +} + +// ---------------------------------------------------------------------------- + +func (m *MapIndex[K, T]) load(src map[uint64]*T) error { + for _, item := range src { + if m.include != nil && !m.include(item) { + continue + } + + k := m.getKey(item) + if _, ok := m.m[k]; ok { + return ErrDuplicate + } + + m.m[k] = item + } + + return nil +} + +func (m *MapIndex[K, T]) Get(k K) (t T, ok bool) { + if item, ok := m.mapGet(k); ok { + return *item, true + } + + return t, false +} + +func (m *MapIndex[K, T]) Update(k K, update func(T) (T, error)) error { + wrapped := func(item T) (T, error) { + new, err := update(item) + if err != nil { + return item, err + } + + if m.getKey(&new) != k { + return item, ErrMismatchedIDs + } + + return new, nil + } + + if item, ok := m.mapGet(k); ok { + return m.c.Update(m.getID(item), wrapped) + } + + return ErrNotFound +} + +func (m *MapIndex[K, T]) Delete(k K) { + if item, ok := m.mapGet(k); ok { + m.c.Delete(m.getID(item)) + } +} + +// ---------------------------------------------------------------------------- + +func (m *MapIndex[K, T]) insert(item *T) { + if m.include == nil || m.include(item) { + m.mapSet(m.getKey(item), item) + } +} + +func (m *MapIndex[K, T]) update(old, new *T) { + if m.include == nil || m.include(old) { + m.mapDelete(m.getKey(old)) + } + if m.include == nil || m.include(new) { + m.mapSet(m.getKey(new), new) + } +} + +func (m *MapIndex[K, T]) delete(item *T) { + if m.include == nil || m.include(item) { + m.mapDelete(m.getKey(item)) + } +} + +// ---------------------------------------------------------------------------- + +func (idx *MapIndex[K, T]) name() string { + return idx._name +} + +func (idx *MapIndex[K, T]) lock(item *T) { + if idx.include == nil || idx.include(item) { + idx.keyLock.Lock(idx.getKey(item)) + } +} + +func (idx *MapIndex[K, T]) unlock(item *T) { + if idx.include == nil || idx.include(item) { + idx.keyLock.Unlock(idx.getKey(item)) + } +} + +// Should hold item lock when calling. +func (idx *MapIndex[K, T]) insertConflict(new *T) bool { + if idx.include != nil && !idx.include(new) { + return false + } + _, ok := idx.Get(idx.getKey(new)) + return ok +} + +// Should hold item lock when calling. +func (idx *MapIndex[K, T]) updateConflict(new *T) bool { + if idx.include != nil && !idx.include(new) { + return false + } + val, ok := idx.mapGet(idx.getKey(new)) + return ok && idx.getID(val) != idx.getID(new) +} + +// ---------------------------------------------------------------------------- + +func (idx *MapIndex[K, T]) mapSet(k K, t *T) { + idx.mapLock.Lock() + idx.m[k] = t + idx.mapLock.Unlock() +} + +func (idx *MapIndex[K, T]) mapDelete(k K) { + idx.mapLock.Lock() + delete(idx.m, k) + idx.mapLock.Unlock() +} + +func (idx *MapIndex[K, T]) mapGet(k K) (*T, bool) { + idx.mapLock.Lock() + t, ok := idx.m[k] + idx.mapLock.Unlock() + return t, ok +} diff --git a/mapindex_test.go b/mapindex_test.go new file mode 100644 index 0000000..d78ce91 --- /dev/null +++ b/mapindex_test.go @@ -0,0 +1,330 @@ +package mdb + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "reflect" + "testing" +) + +func TestMapIndex(t *testing.T) { + type Item struct { + ID uint64 + Name string + ExtID string + } + + checkIdxOne := func(idx *MapIndex[string, Item], expectedList ...Item) error { + expected := make(map[string]Item, len(expectedList)) + for _, i := range expectedList { + expected[i.ExtID] = i + } + + if len(expected) != len(idx.m) { + return fmt.Errorf("Expected %d items, but got %d.", len(expected), len(idx.m)) + } + + for _, e := range expected { + i, ok := idx.Get(e.ExtID) + if !ok { + return fmt.Errorf("Missing item: %v", e) + } + if !reflect.DeepEqual(i, e) { + return fmt.Errorf("Items not equal: %v != %v", i, e) + } + } + return nil + } + + checkIdx := func(idx *MapIndex[string, Item], expectedList ...Item) error { + idx.c.db.waitForWAL() + + if err := checkIdxOne(idx, expectedList...); err != nil { + return fmt.Errorf("%w: original", err) + } + + // Reload the database, collection, and index and re-test. + db := NewPrimary(idx.c.db.root) + c := NewCollection(db, "collection", func(i *Item) uint64 { return i.ID }) + idx = NewMapIndex(c, + "ExtID", + func(i *Item) string { return i.ExtID }, + func(i *Item) bool { return i.ExtID != "" }) + db.Start() + return checkIdxOne(idx, expectedList...) + } + + run := func(name string, inner func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item])) { + t.Run(name, func(t *testing.T) { + root := filepath.Join(os.TempDir(), randString()) + defer os.RemoveAll(root) + + db := NewPrimary(root) + defer db.Close() + c := NewCollection(db, "collection", func(i *Item) uint64 { return i.ID }) + idx := NewMapIndex(c, + "ExtID", + func(i *Item) string { return i.ExtID }, + func(i *Item) bool { return i.ExtID != "" }) + db.Start() + inner(t, c, idx) + }) + } + + run("insert item not in index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + item := Item{4, "4", ""} + c.Insert(item) + if err := checkIdx(idx); err != nil { + t.Fatal(err) + } + }) + + run("insert item in index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + item1 := Item{4, "4", ""} + item2 := Item{5, "5", "abcd"} + c.Insert(item1) + c.Insert(item2) + if err := checkIdx(idx, item2); err != nil { + t.Fatal(err) + } + }) + + run("insert several items", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + item1 := Item{4, "4", ""} + item2 := Item{5, "5", "abcd"} + item3 := Item{6, "6", ""} + item4 := Item{7, "7", "xyz"} + item5 := Item{8, "8", ""} + item6 := Item{9, "9", "mmm"} + c.Insert(item1) + c.Insert(item2) + c.Insert(item3) + c.Insert(item4) + c.Insert(item5) + c.Insert(item6) + if err := checkIdx(idx, item2, item4, item6); err != nil { + t.Fatal(err) + } + }) + + run("insert with conflict", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + item1 := Item{1, "1", "one"} + item2 := Item{2, "2", "one"} + c.Insert(item1) + if _, err := c.Insert(item2); !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + }) + + run("update into index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + item1 := Item{1, "1", ""} + c.Insert(item1) + + if err := checkIdx(idx); err != nil { + t.Fatal(err) + } + + err := c.Update(1, func(i Item) (Item, error) { + i.ExtID = "xx" + return i, nil + }) + if err != nil { + t.Fatal(err) + } + + item1.ExtID = "xx" + if err := checkIdx(idx, item1); err != nil { + t.Fatal(err) + } + }) + + run("update out of index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", ""}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + + err := c.Update(1, func(in Item) (Item, error) { + in.Name = "ONE" + return in, nil + }) + if err != nil { + t.Fatal(err) + } + + if err := checkIdx(idx, Item{2, "2", "two"}); err != nil { + t.Fatal(err) + } + }) + + run("update out of index conflict", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", ""}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + + err := c.Update(1, func(in Item) (Item, error) { + in.ExtID = "two" + return in, nil + }) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + }) + + run("update within index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + + err := c.Update(2, func(in Item) (Item, error) { + in.ExtID = "TWO" + return in, nil + }) + if err != nil { + t.Fatal(err) + } + + if err := checkIdx(idx, Item{1, "1", "one"}, Item{2, "2", "TWO"}); err != nil { + t.Fatal(err) + } + }) + + run("update using index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + + err := idx.Update("one", func(in Item) (Item, error) { + in.Name = "_1_" + return in, nil + }) + if err != nil { + t.Fatal(err) + } + + if err := checkIdx(idx, Item{1, "_1_", "one"}, Item{2, "2", "two"}); err != nil { + t.Fatal(err) + } + }) + + run("update using index not found", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{3, "3", ""}) + + err := idx.Update("onex", func(in Item) (Item, error) { + in.Name = "_1_" + return in, nil + }) + if !errors.Is(err, ErrNotFound) { + t.Fatal(err) + } + }) + + run("update using index caller error", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{3, "3", ""}) + + myErr := errors.New("Mine") + + err := idx.Update("one", func(in Item) (Item, error) { + in.Name = "_1_" + return in, myErr + }) + if !errors.Is(err, myErr) { + t.Fatal(err) + } + }) + + run("update using index mismatched IDs", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + + err := idx.Update("one", func(in Item) (Item, error) { + in.ExtID = "onex" + return in, nil + }) + if !errors.Is(err, ErrMismatchedIDs) { + t.Fatal(err) + } + }) + + run("delete out of index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + c.Delete(3) + + if err := checkIdx(idx, Item{1, "1", "one"}, Item{2, "2", "two"}); err != nil { + t.Fatal(err) + } + }) + + run("delete from index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + c.Delete(2) + + if err := checkIdx(idx, Item{1, "1", "one"}); err != nil { + t.Fatal(err) + } + }) + + run("delete using index", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + idx.Delete("two") + + if err := checkIdx(idx, Item{1, "1", "one"}); err != nil { + t.Fatal(err) + } + }) + + run("delete using index not found", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + c.Insert(Item{1, "1", "one"}) + c.Insert(Item{2, "2", "two"}) // In index. + c.Insert(Item{3, "3", ""}) + idx.Delete("onex") + + if err := checkIdx(idx, Item{1, "1", "one"}, Item{2, "2", "two"}); err != nil { + t.Fatal(err) + } + }) + + run("check name", func(t *testing.T, c *Collection[Item], idx *MapIndex[string, Item]) { + if idx.name() != "ExtID" { + t.Fatal(idx.name()) + } + }) +} + +func TestMapIndexLoadError(t *testing.T) { + type Item struct { + ID uint64 + Name string + ExtID string + } + + root := filepath.Join(os.TempDir(), randString()) + defer os.RemoveAll(root) + + db := NewPrimary(root) + c := NewCollection(db, "collection", func(i *Item) uint64 { return i.ID }) + db.Start() + defer db.Close() + + c.Insert(Item{1, "one", "x"}) + c.Insert(Item{2, "two", "x"}) + c.Insert(Item{3, "three", "y"}) + c.Insert(Item{4, "x", ""}) + + idx := NewMapIndex(c, + "ExtID", + func(i *Item) string { return i.ExtID }, + func(i *Item) bool { return i.ExtID != "" }) + err := idx.load(c.items.m) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } +} diff --git a/shipping_test.go b/shipping_test.go new file mode 100644 index 0000000..18563e6 --- /dev/null +++ b/shipping_test.go @@ -0,0 +1,81 @@ +package mdb + +import ( + "net" + "os" + "path/filepath" + "testing" +) + +func TestLogShip(t *testing.T) { + type Item struct { + ID uint64 + Name string + } + + newDB := func(root string, primary bool) (*Database, *Collection[Item]) { + var db *Database + if primary { + db = NewPrimary(root) + } else { + db = NewSecondary(root) + } + c := NewCollection(db, "collection", func(i *Item) uint64 { return i.ID }) + NewBTreeIndex(c, + func(i, j *Item) bool { return i.Name < j.Name }, + func(i *Item) bool { return i.Name != "" }) + return db, c + } + + root1 := filepath.Join(os.TempDir(), randString()) + root2 := filepath.Join(os.TempDir(), randString()) + //log.Print(root1, " --> ", root2) + defer os.RemoveAll(root1) + defer os.RemoveAll(root2) + + dbLeader, colLeader := newDB(root1, true) + dbLeader.Start() + defer dbLeader.Close() + + dbFollower, _ := newDB(root2, false) + dbFollower.Start() + defer dbFollower.Close() + + c1, c2 := net.Pipe() + go dbLeader.SyncSend(c1) + go dbFollower.SyncRecv(c2) + + item1 := Item{1, "one"} + item2 := Item{2, ""} + item3 := Item{3, "three"} + item4 := Item{4, ""} + item5 := Item{5, "five"} + + item1, _ = colLeader.Insert(item1) + item2, _ = colLeader.Insert(item2) + item3, _ = colLeader.Insert(item3) + item4, _ = colLeader.Insert(item4) + item5, _ = colLeader.Insert(item5) + colLeader.Delete(item2.ID) + + colLeader.Update(item4.ID, func(old Item) (Item, error) { + old.Name = "UPDATED" + return old, nil + }) + + dbLeader.waitForWAL() + dbFollower.waitForWAL() + + dbLeader, colLeader = newDB(root1, true) + dbLeader.Start() + dbFollower, colFollower := newDB(root2, false) + dbFollower.Start() + + m1 := colLeader.items.m + m2 := colFollower.items.m + + if len(m1) != len(m2) { + t.Fatal(m1, m2) + } + +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..3600684 --- /dev/null +++ b/types.go @@ -0,0 +1,28 @@ +package mdb + +type WALStatus struct { + MaxID uint64 // TODO: WALMaxSeqNum + MaxAppliedID uint64 // TODO: KVMaxSeqNum +} + +type itemIndex[T any] interface { + load(m map[uint64]*T) error + insert(*T) + update(old, new *T) // Old and new MUST have the same ID. + delete(*T) +} + +type itemUniqueIndex[T any] interface { + load(m map[uint64]*T) error + name() string + lock(*T) + unlock(*T) + insertConflict(*T) bool + updateConflict(*T) bool +} + +type dbCollection interface { + loadData() + onStore(string, uint64, []byte) // For WAL following. + onDelete(string, uint64) // For WAL following. +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..13a87ec --- /dev/null +++ b/util.go @@ -0,0 +1,7 @@ +package mdb + +func must(err error) { + if err != nil { + panic(err) + } +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..08fc1e7 --- /dev/null +++ b/util_test.go @@ -0,0 +1,15 @@ +package mdb + +import ( + "crypto/rand" + "encoding/hex" + mrand "math/rand" +) + +func randString() string { + buf := make([]byte, 1+mrand.Intn(10)) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + return hex.EncodeToString(buf) +}