what happened?

main v1.2.6
jdl 2023-11-15 12:29:33 +01:00
parent 6cf0ea3c5d
commit 5641a89fee
45 changed files with 4573 additions and 0 deletions

26
LICENSE Normal file
View File

@ -0,0 +1,26 @@
Copyright 2022 John David Lee
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its contributors
may be used to endorse or promote products derived from this software without
specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

View File

@ -1,2 +1,3 @@
# mdb
An in-process, in-memory database for Go.

165
btreeindex.go Normal file
View File

@ -0,0 +1,165 @@
package mdb
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"sync"
"sync/atomic"
"github.com/google/btree"
)
type BTreeIndex[T any] struct {
c *Collection[T]
modLock sync.Mutex
bt atomic.Value // *btree.BTreeG[*T]
getID func(*T) uint64
less func(*T, *T) bool
include func(*T) bool
}
func NewBTreeIndex[T any](
c *Collection[T],
less func(*T, *T) bool,
include func(*T) bool,
) *BTreeIndex[T] {
t := &BTreeIndex[T]{
c: c,
getID: c.getID,
less: less,
include: include,
}
btree := btree.NewG(32, less)
t.bt.Store(btree)
c.indices = append(c.indices, t)
return t
}
func (t *BTreeIndex[T]) load(m map[uint64]*T) error {
btree := btree.NewG(64, t.less)
t.bt.Store(btree)
for _, item := range m {
if t.include == nil || t.include(item) {
if x, _ := btree.ReplaceOrInsert(item); x != nil {
return ErrDuplicate
}
}
}
return nil
}
// ----------------------------------------------------------------------------
func (t *BTreeIndex[T]) Ascend() *BTreeIterator[T] {
iter := newBTreeIterator[T]()
go func() {
t.btree().Ascend(iter.each)
iter.done()
}()
return iter
}
func (t *BTreeIndex[T]) AscendAfter(pivot T) *BTreeIterator[T] {
iter := newBTreeIterator[T]()
go func() {
t.btree().AscendGreaterOrEqual(&pivot, iter.each)
iter.done()
}()
return iter
}
func (t *BTreeIndex[T]) Descend() *BTreeIterator[T] {
iter := newBTreeIterator[T]()
go func() {
t.btree().Descend(iter.each)
iter.done()
}()
return iter
}
func (t *BTreeIndex[T]) DescendAfter(pivot T) *BTreeIterator[T] {
iter := newBTreeIterator[T]()
go func() {
t.btree().DescendLessOrEqual(&pivot, iter.each)
iter.done()
}()
return iter
}
func (t *BTreeIndex[T]) Get(item T) (T, bool) {
ptr, ok := t.btree().Get(&item)
if !ok {
return item, false
}
return *ptr, true
}
func (t *BTreeIndex[T]) Min() (item T, ok bool) {
if ptr, ok := t.btree().Min(); ok {
return *ptr, ok
}
return item, false
}
func (t *BTreeIndex[T]) Max() (item T, ok bool) {
if ptr, ok := t.btree().Max(); ok {
return *ptr, ok
}
return item, false
}
func (t *BTreeIndex[T]) Len() int {
return t.btree().Len()
}
// ----------------------------------------------------------------------------
func (t *BTreeIndex[T]) insert(item *T) {
if t.include == nil || t.include(item) {
t.modify(func(bt *btree.BTreeG[*T]) {
bt.ReplaceOrInsert(item)
})
}
}
func (t *BTreeIndex[T]) update(old, new *T) {
t.modify(func(bt *btree.BTreeG[*T]) {
if t.include == nil || t.include(old) {
bt.Delete(old)
}
if t.include == nil || t.include(new) {
bt.ReplaceOrInsert(new)
}
})
}
func (t *BTreeIndex[T]) delete(item *T) {
if t.include == nil || t.include(item) {
t.modify(func(bt *btree.BTreeG[*T]) {
bt.Delete(item)
})
}
}
// ----------------------------------------------------------------------------
func (t *BTreeIndex[T]) btree() *btree.BTreeG[*T] {
return t.bt.Load().(*btree.BTreeG[*T])
}
func (t *BTreeIndex[T]) modify(mod func(clone *btree.BTreeG[*T])) {
t.modLock.Lock()
defer t.modLock.Unlock()
clone := t.btree().Clone()
mod(clone)
t.bt.Store(clone)
}

144
btreeindex_ex_test.go Normal file
View File

@ -0,0 +1,144 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"fmt"
"reflect"
)
func (bt *BTreeIndex[T]) Equals(rhs *BTreeIndex[T]) error {
if bt.Len() != rhs.Len() {
return fmt.Errorf("Expected %d items, but found %d.", bt.Len(), rhs.Len())
}
it1 := bt.Ascend()
defer it1.Close()
it2 := rhs.Ascend()
defer it2.Close()
for it1.Next() {
it2.Next()
v1 := it1.Value()
v2 := it2.Value()
if !reflect.DeepEqual(v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", v1, v2)
}
}
return nil
}
func (bt *BTreeIndex[T]) EqualsList(data []*T) error {
if bt.Len() != len(data) {
return fmt.Errorf("Expected %d items, but found %d.", bt.Len(), len(data))
}
if len(data) == 0 {
return nil
}
// Ascend fully.
it := bt.Ascend()
for _, v1 := range data {
it.Next()
v2 := it.Value()
if !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
}
if it.Next() {
return fmt.Errorf("Next returned true after full ascend.")
}
it.Close()
// Descend fully.
it = bt.Descend()
dataList := data
for len(dataList) > 0 {
v1 := dataList[len(dataList)-1]
dataList = dataList[:len(dataList)-1]
it.Next()
v2 := it.Value()
if !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
}
if it.Next() {
return fmt.Errorf("Next returned true after full descend.")
}
it.Close()
// AscendAfter
dataList = data
for len(dataList) > 1 {
dataList = dataList[1:]
it = bt.AscendAfter(*dataList[0])
for _, v1 := range dataList {
it.Next()
v2 := it.Value()
if !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
}
if it.Next() {
return fmt.Errorf("Next returned true after partial ascend.")
}
it.Close()
}
// DescendAfter
dataList = data
for len(dataList) > 1 {
dataList = dataList[:len(dataList)-1]
it = bt.DescendAfter(*dataList[len(dataList)-1])
for i := len(dataList) - 1; i >= 0; i-- {
v1 := dataList[i]
it.Next()
v2 := it.Value()
if !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
}
if it.Next() {
return fmt.Errorf("Next returned true after partial descend: %#v", it.Value())
}
it.Close()
}
// Using Get.
for _, v1 := range data {
v2, ok := bt.Get(*v1)
if !ok || !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
}
// Min.
v1 := data[0]
v2, ok := bt.Min()
if !ok || !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
// Max.
v1 = data[len(data)-1]
v2, ok = bt.Max()
if !ok || !reflect.DeepEqual(*v1, v2) {
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
}
return nil
}

318
btreeindex_test.go Normal file
View File

@ -0,0 +1,318 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"errors"
"reflect"
"testing"
)
func TestFullBTreeIndex(t *testing.T) {
// Test against the email index.
run := func(name string, inner func(t *testing.T, db *DB) []*User) {
testWithDB(t, name, func(t *testing.T, db *DB) {
expected := inner(t, db)
if err := db.Users.emailBTree.EqualsList(expected); err != nil {
t.Fatal(err)
}
db.Close()
db = OpenDB(db.root, true)
if err := db.Users.emailBTree.EqualsList(expected); err != nil {
t.Fatal(err)
}
})
}
run("insert", func(t *testing.T, db *DB) (users []*User) {
users = append(users,
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"})
for _, u := range users {
u2, err := db.Users.c.Insert(*u)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(u2, *u) {
t.Fatal(u2, *u)
}
}
return users
})
run("insert duplicate", func(t *testing.T, db *DB) (users []*User) {
users = append(users,
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "ccc"})
u1, err := db.Users.c.Insert(*users[0])
if err != nil {
t.Fatal(err)
}
u2, err := db.Users.c.Insert(*users[1])
if !errors.Is(err, ErrDuplicate) {
t.Fatal(u1, u2, err)
}
return users[:1]
})
run("update", func(t *testing.T, db *DB) (users []*User) {
users = append(users,
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
&User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"})
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users[2].ID, func(u User) (User, error) {
u.Email = "g@h.com"
return u, nil
})
if err != nil {
t.Fatal(err)
}
users[2].Email = "g@h.com"
return users
})
run("delete", func(t *testing.T, db *DB) (users []*User) {
users = append(users,
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"},
&User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"})
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
db.Users.c.Delete(users[0].ID)
users = users[1:]
return users
})
run("get not found", func(t *testing.T, db *DB) (users []*User) {
users = append(users,
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"},
&User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"})
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
if u, ok := db.Users.emailBTree.Get(User{Email: "g@h.com"}); ok {
t.Fatal(u, ok)
}
return users
})
run("min/max empty", func(t *testing.T, db *DB) (users []*User) {
if u, ok := db.Users.emailBTree.Min(); ok {
t.Fatal(u, ok)
}
if u, ok := db.Users.emailBTree.Max(); ok {
t.Fatal(u, ok)
}
return users
})
}
func TestPartialBTreeIndex(t *testing.T) {
// Test against the extID btree index.
run := func(name string, inner func(t *testing.T, db *DB) []*User) {
testWithDB(t, name, func(t *testing.T, db *DB) {
expected := inner(t, db)
if err := db.Users.extIDBTree.EqualsList(expected); err != nil {
t.Fatal(err)
}
db.Close()
db = OpenDB(db.root, true)
if err := db.Users.extIDBTree.EqualsList(expected); err != nil {
t.Fatal(err)
}
})
}
run("insert out", func(t *testing.T, db *DB) []*User {
users := []*User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "xxx"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ggg"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "aaa"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
return []*User{}
})
run("insert in", func(t *testing.T, db *DB) []*User {
users := []*User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "xxx"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ggg", ExtID: "x"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "aaa"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
return []*User{users[1]}
})
run("update out to out", func(t *testing.T, db *DB) []*User {
users := []*User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users[0].ID, func(u User) (User, error) {
u.Name = "axa"
users[0].Name = "axa"
return u, nil
})
if err != nil {
t.Fatal(err)
}
return []*User{users[1], users[3]}
})
run("update in to in", func(t *testing.T, db *DB) []*User {
users := []*User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users[1].ID, func(u User) (User, error) {
u.ExtID = "C"
users[1].ExtID = "C"
return u, nil
})
if err != nil {
t.Fatal(err)
}
return []*User{users[3], users[1]}
})
run("update out to in", func(t *testing.T, db *DB) []*User {
users := []*User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users[2].ID, func(u User) (User, error) {
u.ExtID = "C"
users[2].ExtID = "C"
return u, nil
})
if err != nil {
t.Fatal(err)
}
return []*User{users[1], users[3], users[2]}
})
run("update in to out", func(t *testing.T, db *DB) []*User {
users := []*User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users[1].ID, func(u User) (User, error) {
u.ExtID = ""
users[1].ExtID = ""
return u, nil
})
if err != nil {
t.Fatal(err)
}
return []*User{users[3]}
})
}
func TestBTreeIndex_load_ErrDuplicate(t *testing.T) {
testWithDB(t, "", func(t *testing.T, db *DB) {
idx := NewBTreeIndex(
db.Users.c,
func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID },
nil)
users := map[uint64]*User{
1: {ID: 1, Email: "x@y.com", Name: "xx", ExtID: "x"},
2: {ID: 2, Email: "a@b.com", Name: "aa", ExtID: "x"},
}
if err := idx.load(users); err != ErrDuplicate {
t.Fatal(err)
}
})
}

51
btreeiterator.go Normal file
View File

@ -0,0 +1,51 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
type BTreeIterator[T any] struct {
out chan *T
close chan struct{}
current *T
}
func newBTreeIterator[T any]() *BTreeIterator[T] {
return &BTreeIterator[T]{
out: make(chan *T),
close: make(chan struct{}, 1),
}
}
func (iter *BTreeIterator[T]) each(i *T) bool {
select {
case iter.out <- i:
return true
case <-iter.close:
return false
}
}
func (iter *BTreeIterator[T]) done() {
close(iter.out)
}
func (iter *BTreeIterator[T]) Next() bool {
val, ok := <-iter.out
if ok {
iter.current = val
return true
}
return false
}
func (iter *BTreeIterator[T]) Close() {
iter.close <- struct{}{}
}
func (iter *BTreeIterator[T]) Value() T {
return *iter.current
}

49
btreeiterator_test.go Normal file
View File

@ -0,0 +1,49 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"reflect"
"testing"
)
func TestBTreeIterator(t *testing.T) {
testWithDB(t, "min/max matches iterator", func(t *testing.T, db *DB) {
users := []User{
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "bbb"},
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(u); err != nil {
t.Fatal(err)
}
}
it := db.Users.nameBTree.Ascend()
if !it.Next() {
t.Fatal("No next")
}
u := it.Value()
if !reflect.DeepEqual(u, users[0]) {
t.Fatal(u, users[0])
}
it.Close()
it = db.Users.nameBTree.Descend()
if !it.Next() {
t.Fatal("No next")
}
u = it.Value()
if !reflect.DeepEqual(u, users[2]) {
t.Fatal(u, users[2])
}
it.Close()
})
}

27
codec.go Normal file
View File

@ -0,0 +1,27 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"bytes"
"encoding/json"
"git.crumpington.com/public/mdb/kvstore"
)
func decode[T any](data []byte) *T {
item := new(T)
must(json.Unmarshal(data, item))
return item
}
func encode(item any) []byte {
w := bytes.NewBuffer(kvstore.GetDataBuf(0)[:0])
must(json.NewEncoder(w).Encode(item))
return w.Bytes()
}

212
collection.go Normal file
View File

@ -0,0 +1,212 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"fmt"
"git.crumpington.com/public/mdb/keyedmutex"
)
type Collection[T any] struct {
primary bool
db *Database
name string
idLock keyedmutex.KeyedMutex[uint64]
items *itemMap[T]
indices []itemIndex[T]
uniqueIndices []itemUniqueIndex[T]
getID func(*T) uint64
sanitize func(*T)
validate func(*T) error
}
func NewCollection[T any](db *Database, name string, getID func(*T) uint64) *Collection[T] {
items := newItemMap(db.kv, name, getID)
c := &Collection[T]{
primary: db.kv.Primary(),
db: db,
name: name,
idLock: keyedmutex.New[uint64](),
items: items,
indices: []itemIndex[T]{items},
uniqueIndices: []itemUniqueIndex[T]{},
getID: items.getID,
sanitize: func(*T) {},
validate: func(*T) error { return nil },
}
db.collections[name] = c
return c
}
func (c *Collection[T]) SetSanitize(sanitize func(*T)) {
c.sanitize = sanitize
}
func (c *Collection[T]) SetValidate(validate func(*T) error) {
c.validate = validate
}
// ----------------------------------------------------------------------------
func (c *Collection[T]) NextID() uint64 {
return c.items.nextID()
}
// ----------------------------------------------------------------------------
func (c *Collection[T]) Insert(item T) (T, error) {
if !c.primary {
return item, ErrReadOnly
}
c.sanitize(&item)
if err := c.validate(&item); err != nil {
return item, err
}
id := c.getID(&item)
c.idLock.Lock(id)
defer c.idLock.Unlock(id)
if _, ok := c.items.Get(id); ok {
return item, fmt.Errorf("%w: ID", ErrDuplicate)
}
// Acquire locks and check for insert conflicts.
for _, idx := range c.uniqueIndices {
idx.lock(&item)
defer idx.unlock(&item)
if idx.insertConflict(&item) {
return item, fmt.Errorf("%w: %s", ErrDuplicate, idx.name())
}
}
for _, idx := range c.indices {
idx.insert(&item)
}
return item, nil
}
func (c *Collection[T]) Update(id uint64, update func(T) (T, error)) error {
if !c.primary {
return ErrReadOnly
}
c.idLock.Lock(id)
defer c.idLock.Unlock(id)
old, ok := c.items.Get(id)
if !ok {
return ErrNotFound
}
newItem, err := update(*old)
if err != nil {
if err == ErrAbortUpdate {
return nil
}
return err
}
new := &newItem
if c.getID(new) != id {
return ErrMismatchedIDs
}
c.sanitize(new)
if err := c.validate(new); err != nil {
return err
}
// Acquire locks and check for update conflicts.
for _, idx := range c.uniqueIndices {
idx.lock(new)
defer idx.unlock(new)
if idx.updateConflict(new) {
return fmt.Errorf("%w: %s", ErrDuplicate, idx.name())
}
}
for _, idx := range c.indices {
idx.update(old, new)
}
return nil
}
func (c Collection[T]) Delete(id uint64) {
if !c.primary {
panic(ErrReadOnly)
}
c.idLock.Lock(id)
defer c.idLock.Unlock(id)
item, ok := c.items.Get(id)
if !ok {
return
}
// Acquire locks and check for insert conflicts.
for _, idx := range c.uniqueIndices {
idx.lock(item)
defer idx.unlock(item)
}
for _, idx := range c.indices {
idx.delete(item)
}
}
func (c Collection[T]) Get(id uint64) (t T, ok bool) {
ptr, ok := c.items.Get(id)
if !ok {
return t, false
}
return *ptr, true
}
// ----------------------------------------------------------------------------
func (c *Collection[T]) loadData() {
for _, idx := range c.indices {
must(idx.load(c.items.m))
}
}
func (c *Collection[T]) onStore(id uint64, data []byte) {
item := decode[T](data)
old, ok := c.items.Get(id)
if !ok {
// Insert.
for _, idx := range c.indices {
idx.insert(item)
}
} else {
// Otherwise update.
for _, idx := range c.indices {
idx.update(old, item)
}
}
}
func (c *Collection[T]) onDelete(id uint64) {
if item, ok := c.items.Get(id); ok {
for _, idx := range c.indices {
idx.delete(item)
}
}
}

153
collection_test.go Normal file
View File

@ -0,0 +1,153 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"errors"
"reflect"
"testing"
)
func TestCollection(t *testing.T) {
testWithDB(t, "insert on secondary", func(t *testing.T, db *DB) {
c := db.Users.c
c.primary = false
if _, err := c.Insert(User{}); !errors.Is(err, ErrReadOnly) {
t.Fatal(err)
}
})
testWithDB(t, "insert validation error", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{
ID: c.NextID(),
Email: "a@b.com",
}
if _, err := c.Insert(user); !errors.Is(err, ErrInvalidName) {
t.Fatal(err)
}
})
testWithDB(t, "insert duplicate", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{
ID: c.NextID(),
Name: "adsf",
Email: "a@b.com",
}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
if _, err := c.Insert(user); !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
})
testWithDB(t, "update on secondary", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
c.primary = false
err := c.Update(user.ID, func(u User) (User, error) {
u.Name = "xxx"
return u, nil
})
if !errors.Is(err, ErrReadOnly) {
t.Fatal(err)
}
})
testWithDB(t, "update not found", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
err := c.Update(user.ID+1, func(u User) (User, error) {
u.Name = "xxx"
return u, nil
})
if !errors.Is(err, ErrNotFound) {
t.Fatal(err)
}
})
testWithDB(t, "update failed validation", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
err := c.Update(user.ID, func(u User) (User, error) {
u.Name = ""
return u, nil
})
if !errors.Is(err, ErrInvalidName) {
t.Fatal(err)
}
})
testWithDB(t, "delete on secondary", func(t *testing.T, db *DB) {
defer func() {
if err := recover(); err == nil {
t.Fatal("No panic")
}
}()
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
c.primary = false
c.Delete(1)
})
testWithDB(t, "delete not found", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
c.Delete(user.ID + 1) // Does nothing.
})
testWithDB(t, "get", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
u2, ok := c.Get(user.ID)
if !ok || !reflect.DeepEqual(user, u2) {
t.Fatal(ok, u2, user)
}
})
testWithDB(t, "get not found", func(t *testing.T, db *DB) {
c := db.Users.c
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
if _, err := c.Insert(user); err != nil {
t.Fatal(err)
}
u2, ok := c.Get(user.ID - 1)
if ok {
t.Fatal(ok, u2)
}
})
}

123
database.go Normal file
View File

@ -0,0 +1,123 @@
package mdb
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"net"
"os"
"path/filepath"
"sync"
"git.crumpington.com/public/mdb/kvstore"
"golang.org/x/sys/unix"
)
type LogInfo = kvstore.LogInfo
type Database struct {
root string
lock *os.File
kv *kvstore.KV
collections map[string]dbCollection
}
func NewPrimary(root string) *Database {
return newDB(root, true)
}
func NewSecondary(root string) *Database {
return newDB(root, false)
}
func newDB(root string, primary bool) *Database {
must(os.MkdirAll(root, 0700))
lockPath := filepath.Join(root, "lock")
// Acquire the lock.
lock, err := os.OpenFile(lockPath, os.O_RDWR|os.O_CREATE, 0600)
must(err)
must(unix.Flock(int(lock.Fd()), unix.LOCK_EX))
db := &Database{
root: root,
collections: map[string]dbCollection{},
lock: lock,
}
if primary {
db.kv = kvstore.NewPrimary(root)
} else {
db.kv = kvstore.NewSecondary(root, db.onStore, db.onDelete)
}
return db
}
func (db *Database) Start() {
wg := sync.WaitGroup{}
for _, c := range db.collections {
wg.Add(1)
go func(c dbCollection) {
defer wg.Done()
c.loadData()
}(c)
}
wg.Wait()
}
func (db *Database) MaxSeqNum() uint64 {
return db.kv.MaxSeqNum()
}
func (db *Database) LogInfo() LogInfo {
return db.kv.LogInfo()
}
func (db *Database) Close() {
if db.kv != nil {
db.kv.Close()
db.kv = nil
}
if db.lock != nil {
db.lock.Close()
db.lock = nil
}
}
// ----------------------------------------------------------------------------
func (db *Database) onStore(collection string, id uint64, data []byte) {
if c, ok := db.collections[collection]; ok {
c.onStore(id, data)
}
}
func (db *Database) onDelete(collection string, id uint64) {
if c, ok := db.collections[collection]; ok {
c.onDelete(id)
}
}
// ----------------------------------------------------------------------------
func (db *Database) SyncSend(conn net.Conn) {
db.kv.SyncSend(conn)
}
func (db *Database) SyncRecv(conn net.Conn) {
db.kv.SyncRecv(conn)
}
// ----------------------------------------------------------------------------
// CleanBefore deletes log entries that are more than the given number of
// seconds old.
func (db *Database) CleanBefore(seconds int64) {
db.kv.CleanBefore(seconds)
}

1
database_ex_test.go Normal file
View File

@ -0,0 +1 @@
package mdb

45
database_test.go Normal file
View File

@ -0,0 +1,45 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"sync"
"testing"
"time"
)
func TestDatabase(t *testing.T) {
testWithDB(t, "multiple writers", func(t *testing.T, db *DB) {
wg := sync.WaitGroup{}
N := 64
wg.Add(64)
for i := 0; i < N; i++ {
go func() {
defer wg.Done()
for j := 0; j < 1024; j++ {
db.RandAction()
}
}()
}
wg.Wait()
})
testWithDB(t, "clean before", func(t *testing.T, db *DB) {
db.RandAction()
db.RandAction()
db.RandAction()
time.Sleep(2 * time.Second)
db.CleanBefore(0)
if db.MaxSeqNum() == 0 {
t.Fatal(db.MaxSeqNum())
}
})
}

5
dep-graph.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
godepgraph -s . > .deps.dot && xdot .deps.dot
rm .deps.dot

21
errors.go Normal file
View File

@ -0,0 +1,21 @@
package mdb
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import "errors"
var (
ErrMismatchedIDs = errors.New("MismatchedIDs")
ErrDuplicate = errors.New("Duplicate")
ErrNotFound = errors.New("NotFound")
ErrReadOnly = errors.New("ReadOnly")
// Return in update function to abort changes.
ErrAbortUpdate = errors.New("AbortUpdate")
)

9
go.mod Normal file
View File

@ -0,0 +1,9 @@
module git.crumpington.com/public/mdb
go 1.18
require (
github.com/google/btree v1.1.2
github.com/mattn/go-sqlite3 v1.14.14
golang.org/x/sys v0.0.0-20220731174439-a90be440212d
)

6
go.sum Normal file
View File

@ -0,0 +1,6 @@
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw=
github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

114
itemmap.go Normal file
View File

@ -0,0 +1,114 @@
package mdb
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"math/rand"
"sync"
"sync/atomic"
"git.crumpington.com/public/mdb/keyedmutex"
"git.crumpington.com/public/mdb/kvstore"
)
// Implements ItemMap and ItemUniqueIndex interfaces.
type itemMap[T any] struct {
primary bool
kv *kvstore.KV
collection string
idLock keyedmutex.KeyedMutex[uint64]
mapLock sync.Mutex
m map[uint64]*T
getID func(*T) uint64
maxID uint64
}
func newItemMap[T any](kv *kvstore.KV, collection string, getID func(*T) uint64) *itemMap[T] {
m := &itemMap[T]{
primary: kv.Primary(),
kv: kv,
collection: collection,
idLock: keyedmutex.New[uint64](),
m: map[uint64]*T{},
getID: getID,
}
kv.Iterate(collection, func(id uint64, data []byte) {
item := decode[T](data)
if id > m.maxID {
m.maxID = id
}
m.m[id] = item
})
return m
}
func (m *itemMap[T]) load(src map[uint64]*T) error {
// No-op: The itemmap is the source for loading all other indices.
return nil
}
// ----------------------------------------------------------------------------
func (m *itemMap[T]) Get(id uint64) (*T, bool) {
return m.mapGet(id)
}
// Should hold item lock when calling.
func (idx *itemMap[T]) insert(item *T) {
id := idx.getID(item)
if idx.primary {
idx.kv.Store(idx.collection, id, encode(item))
}
idx.mapSet(id, item)
}
// Should hold item lock when calling. old and new MUST have the same ID.
func (idx *itemMap[T]) update(old, new *T) {
idx.insert(new)
}
// Should hold item lock when calling.
func (idx *itemMap[T]) delete(item *T) {
id := idx.getID(item)
if idx.primary {
idx.kv.Delete(idx.collection, id)
}
idx.mapDelete(id)
}
// ----------------------------------------------------------------------------
func (idx *itemMap[T]) mapSet(id uint64, item *T) {
idx.mapLock.Lock()
idx.m[id] = item
idx.mapLock.Unlock()
}
func (idx *itemMap[T]) mapDelete(id uint64) {
idx.mapLock.Lock()
delete(idx.m, id)
idx.mapLock.Unlock()
}
func (idx *itemMap[T]) mapGet(id uint64) (*T, bool) {
idx.mapLock.Lock()
item, ok := idx.m[id]
idx.mapLock.Unlock()
return item, ok
}
// ----------------------------------------------------------------------------
func (idx *itemMap[T]) nextID() uint64 {
n := 1 + rand.Int63n(256)
return atomic.AddUint64(&idx.maxID, uint64(n))
}

61
itemmap_ex_test.go Normal file
View File

@ -0,0 +1,61 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"fmt"
"reflect"
)
func (m *itemMap[T]) Equals(rhs *itemMap[T]) error {
return m.EqualsMap(rhs.m)
}
func (m *itemMap[T]) EqualsMap(data map[uint64]*T) error {
if len(data) != len(m.m) {
return fmt.Errorf("Expected %d items, but found %d.", len(data), len(m.m))
}
for key, exp := range data {
val, ok := m.m[key]
if !ok {
return fmt.Errorf("No value for %d. Expected: %v", key, *exp)
}
if !reflect.DeepEqual(*val, *exp) {
return fmt.Errorf("Value mismatch %d: %v != %v", key, *val, *exp)
}
}
return nil
}
func (m *itemMap[T]) EqualsKV() (err error) {
count := 0
m.kv.Iterate(m.collection, func(id uint64, data []byte) {
count++
if err != nil {
return
}
item := decode[T](data)
val, ok := m.m[id]
if !ok {
err = fmt.Errorf("Item %d not found in memory: %v", id, *item)
return
}
if !reflect.DeepEqual(*item, *val) {
err = fmt.Errorf("Items not equal %d: %v != %v", id, *item, *val)
return
}
})
if err == nil && count != len(m.m) {
err = fmt.Errorf("%d items on disk, but %d in memory", count, len(m.m))
}
return err
}

134
itemmap_test.go Normal file
View File

@ -0,0 +1,134 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"fmt"
"reflect"
"testing"
)
func TestItemMap(t *testing.T) {
// expected is a map of users.
run := func(name string, inner func(t *testing.T, db *DB) (expected map[uint64]*User)) {
testWithDB(t, name, func(t *testing.T, db *DB) {
expected := inner(t, db)
if err := db.Users.c.items.EqualsMap(expected); err != nil {
t.Fatal(err)
}
if err := db.Users.c.items.EqualsKV(); err != nil {
t.Fatal(err)
}
db.Close()
db = OpenDB(db.root, true)
if err := db.Users.c.items.EqualsMap(expected); err != nil {
t.Fatal(err)
}
if err := db.Users.c.items.EqualsKV(); err != nil {
t.Fatal(err)
}
})
}
run("simple", func(t *testing.T, db *DB) (expected map[uint64]*User) {
users := map[uint64]*User{}
c := db.Users.c
for i := uint64(1); i < 10; i++ {
id := c.NextID()
users[id] = &User{
ID: id,
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
user, err := c.Insert(*users[id])
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(user, *users[id]) {
t.Fatal(user, *users[id])
}
}
return users
})
run("insert and delete", func(t *testing.T, db *DB) (expected map[uint64]*User) {
users := map[uint64]*User{}
c := db.Users.c
for x := uint64(1); x < 10; x++ {
id := c.NextID()
users[id] = &User{
ID: id,
Email: fmt.Sprintf("a.%d@c.com", x),
Name: fmt.Sprintf("name.%d", x),
ExtID: fmt.Sprintf("EXTID.%d", x),
}
user, err := c.Insert(*users[id])
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(user, *users[id]) {
t.Fatal(user, *users[id])
}
}
var id uint64
for key := range users {
id = key
}
delete(users, id)
c.Delete(id)
return users
})
run("update", func(t *testing.T, db *DB) (expected map[uint64]*User) {
users := map[uint64]*User{}
c := db.Users.c
for x := uint64(1); x < 10; x++ {
id := c.NextID()
users[id] = &User{
ID: id,
Email: fmt.Sprintf("a.%d@c.com", x),
Name: fmt.Sprintf("name.%d", x),
ExtID: fmt.Sprintf("EXTID.%d", x),
}
user, err := c.Insert(*users[id])
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(user, *users[id]) {
t.Fatal(user, *users[id])
}
}
var id uint64
for key := range users {
id = key
}
err := c.Update(id, func(u User) (User, error) {
u.Name = "Hello"
return u, nil
})
if err != nil {
t.Fatal(err)
}
users[id].Name = "Hello"
return users
})
}

75
keyedmutex/mutex.go Normal file
View File

@ -0,0 +1,75 @@
package keyedmutex
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"container/list"
"sync"
)
type KeyedMutex[K comparable] struct {
mu *sync.Mutex
waitList map[K]*list.List
}
func New[K comparable]() KeyedMutex[K] {
return KeyedMutex[K]{
mu: new(sync.Mutex),
waitList: map[K]*list.List{},
}
}
func (m KeyedMutex[K]) Lock(key K) {
if ch := m.lock(key); ch != nil {
<-ch
}
}
func (m KeyedMutex[K]) lock(key K) chan struct{} {
m.mu.Lock()
defer m.mu.Unlock()
if waitList, ok := m.waitList[key]; ok {
ch := make(chan struct{})
waitList.PushBack(ch)
return ch
}
m.waitList[key] = list.New()
return nil
}
func (m KeyedMutex[K]) TryLock(key K) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.waitList[key]; ok {
return false
}
m.waitList[key] = list.New()
return true
}
func (m KeyedMutex[K]) Unlock(key K) {
m.mu.Lock()
defer m.mu.Unlock()
waitList, ok := m.waitList[key]
if !ok {
panic("unlock of unlocked mutex")
}
if waitList.Len() == 0 {
delete(m.waitList, key)
} else {
ch := waitList.Remove(waitList.Front()).(chan struct{})
ch <- struct{}{}
}
}

116
kvstore/db-sql.go Normal file
View File

@ -0,0 +1,116 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
const sqlSchema = `
BEGIN IMMEDIATE;
CREATE TABLE IF NOT EXISTS data(
SeqNum INTEGER NOT NULL PRIMARY KEY,
Deleted INTEGER NOT NULL DEFAULT 0,
Data BLOB NOT NULL
) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS data_deleted_index ON data(Deleted,SeqNum);
CREATE TABLE IF NOT EXISTS log(
SeqNum INTEGER NOT NULL PRIMARY KEY,
CreatedAt INTEGER NOT NULL,
Collection TEXT NOT NULL,
ID INTEGER NOT NULL,
Store INTEGER NOT NULL
) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS log_created_at_index ON log(CreatedAt);
CREATE TABLE IF NOT EXISTS kv(
Collection TEXT NOT NULL,
ID INTEGER NOT NULL,
SeqNum INTEGER NOT NULL,
PRIMARY KEY (Collection, ID)
) WITHOUT ROWID;
CREATE VIEW IF NOT EXISTS kvdata AS
SELECT
kv.Collection,
kv.ID,
data.Data
FROM kv
JOIN data ON kv.SeqNum=data.SeqNum;
CREATE VIEW IF NOT EXISTS logdata AS
SELECT
log.SeqNum,
log.Collection,
log.ID,
log.Store,
data.data
FROM log
LEFT JOIN data on log.SeqNum=data.SeqNum;
CREATE TRIGGER IF NOT EXISTS deletedata AFTER UPDATE OF SeqNum ON kv
FOR EACH ROW
WHEN OLD.SeqNum != NEW.SeqNum
BEGIN
UPDATE data SET Deleted=1 WHERE SeqNum=OLD.SeqNum;
END;
COMMIT;`
// ----------------------------------------------------------------------------
const sqlInsertData = `INSERT INTO data(SeqNum,Data) VALUES(?,?)`
const sqlInsertKV = `INSERT INTO kv(Collection,ID,SeqNum) VALUES (?,?,?)
ON CONFLICT(Collection,ID) DO UPDATE SET SeqNum=excluded.SeqNum
WHERE ID=excluded.ID`
// ----------------------------------------------------------------------------
const sqlDeleteKV = `DELETE FROM kv WHERE Collection=? AND ID=?`
const sqlDeleteData = `UPDATE data SET Deleted=1
WHERE SeqNum=(
SELECT SeqNum FROM kv WHERE Collection=? AND ID=?)`
// ----------------------------------------------------------------------------
const sqlInsertLog = `INSERT INTO log(SeqNum,CreatedAt,Collection,ID,Store)
VALUES(?,?,?,?,?)`
// ----------------------------------------------------------------------------
const sqlKVIterate = `SELECT ID,Data FROM kvdata WHERE Collection=?`
const sqlLogIterate = `
SELECT SeqNum,Collection,ID,Store,Data
FROM logdata
WHERE SeqNum > ?
ORDER BY SeqNum ASC`
const sqlMaxSeqNumGet = `SELECT COALESCE(MAX(SeqNum),0) FROM log`
const sqlLogInfoGet = `SELECT
COALESCE(SeqNum, 0),
COALESCE(CreatedAt, 0)
FROM log
ORDER BY SeqNum DESC LIMIT 1`
const sqlCleanQuery = `
DELETE FROM
log
WHERE
CreatedAt < ? AND
SeqNum < (SELECT MAX(SeqNum) FROM log);
DELETE FROM
data
WHERE
Deleted != 0 AND
SeqNum < (SELECT MIN(SeqNum) FROM log);`

5
kvstore/dep-graph.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
godepgraph -s . > .deps.dot && xdot .deps.dot
rm .deps.dot

40
kvstore/globals.go Normal file
View File

@ -0,0 +1,40 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"sync"
"time"
)
var (
connTimeout = 16 * time.Second
pollInterval = 500 * time.Millisecond
modQSize = 1024
poolBufSize = 8192
bufferPool = sync.Pool{
New: func() any {
return make([]byte, poolBufSize)
},
}
)
func GetDataBuf(size int) []byte {
if size > poolBufSize {
return make([]byte, size)
}
return bufferPool.Get().([]byte)[:size]
}
func RecycleDataBuf(b []byte) {
if cap(b) != poolBufSize {
return
}
bufferPool.Put(b)
}

21
kvstore/main_test.go Normal file
View File

@ -0,0 +1,21 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"math/rand"
"os"
"testing"
"time"
)
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
os.Exit(m.Run())
}

53
kvstore/shipping.go Normal file
View File

@ -0,0 +1,53 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import "encoding/binary"
const recHeaderSize = 22
func encodeRecordHeader(rec record, buf []byte) {
// SeqNum (8)
// ID (8)
// DataLen (4)
// Store (1)
// CollectionLen (1)
binary.LittleEndian.PutUint64(buf[:8], rec.SeqNum)
buf = buf[8:]
binary.LittleEndian.PutUint64(buf[:8], rec.ID)
buf = buf[8:]
if rec.Store {
binary.LittleEndian.PutUint32(buf[:4], uint32(len(rec.Data)))
buf[4] = 1
} else {
binary.LittleEndian.PutUint32(buf[:4], 0)
buf[4] = 0
}
buf = buf[5:]
buf[0] = byte(len(rec.Collection))
}
func decodeRecHeader(header []byte) (rec record, colLen, dataLen int) {
buf := header
rec.SeqNum = binary.LittleEndian.Uint64(buf[:8])
buf = buf[8:]
rec.ID = binary.LittleEndian.Uint64(buf[:8])
buf = buf[8:]
dataLen = int(binary.LittleEndian.Uint32(buf[:4]))
buf = buf[4:]
rec.Store = buf[0] != 0
buf = buf[1:]
colLen = int(buf[0])
return
}

254
kvstore/shipping_test.go Normal file
View File

@ -0,0 +1,254 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"math/rand"
"os"
"sync"
"testing"
"time"
"git.crumpington.com/public/mdb/testconn"
)
// ----------------------------------------------------------------------------
// Stores info from secondary callbacks.
type callbacks struct {
lock sync.Mutex
m map[string]map[uint64]string
}
func (sc *callbacks) onStore(c string, id uint64, data []byte) {
sc.lock.Lock()
defer sc.lock.Unlock()
if _, ok := sc.m[c]; !ok {
sc.m[c] = map[uint64]string{}
}
sc.m[c][id] = string(data)
}
func (sc *callbacks) onDelete(c string, id uint64) {
sc.lock.Lock()
defer sc.lock.Unlock()
if _, ok := sc.m[c]; !ok {
return
}
delete(sc.m[c], id)
}
// ----------------------------------------------------------------------------
func TestShipping(t *testing.T) {
run := func(name string, inner func(
t *testing.T,
pDir string,
sDir string,
primary *KV,
secondary *KV,
cbs *callbacks,
nw *testconn.Network,
)) {
t.Run(name, func(t *testing.T) {
pDir, _ := os.MkdirTemp("", "")
defer os.RemoveAll(pDir)
sDir, _ := os.MkdirTemp("", "")
defer os.RemoveAll(sDir)
nw := testconn.NewNetwork()
defer func() {
nw.CloseServer()
nw.CloseClient()
}()
cbs := &callbacks{
m: map[string]map[uint64]string{},
}
prim := NewPrimary(pDir)
defer prim.Close()
sec := NewSecondary(sDir, cbs.onStore, cbs.onDelete)
defer sec.Close()
inner(t, pDir, sDir, prim, sec, cbs, nw)
})
}
run("simple", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
M := 10
N := 1000
wg := sync.WaitGroup{}
// Store M*N values in the background.
for i := 0; i < M; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
time.Sleep(time.Millisecond)
prim.randAction()
}
}()
}
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Accept()
prim.SyncSend(conn)
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Dial()
sec.SyncRecv(conn)
}()
sec.waitForSeqNum(uint64(M * N))
nw.CloseServer()
nw.CloseClient()
wg.Wait()
if err := prim.equalsKV("a", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("b", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("c", sec); err != nil {
t.Fatal(err)
}
})
run("simple concurrent", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
M := 64
N := 128
wg := sync.WaitGroup{}
// Store M*N values in the background.
for i := 0; i < M; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
time.Sleep(time.Millisecond)
prim.randAction()
}
}()
}
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Accept()
prim.SyncSend(conn)
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Dial()
sec.SyncRecv(conn)
}()
sec.waitForSeqNum(uint64(M * N))
nw.CloseServer()
nw.CloseClient()
wg.Wait()
if err := prim.equalsKV("a", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("b", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("c", sec); err != nil {
t.Fatal(err)
}
})
run("net failures", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
M := 10
N := 1000
sleepTime := time.Millisecond
wg := sync.WaitGroup{}
// Store M*N values in the background.
for i := 0; i < M; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < N; j++ {
time.Sleep(sleepTime)
prim.randAction()
}
}()
}
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
for sec.MaxSeqNum() < uint64(M*N) {
if conn := nw.Accept(); conn != nil {
prim.SyncSend(conn)
}
}
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
for sec.MaxSeqNum() < uint64(M*N) {
if conn := nw.Dial(); conn != nil {
sec.SyncRecv(conn)
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for sec.MaxSeqNum() < uint64(M*N) {
time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime))))
if rand.Float64() < 0.5 {
nw.CloseClient()
} else {
nw.CloseServer()
}
}
}()
sec.waitForSeqNum(prim.MaxSeqNum())
wg.Wait()
if err := prim.equalsKV("a", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("b", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("c", sec); err != nil {
t.Fatal(err)
}
})
}

149
kvstore/store.go Normal file
View File

@ -0,0 +1,149 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"database/sql"
"sync"
"time"
_ "github.com/mattn/go-sqlite3"
)
type LogInfo struct {
MaxSeqNum uint64
MaxCreatedAt uint64
}
type KV struct {
primary bool
dbPath string
db *sql.DB
maxSeqNumStmt *sql.Stmt
logInfoStmt *sql.Stmt
logIterateStmt *sql.Stmt
w *writer
onStore func(string, uint64, []byte)
onDelete func(string, uint64)
closeLock sync.Mutex
recvLock sync.Mutex
}
func newKV(
dir string,
primary bool,
onStore func(string, uint64, []byte),
onDelete func(string, uint64),
) *KV {
kv := &KV{
dbPath: dbPath(dir),
primary: primary,
onStore: onStore,
onDelete: onDelete,
}
opts := `?_journal=WAL`
db, err := sql.Open("sqlite3", kv.dbPath+opts)
must(err)
_, err = db.Exec(sqlSchema)
must(err)
kv.maxSeqNumStmt, err = db.Prepare(sqlMaxSeqNumGet)
must(err)
kv.logInfoStmt, err = db.Prepare(sqlLogInfoGet)
must(err)
kv.logIterateStmt, err = db.Prepare(sqlLogIterate)
must(err)
_, err = db.Exec(sqlSchema)
must(err)
kv.db = db
if kv.primary {
kv.w = newWriter(kv.db)
kv.w.Start(kv.MaxSeqNum())
}
return kv
}
func NewPrimary(dir string) *KV {
return newKV(dir, true, nil, nil)
}
func NewSecondary(
dir string,
onStore func(collection string, id uint64, data []byte),
onDelete func(collection string, id uint64),
) *KV {
return newKV(dir, false, onStore, onDelete)
}
func (kv *KV) Primary() bool {
return kv.primary
}
func (kv *KV) MaxSeqNum() (seqNum uint64) {
must(kv.maxSeqNumStmt.QueryRow().Scan(&seqNum))
return seqNum
}
func (kv *KV) LogInfo() (info LogInfo) {
must(kv.logInfoStmt.QueryRow().Scan(&info.MaxSeqNum, &info.MaxCreatedAt))
return
}
func (kv *KV) Iterate(collection string, each func(id uint64, data []byte)) {
rows, err := kv.db.Query(sqlKVIterate, collection)
must(err)
defer rows.Close()
var (
id uint64
data []byte
)
for rows.Next() {
must(rows.Scan(&id, &data))
each(id, data)
}
}
func (kv *KV) Close() {
kv.closeLock.Lock()
defer kv.closeLock.Unlock()
if kv.w != nil {
kv.w.Stop()
}
if kv.db != nil {
kv.db.Close()
kv.db = nil
}
}
func (kv *KV) Store(collection string, id uint64, data []byte) {
kv.w.Store(collection, id, data)
}
func (kv *KV) Delete(collection string, id uint64) {
kv.w.Delete(collection, id)
}
func (kv *KV) CleanBefore(seconds int64) {
_, err := kv.db.Exec(sqlCleanQuery, time.Now().Unix()-seconds)
must(err)
}

125
kvstore/store_test.go Normal file
View File

@ -0,0 +1,125 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"fmt"
"math/rand"
"os"
"reflect"
"testing"
"time"
)
// ----------------------------------------------------------------------------
func (kv *KV) waitForSeqNum(x uint64) {
for {
seqNum := kv.MaxSeqNum()
if seqNum >= x {
return
}
time.Sleep(100 * time.Millisecond)
}
}
func (kv *KV) dump(collection string) map[uint64]string {
m := map[uint64]string{}
kv.Iterate(collection, func(id uint64, data []byte) {
m[id] = string(data)
})
return m
}
func (kv *KV) equals(collection string, expected map[uint64]string) error {
m := kv.dump(collection)
if len(m) != len(expected) {
return fmt.Errorf("Expected %d values but found %d", len(expected), len(m))
}
for key, exp := range expected {
val, ok := m[key]
if !ok {
return fmt.Errorf("Value for %d not found.", key)
}
if val != exp {
return fmt.Errorf("%d: Expected %s but found %s.", key, exp, val)
}
}
return nil
}
func (kv *KV) equalsKV(collection string, rhs *KV) error {
l1 := []record{}
kv.replay(0, func(rec record) error {
l1 = append(l1, rec)
return nil
})
idx := -1
err := rhs.replay(0, func(rec record) error {
idx++
if !reflect.DeepEqual(rec, l1[idx]) {
return fmt.Errorf("Records not equal: %d %v %v", idx, rec, l1[idx])
}
return nil
})
if err != nil {
return err
}
return kv.equals(collection, rhs.dump(collection))
}
// Collection one of ("a", "b", "c").
// ID one of [1,10]
var (
randCollections = []string{"a", "b", "c"}
randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
)
func (kv *KV) randAction() {
c := randCollections[rand.Intn(len(randCollections))]
id := randIDs[rand.Intn(len(randIDs))]
// Mostly stores.
if rand.Float64() < 0.9 {
kv.Store(c, id, randBytes())
} else {
kv.Delete(c, id)
}
}
// ----------------------------------------------------------------------------
func TestKV(t *testing.T) {
run := func(name string, inner func(t *testing.T, kv *KV)) {
dir, _ := os.MkdirTemp("", "")
defer os.RemoveAll(dir)
kv := NewPrimary(dir)
defer kv.Close()
inner(t, kv)
}
run("simple", func(t *testing.T, kv *KV) {
kv.Store("a", 1, _b("Hello"))
kv.Store("a", 2, _b("World"))
kv.waitForSeqNum(2)
err := kv.equals("a", map[uint64]string{
1: "Hello",
2: "World",
})
if err != nil {
t.Fatal(err)
}
})
}

103
kvstore/sync-recv.go Normal file
View File

@ -0,0 +1,103 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"encoding/binary"
"log"
"net"
"time"
)
func (kv *KV) SyncRecv(conn net.Conn) {
defer conn.Close()
if kv.primary {
panic("SyncRecv called on primary.")
}
if !kv.recvLock.TryLock() {
return
}
defer kv.recvLock.Unlock()
// It's important that we stop the writer so that all queued writes are
// committed to the database before syncing has a chance to restart.
w := newWriter(kv.db)
w.Start(kv.MaxSeqNum())
defer w.Stop()
headerBuf := make([]byte, recHeaderSize)
nameBuf := make([]byte, 32)
afterSeqNumBuf := make([]byte, 8)
afterSeqNum := kv.MaxSeqNum()
expectedSeqNum := afterSeqNum + 1
// Send fromID to the conn.
conn.SetWriteDeadline(time.Now().Add(connTimeout))
binary.LittleEndian.PutUint64(afterSeqNumBuf, afterSeqNum)
if _, err := conn.Write(afterSeqNumBuf); err != nil {
log.Printf("RecvWAL failed to send after sequence number: %v", err)
return
}
conn.SetWriteDeadline(time.Time{})
for {
conn.SetReadDeadline(time.Now().Add(connTimeout))
if _, err := conn.Read(headerBuf); err != nil {
log.Printf("RecvWAL failed to read header: %v", err)
return
}
rec, nameLen, dataLen := decodeRecHeader(headerBuf)
// Heartbeat.
if rec.SeqNum == 0 {
continue
}
if rec.SeqNum != expectedSeqNum {
log.Printf("Expected sequence number %d but got %d.",
expectedSeqNum, rec.SeqNum)
return
}
expectedSeqNum++
if cap(nameBuf) < nameLen {
nameBuf = make([]byte, nameLen)
}
nameBuf = nameBuf[:nameLen]
if _, err := conn.Read(nameBuf); err != nil {
log.Printf("RecvWAL failed to read collection name: %v", err)
return
}
rec.Collection = string(nameBuf)
// Note that it's OK to apply changes via onStore / onDelete before the log
// entries are written because all changes are idempotent. Even if the log
// is replayed from some point in the past, the final state will be
// consistent.
if rec.Store {
rec.Data = GetDataBuf(dataLen)
if _, err := conn.Read(rec.Data); err != nil {
log.Printf("RecvWAL failed to read data: %v", err)
return
}
kv.onStore(rec.Collection, rec.ID, rec.Data)
w.StoreAsync(rec.Collection, rec.ID, rec.Data)
} else {
kv.onDelete(rec.Collection, rec.ID)
w.DeleteAsync(rec.Collection, rec.ID)
}
}
}

116
kvstore/sync-send.go Normal file
View File

@ -0,0 +1,116 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"encoding/binary"
"log"
"net"
"time"
)
func (kv *KV) SyncSend(conn net.Conn) {
defer conn.Close()
var (
seqNumBuf = make([]byte, 8)
headerBuf = make([]byte, recHeaderSize)
empty = make([]byte, recHeaderSize)
err error
)
// Read afterSeqNum from the conn.
conn.SetReadDeadline(time.Now().Add(connTimeout))
if _, err := conn.Read(seqNumBuf[:8]); err != nil {
log.Printf("SyncSend failed to read afterSeqNum: %v", err)
return
}
afterSeqNum := binary.LittleEndian.Uint64(seqNumBuf[:8])
POLL:
for i := 0; i < 4; i++ {
if kv.MaxSeqNum() > afterSeqNum {
goto REPLAY
}
time.Sleep(pollInterval)
}
goto HEARTBEAT
HEARTBEAT:
conn.SetWriteDeadline(time.Now().Add(connTimeout))
if _, err := conn.Write(empty); err != nil {
log.Printf("SendWAL failed to send heartbeat: %v", err)
return
}
goto POLL
REPLAY:
err = kv.replay(afterSeqNum, func(rec record) error {
conn.SetWriteDeadline(time.Now().Add(connTimeout))
afterSeqNum = rec.SeqNum
encodeRecordHeader(rec, headerBuf)
if _, err := conn.Write(headerBuf); err != nil {
log.Printf("SendWAL failed to send header %v", err)
return err
}
if _, err := conn.Write([]byte(rec.Collection)); err != nil {
log.Printf("SendWAL failed to send collection name %v", err)
return err
}
if !rec.Store {
return nil
}
if _, err := conn.Write(rec.Data); err != nil {
log.Printf("SendWAL failed to send data %v", err)
return err
}
return nil
})
if err != nil {
return
}
goto POLL
}
func (kv *KV) replay(afterSeqNum uint64, each func(rec record) error) error {
rec := record{}
rows, err := kv.logIterateStmt.Query(afterSeqNum)
must(err)
defer rows.Close()
rec.Data = GetDataBuf(0)
defer RecycleDataBuf(rec.Data)
for rows.Next() {
must(rows.Scan(
&rec.SeqNum,
&rec.Collection,
&rec.ID,
&rec.Store,
&rec.Data))
if err = each(rec); err != nil {
return err
}
}
return nil
}

27
kvstore/types.go Normal file
View File

@ -0,0 +1,27 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import "sync"
type modJob struct {
Collection string
ID uint64
Store bool
Data []byte
Ready *sync.WaitGroup
}
type record struct {
SeqNum uint64
Collection string
ID uint64
Store bool
Data []byte
}

23
kvstore/util.go Normal file
View File

@ -0,0 +1,23 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"path/filepath"
)
func must(err error) {
if err != nil {
panic(err)
}
}
func dbPath(dir string) string {
return filepath.Join(dir, "db")
}

29
kvstore/util_test.go Normal file
View File

@ -0,0 +1,29 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"crypto/rand"
"encoding/hex"
mrand "math/rand"
)
func _b(in string) []byte {
return []byte(in)
}
func randString() string {
buf := make([]byte, 1+mrand.Intn(20))
rand.Read(buf)
return hex.EncodeToString(buf)
}
func randBytes() []byte {
return _b(randString())
}

194
kvstore/writer.go Normal file
View File

@ -0,0 +1,194 @@
package kvstore
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"database/sql"
"sync"
"time"
)
type writer struct {
db *sql.DB
modQ chan modJob
wg sync.WaitGroup
}
func newWriter(db *sql.DB) *writer {
return &writer{
db: db,
modQ: make(chan modJob, modQSize),
}
}
func (w *writer) Start(maxSeqNum uint64) {
w.wg.Add(1)
go w.run(maxSeqNum)
}
func (w *writer) Stop() {
close(w.modQ)
w.wg.Wait()
}
// Takes ownership of the incoming data.
func (w *writer) Store(collection string, id uint64, data []byte) {
job := modJob{
Collection: collection,
ID: id,
Store: true,
Data: data,
Ready: &sync.WaitGroup{},
}
job.Ready.Add(1)
w.modQ <- job
job.Ready.Wait()
}
func (w *writer) Delete(collection string, id uint64) {
job := modJob{
Collection: collection,
ID: id,
Store: false,
Ready: &sync.WaitGroup{},
}
job.Ready.Add(1)
w.modQ <- job
job.Ready.Wait()
}
// Takes ownership of the incoming data.
func (w *writer) StoreAsync(collection string, id uint64, data []byte) {
w.modQ <- modJob{
Collection: collection,
ID: id,
Store: true,
Data: data,
}
}
func (w *writer) DeleteAsync(collection string, id uint64) {
w.modQ <- modJob{
Collection: collection,
ID: id,
Store: false,
}
}
func (w *writer) run(maxSeqNum uint64) {
defer w.wg.Done()
var (
mod modJob
ok bool
tx *sql.Tx
insertData *sql.Stmt
insertKV *sql.Stmt
deleteData *sql.Stmt
deleteKV *sql.Stmt
insertLog *sql.Stmt
err error
curSeqNum = maxSeqNum
now int64
wgs = make([]*sync.WaitGroup, 10)
)
BEGIN:
insertData = nil
deleteData = nil
wgs = wgs[:0]
mod, ok = <-w.modQ
if !ok {
return
}
tx, err = w.db.Begin()
must(err)
now = time.Now().Unix()
insertLog, err = tx.Prepare(sqlInsertLog)
must(err)
LOOP:
if mod.Ready != nil {
wgs = append(wgs, mod.Ready)
}
curSeqNum++
if mod.Store {
goto STORE
} else {
goto DELETE
}
STORE:
if insertData == nil {
insertData, err = tx.Prepare(sqlInsertData)
must(err)
insertKV, err = tx.Prepare(sqlInsertKV)
must(err)
}
_, err = insertData.Exec(curSeqNum, mod.Data)
must(err)
_, err = insertKV.Exec(mod.Collection, mod.ID, curSeqNum)
must(err)
_, err = insertLog.Exec(curSeqNum, now, mod.Collection, mod.ID, true)
must(err)
RecycleDataBuf(mod.Data)
goto NEXT
DELETE:
if deleteData == nil {
deleteData, err = tx.Prepare(sqlDeleteData)
must(err)
deleteKV, err = tx.Prepare(sqlDeleteKV)
must(err)
}
_, err = deleteData.Exec(mod.Collection, mod.ID)
must(err)
_, err = deleteKV.Exec(mod.Collection, mod.ID)
must(err)
_, err = insertLog.Exec(curSeqNum, now, mod.Collection, mod.ID, false)
must(err)
goto NEXT
NEXT:
select {
case mod, ok = <-w.modQ:
if ok {
goto LOOP
}
default:
}
goto COMMIT
COMMIT:
must(tx.Commit())
for i := range wgs {
wgs[i].Done()
}
goto BEGIN
}

32
main_test.go Normal file
View File

@ -0,0 +1,32 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"math/rand"
"os"
"testing"
"time"
)
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
os.Exit(m.Run())
}
func testWithDB(t *testing.T, name string, inner func(t *testing.T, db *DB)) {
t.Run(name, func(t *testing.T) {
root, err := os.MkdirTemp("", "")
must(err)
defer os.RemoveAll(root)
db := OpenDB(root, true)
defer db.Close()
inner(t, db)
})
}

177
mapindex.go Normal file
View File

@ -0,0 +1,177 @@
package mdb
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"sync"
"git.crumpington.com/public/mdb/keyedmutex"
)
type MapIndex[K comparable, T any] struct {
c *Collection[T]
_name string
keyLock keyedmutex.KeyedMutex[K]
mapLock sync.Mutex
m map[K]*T
getID func(*T) uint64
getKey func(*T) K
include func(*T) bool
}
func NewMapIndex[K comparable, T any](
c *Collection[T],
name string,
getKey func(*T) K,
include func(*T) bool,
) *MapIndex[K, T] {
m := &MapIndex[K, T]{
c: c,
_name: name,
keyLock: keyedmutex.New[K](),
m: map[K]*T{},
getID: c.getID,
getKey: getKey,
include: include,
}
c.indices = append(c.indices, m)
c.uniqueIndices = append(c.uniqueIndices, m)
return m
}
// ----------------------------------------------------------------------------
func (m *MapIndex[K, T]) load(src map[uint64]*T) error {
for _, item := range src {
if m.include != nil && !m.include(item) {
continue
}
k := m.getKey(item)
if _, ok := m.m[k]; ok {
return ErrDuplicate
}
m.m[k] = item
}
return nil
}
func (m *MapIndex[K, T]) Get(k K) (t T, ok bool) {
if item, ok := m.mapGet(k); ok {
return *item, true
}
return t, false
}
func (m *MapIndex[K, T]) Update(k K, update func(T) (T, error)) error {
wrapped := func(item T) (T, error) {
new, err := update(item)
if err != nil {
return item, err
}
return new, nil
}
if item, ok := m.mapGet(k); ok {
return m.c.Update(m.getID(item), wrapped)
}
return ErrNotFound
}
func (m *MapIndex[K, T]) Delete(k K) {
if item, ok := m.mapGet(k); ok {
m.c.Delete(m.getID(item))
}
}
// ----------------------------------------------------------------------------
func (m *MapIndex[K, T]) insert(item *T) {
if m.include == nil || m.include(item) {
m.mapSet(m.getKey(item), item)
}
}
func (m *MapIndex[K, T]) update(old, new *T) {
if m.include == nil || m.include(old) {
m.mapDelete(m.getKey(old))
}
if m.include == nil || m.include(new) {
m.mapSet(m.getKey(new), new)
}
}
func (m *MapIndex[K, T]) delete(item *T) {
if m.include == nil || m.include(item) {
m.mapDelete(m.getKey(item))
}
}
// ----------------------------------------------------------------------------
func (idx *MapIndex[K, T]) name() string {
return idx._name
}
func (idx *MapIndex[K, T]) lock(item *T) {
if idx.include == nil || idx.include(item) {
idx.keyLock.Lock(idx.getKey(item))
}
}
func (idx *MapIndex[K, T]) unlock(item *T) {
if idx.include == nil || idx.include(item) {
idx.keyLock.Unlock(idx.getKey(item))
}
}
// Should hold item lock when calling.
func (idx *MapIndex[K, T]) insertConflict(new *T) bool {
if idx.include != nil && !idx.include(new) {
return false
}
_, ok := idx.Get(idx.getKey(new))
return ok
}
// Should hold item lock when calling.
func (idx *MapIndex[K, T]) updateConflict(new *T) bool {
if idx.include != nil && !idx.include(new) {
return false
}
cur, ok := idx.mapGet(idx.getKey(new))
return ok && idx.getID(cur) != idx.getID(new)
}
// ----------------------------------------------------------------------------
func (idx *MapIndex[K, T]) mapSet(k K, t *T) {
idx.mapLock.Lock()
idx.m[k] = t
idx.mapLock.Unlock()
}
func (idx *MapIndex[K, T]) mapDelete(k K) {
idx.mapLock.Lock()
delete(idx.m, k)
idx.mapLock.Unlock()
}
func (idx *MapIndex[K, T]) mapGet(k K) (*T, bool) {
idx.mapLock.Lock()
t, ok := idx.m[k]
idx.mapLock.Unlock()
return t, ok
}

35
mapindex_ex_test.go Normal file
View File

@ -0,0 +1,35 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"fmt"
"reflect"
)
func (m *MapIndex[K, T]) Equals(rhs *MapIndex[K, T]) error {
return m.EqualsMap(rhs.m)
}
func (m *MapIndex[K, T]) EqualsMap(data map[K]*T) error {
if len(m.m) != len(data) {
return fmt.Errorf("Expected %d items, but found %d.", len(data), len(m.m))
}
for key, exp := range data {
val, ok := m.Get(key)
if !ok {
return fmt.Errorf("No value for %v. Expected: %v", key, *exp)
}
if !reflect.DeepEqual(val, *exp) {
return fmt.Errorf("Value mismatch %v: %v != %v", key, val, *exp)
}
}
return nil
}

692
mapindex_test.go Normal file
View File

@ -0,0 +1,692 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"errors"
"fmt"
"reflect"
"testing"
)
func TestFullMapIndex(t *testing.T) {
// Test against the emailMap index.
run := func(name string, inner func(t *testing.T, db *DB) map[string]*User) {
testWithDB(t, name, func(t *testing.T, db *DB) {
expected := inner(t, db)
if err := db.Users.emailMap.EqualsMap(expected); err != nil {
t.Fatal(err)
}
db.Close()
db = OpenDB(db.root, true)
if err := db.Users.emailMap.EqualsMap(expected); err != nil {
t.Fatal(err)
}
})
}
run("insert", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
user2, err := db.Users.c.Insert(*user)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*user, user2) {
t.Fatal(*user, user2)
}
users[user.Email] = user
}
return users
})
run("delete", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
user2, err := db.Users.c.Insert(*user)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*user, user2) {
t.Fatal(*user, user2)
}
users[user.Email] = user
}
var id string
for key := range users {
id = key
break
}
delete(users, id)
db.Users.emailMap.Delete(id)
return users
})
run("update non-indexed field", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
user2, err := db.Users.c.Insert(*user)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*user, user2) {
t.Fatal(*user, user2)
}
users[user.Email] = user
}
var id string
for key := range users {
id = key
break
}
err := db.Users.emailMap.Update(id, func(u User) (User, error) {
u.Name = "UPDATED"
return u, nil
})
if err != nil {
t.Fatal(err)
}
users[id].Name = "UPDATED"
return users
})
run("update indexed field", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
user2, err := db.Users.c.Insert(*user)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(*user, user2) {
t.Fatal(*user, user2)
}
users[user.Email] = user
}
var id uint64
var email string
for key := range users {
email = key
id = users[key].ID
break
}
err := db.Users.c.Update(id, func(u User) (User, error) {
u.Email = "test@x.com"
return u, nil
})
if err != nil {
t.Fatal(err)
}
user := users[email]
user.Email = "test@x.com"
delete(users, email)
users[user.Email] = user
return users
})
run("update change id error", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
if _, err := db.Users.c.Insert(*user); err != nil {
t.Fatal(err)
}
users[user.Email] = user
}
var email string
for key := range users {
email = key
break
}
err := db.Users.emailMap.Update(email, func(u User) (User, error) {
u.ID++
return u, nil
})
if err != ErrMismatchedIDs {
t.Fatal(err)
}
return users
})
run("update function error", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
if _, err := db.Users.c.Insert(*user); err != nil {
t.Fatal(err)
}
users[user.Email] = user
}
var email string
for key := range users {
email = key
break
}
myErr := errors.New("hello")
err := db.Users.emailMap.Update(email, func(u User) (User, error) {
return u, myErr
})
if err != myErr {
t.Fatal(err)
}
return users
})
run("update ErrAbortUpdate", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
if _, err := db.Users.c.Insert(*user); err != nil {
t.Fatal(err)
}
users[user.Email] = user
}
var email string
for key := range users {
email = key
break
}
err := db.Users.emailMap.Update(email, func(u User) (User, error) {
return u, ErrAbortUpdate
})
if err != nil {
t.Fatal(err)
}
return users
})
run("update ErrNotFound", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
for i := uint64(1); i < 10; i++ {
user := &User{
ID: db.Users.c.NextID(),
Email: fmt.Sprintf("a.%d@c.com", i),
Name: fmt.Sprintf("name.%d", i),
ExtID: fmt.Sprintf("EXTID.%d", i),
}
if _, err := db.Users.c.Insert(*user); err != nil {
t.Fatal(err)
}
users[user.Email] = user
}
var email string
for key := range users {
email = key
break
}
err := db.Users.emailMap.Update(email+"x", func(u User) (User, error) {
return u, nil
})
if err != ErrNotFound {
t.Fatal(err)
}
return users
})
run("insert conflict", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
user := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a", ExtID: ""}
if _, err := db.Users.c.Insert(*user); err != nil {
t.Fatal(err)
}
users[user.Email] = user
user2 := User{
ID: db.Users.c.NextID(),
Email: "a@b.com",
Name: "someone else",
ExtID: "123",
}
_, err := db.Users.c.Insert(user2)
if !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
return users
})
run("update conflict", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{}
user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a"}
user2 := &User{ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x"}
users[user1.Email] = user1
users[user2.Email] = user2
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(user2.ID, func(u User) (User, error) {
u.Email = "a@b.com"
return u, nil
})
if !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
err = db.Users.emailMap.Update("x@y.com", func(u User) (User, error) {
u.Email = "a@b.com"
return u, nil
})
if !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
return users
})
}
func TestPartialMapIndex(t *testing.T) {
// Test against the extID map index.
run := func(name string, inner func(t *testing.T, db *DB) map[string]*User) {
testWithDB(t, name, func(t *testing.T, db *DB) {
expected := inner(t, db)
if err := db.Users.extIDMap.EqualsMap(expected); err != nil {
t.Fatal(err)
}
db.Close()
db = OpenDB(db.root, true)
if err := db.Users.extIDMap.EqualsMap(expected); err != nil {
t.Fatal(err)
}
})
}
run("insert", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x", ExtID: "x"},
}
user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a"}
if _, err := db.Users.c.Insert(*users["x"]); err != nil {
t.Fatal(err)
}
if _, err := db.Users.c.Insert(*user1); err != nil {
t.Fatal(err)
}
return users
})
run("insert with conflict", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x", ExtID: "x"},
}
user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "y", ExtID: "x"}
if _, err := db.Users.c.Insert(*users["x"]); err != nil {
t.Fatal(err)
}
if _, err := db.Users.c.Insert(*user1); !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
return users
})
run("insert and delete in index", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc", ExtID: "z"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
// Delete from index and from collection.
db.Users.extIDMap.Delete("x")
db.Users.c.Delete(users["z"].ID)
delete(users, "x")
delete(users, "z")
return users
})
run("insert and delete outside index", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
// Delete from index and from collection.
db.Users.extIDMap.Delete("x")
db.Users.c.Delete(users["z"].ID)
delete(users, "x")
delete(users, "z")
return users
})
run("update outside index", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) {
u.Name = "Whatever"
return u, nil
})
if err != nil {
t.Fatal(err)
}
delete(users, "z") // No ExtID => not in index.
return users
})
run("update into index", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) {
u.ExtID = "z"
return u, nil
})
if err != nil {
t.Fatal(err)
}
users["z"].ExtID = "z"
return users
})
run("update out of index", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.extIDMap.Update("y", func(u User) (User, error) {
u.ExtID = ""
return u, nil
})
if err != nil {
t.Fatal(err)
}
err = db.Users.c.Update(users["x"].ID, func(u User) (User, error) {
u.ExtID = ""
return u, nil
})
if err != nil {
t.Fatal(err)
}
delete(users, "x")
delete(users, "z")
delete(users, "y")
return users
})
run("update function error", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
myErr := errors.New("hello")
err := db.Users.extIDMap.Update("y", func(u User) (User, error) {
u.Email = "blah"
return u, myErr
})
if err != myErr {
t.Fatal(err)
}
return users
})
run("update ErrAbortUpdate", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.extIDMap.Update("y", func(u User) (User, error) {
u.Email = "blah"
return u, ErrAbortUpdate
})
if err != nil {
t.Fatal(err)
}
return users
})
run("update ErrNotFound", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.extIDMap.Update("z", func(u User) (User, error) {
u.Name = "blah"
return u, nil
})
if !errors.Is(err, ErrNotFound) {
t.Fatal(err)
}
return users
})
run("update conflict in to in", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "z@z.com", Name: "zz", ExtID: "z"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.extIDMap.Update("z", func(u User) (User, error) {
u.ExtID = "x"
return u, nil
})
if !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
return users
})
run("update conflict out to in", func(t *testing.T, db *DB) map[string]*User {
users := map[string]*User{
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
"z": {ID: db.Users.c.NextID(), Email: "z@z.com", Name: "zz"},
}
for _, u := range users {
if _, err := db.Users.c.Insert(*u); err != nil {
t.Fatal(err)
}
}
err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) {
u.ExtID = "x"
return u, nil
})
if !errors.Is(err, ErrDuplicate) {
t.Fatal(err)
}
delete(users, "z")
return users
})
}
func TestMapIndex_load_ErrDuplicate(t *testing.T) {
testWithDB(t, "", func(t *testing.T, db *DB) {
idx := NewMapIndex(
db.Users.c,
"broken",
func(u *User) string { return u.Name },
nil)
users := map[uint64]*User{
1: {ID: 1, Email: "x@y.com", Name: "aa", ExtID: "x"},
2: {ID: 2, Email: "b@c.com", Name: "aa", ExtID: "y"},
}
if err := idx.load(users); err != ErrDuplicate {
t.Fatal(err)
}
})
}

186
shipping_test.go Normal file
View File

@ -0,0 +1,186 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"math/rand"
"os"
"sync"
"testing"
"time"
"git.crumpington.com/public/mdb/testconn"
)
func TestShipping(t *testing.T) {
run := func(name string, inner func(t *testing.T, db1 *DB, db2 *DB, network *testconn.Network)) {
t.Run(name, func(t *testing.T) {
root1, err := os.MkdirTemp("", "")
must(err)
defer os.RemoveAll(root1)
root2, err := os.MkdirTemp("", "")
must(err)
defer os.RemoveAll(root2)
db1 := OpenDB(root1, true)
defer db1.Close()
db2 := OpenDB(root2, false)
defer db2.Close()
inner(t, db1, db2, testconn.NewNetwork())
})
}
run("simple", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
wg := sync.WaitGroup{}
wg.Add(2)
// Send in background.
go func() {
defer wg.Done()
conn := network.Accept()
db.SyncSend(conn)
}()
// Recv in background.
go func() {
defer wg.Done()
conn := network.Dial()
db2.SyncRecv(conn)
}()
for i := 0; i < 100; i++ {
db.RandAction()
}
db.WaitForSync(db2)
network.CloseClient()
wg.Wait()
if err := db.Equals(db2); err != nil {
t.Fatal(err)
}
})
run("simple multiple writers", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
wg := sync.WaitGroup{}
// Send in background.
wg.Add(1)
go func() {
defer wg.Done()
conn := network.Accept()
db.SyncSend(conn)
}()
// Recv in background.
wg.Add(1)
go func() {
defer wg.Done()
conn := network.Dial()
db2.SyncRecv(conn)
}()
updateWG := sync.WaitGroup{}
updateWG.Add(64)
for i := 0; i < 64; i++ {
go func() {
defer updateWG.Done()
for j := 0; j < 1024; j++ {
db.RandAction()
}
}()
}
updateWG.Wait()
db.WaitForSync(db2)
network.CloseClient()
wg.Wait()
if err := db.Equals(db2); err != nil {
t.Fatal(err)
}
})
run("unstable network", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
sleepTimeout := time.Millisecond
updateWG := sync.WaitGroup{}
updateWG.Add(64)
for i := 0; i < 64; i++ {
go func() {
defer updateWG.Done()
for j := 0; j < 4096; j++ {
time.Sleep(sleepTimeout)
db.RandAction()
}
}()
}
updating := &atomicBool{}
updating.Set(true)
go func() {
updateWG.Wait()
updating.Set(false)
}()
// Recv in background.
recving := &atomicBool{}
recving.Set(true)
go func() {
for {
// Stop when no longer updating and WAL files match.
if !updating.Get() {
if db.MaxSeqNum() == db2.MaxSeqNum() {
recving.Set(false)
return
}
}
if conn := network.Dial(); conn != nil {
db2.SyncRecv(conn)
}
}
}()
// Send in background.
sending := &atomicBool{}
sending.Set(true)
go func() {
for {
// Stop when no longer updating and WAL files match.
if !updating.Get() {
if db.MaxSeqNum() == db2.MaxSeqNum() {
sending.Set(false)
return
}
}
if conn := network.Accept(); conn != nil {
db.SyncSend(conn)
}
}
}()
// Interrupt network periodically as long as sending or receiving.
for sending.Get() || recving.Get() {
time.Sleep(time.Duration(rand.Intn(10 * int(sleepTimeout))))
if rand.Float64() < 0.5 {
network.CloseClient()
} else {
network.CloseServer()
}
}
if err := db.Equals(db2); err != nil {
t.Fatal(err)
}
})
}

87
testconn/net.go Normal file
View File

@ -0,0 +1,87 @@
package testconn
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"net"
"sync"
"time"
)
type Network struct {
lock sync.Mutex
// Current connections.
cConn net.Conn
sConn net.Conn
acceptQ chan net.Conn
}
func NewNetwork() *Network {
return &Network{
acceptQ: make(chan net.Conn, 1),
}
}
func (n *Network) Dial() net.Conn {
cc, sc := net.Pipe()
func() {
n.lock.Lock()
defer n.lock.Unlock()
if n.cConn != nil {
n.cConn.Close()
n.cConn = nil
}
select {
case n.acceptQ <- sc:
n.cConn = cc
default:
cc = nil
}
}()
return cc
}
func (n *Network) Accept() net.Conn {
var sc net.Conn
select {
case sc = <-n.acceptQ:
case <-time.After(time.Second):
return nil
}
func() {
n.lock.Lock()
defer n.lock.Unlock()
if n.sConn != nil {
n.sConn.Close()
n.sConn = nil
}
n.sConn = sc
}()
return sc
}
func (n *Network) CloseClient() {
n.lock.Lock()
defer n.lock.Unlock()
if n.cConn != nil {
n.cConn.Close()
n.cConn = nil
}
}
func (n *Network) CloseServer() {
n.lock.Lock()
defer n.lock.Unlock()
if n.sConn != nil {
n.sConn.Close()
n.sConn = nil
}
}

272
testdb_test.go Normal file
View File

@ -0,0 +1,272 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"errors"
"fmt"
"math/rand"
"net/mail"
"strings"
"time"
)
// ----------------------------------------------------------------------------
// Validate errors.
// ----------------------------------------------------------------------------
var ErrInvalidName = errors.New("invalid name")
// ----------------------------------------------------------------------------
// User Collection
// ----------------------------------------------------------------------------
type User struct {
ID uint64
Email string
Name string
ExtID string
}
type Users struct {
c *Collection[User]
emailMap *MapIndex[string, User] // Full map index.
emailBTree *BTreeIndex[User] // Full btree index.
nameBTree *BTreeIndex[User] // Full btree with duplicates.
extIDMap *MapIndex[string, User] // Partial map index.
extIDBTree *BTreeIndex[User] // Partial btree index.
}
func userGetID(u *User) uint64 { return u.ID }
func userSanitize(u *User) {
u.Name = strings.TrimSpace(u.Name)
e, err := mail.ParseAddress(strings.ToLower(strings.TrimSpace(u.Email)))
if err == nil {
u.Email = e.Address
}
}
func userValidate(u *User) error {
if len(u.Name) == 0 {
return ErrInvalidName
}
return nil
}
// ----------------------------------------------------------------------------
// Account Collection
// ----------------------------------------------------------------------------
type Account struct {
ID uint64
Name string
}
type Accounts struct {
c *Collection[Account]
nameMap *MapIndex[string, Account]
}
func accountGetID(a *Account) uint64 { return a.ID }
// ----------------------------------------------------------------------------
// Database
// ----------------------------------------------------------------------------
type DB struct {
*Database
root string
Users Users
Accounts Accounts
}
func OpenDB(root string, primary bool) *DB {
db := &DB{root: root}
if primary {
db.Database = NewPrimary(root)
} else {
db.Database = NewSecondary(root)
}
db.Users = Users{}
db.Users.c = NewCollection(db.Database, "users", userGetID)
db.Users.c.SetSanitize(userSanitize)
db.Users.c.SetValidate(userValidate)
db.Users.emailMap = NewMapIndex(
db.Users.c,
"email",
func(u *User) string { return u.Email },
nil)
db.Users.emailBTree = NewBTreeIndex(
db.Users.c,
func(lhs, rhs *User) bool { return lhs.Email < rhs.Email },
nil)
db.Users.nameBTree = NewBTreeIndex(
db.Users.c,
func(lhs, rhs *User) bool {
if lhs.Name != rhs.Name {
return lhs.Name < rhs.Name
}
return lhs.ID < rhs.ID
},
nil)
db.Users.extIDMap = NewMapIndex(
db.Users.c,
"extID",
func(u *User) string { return u.ExtID },
func(u *User) bool { return u.ExtID != "" })
db.Users.extIDBTree = NewBTreeIndex(
db.Users.c,
func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID },
func(u *User) bool { return u.ExtID != "" })
db.Accounts = Accounts{}
db.Accounts.c = NewCollection(db.Database, "accounts", accountGetID)
db.Accounts.nameMap = NewMapIndex(
db.Accounts.c,
"name",
func(a *Account) string { return a.Name },
nil)
db.Start()
return db
}
func (db *DB) Equals(rhs *DB) error {
db.WaitForSync(rhs)
// Users: itemMap.
if err := db.Users.c.items.Equals(rhs.Users.c.items); err != nil {
return fmt.Errorf("%w: Users.c.items not equal", err)
}
// Users: emailMap
if err := db.Users.emailMap.Equals(rhs.Users.emailMap); err != nil {
return fmt.Errorf("%w: Users.emailMap not equal", err)
}
// Users: emailBTree
if err := db.Users.emailBTree.Equals(rhs.Users.emailBTree); err != nil {
return fmt.Errorf("%w: Users.emailBTree not equal", err)
}
// Users: nameBTree
if err := db.Users.nameBTree.Equals(rhs.Users.nameBTree); err != nil {
return fmt.Errorf("%w: Users.nameBTree not equal", err)
}
// Users: extIDMap
if err := db.Users.extIDMap.Equals(rhs.Users.extIDMap); err != nil {
return fmt.Errorf("%w: Users.extIDMap not equal", err)
}
// Users: extIDBTree
if err := db.Users.extIDBTree.Equals(rhs.Users.extIDBTree); err != nil {
return fmt.Errorf("%w: Users.extIDBTree not equal", err)
}
// Accounts: itemMap
if err := db.Accounts.c.items.Equals(rhs.Accounts.c.items); err != nil {
return fmt.Errorf("%w: Accounts.c.items not equal", err)
}
// Accounts: nameMap
if err := db.Accounts.nameMap.Equals(rhs.Accounts.nameMap); err != nil {
return fmt.Errorf("%w: Accounts.nameMap not equal", err)
}
return nil
}
// Wait for two databases to become synchronized.
func (db *DB) WaitForSync(rhs *DB) {
for {
if db.MaxSeqNum() == rhs.MaxSeqNum() {
return
}
time.Sleep(100 * time.Millisecond)
}
}
var (
randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 2, 13, 14, 15, 16}
)
func (db *DB) RandAction() {
if rand.Float32() < 0.3 {
db.randActionAccount()
} else {
db.randActionUser()
}
}
func (db *DB) randActionAccount() {
id := randIDs[rand.Intn(len(randIDs))]
f := rand.Float32()
_, exists := db.Accounts.c.Get(id)
if !exists {
db.Accounts.c.Insert(Account{
ID: id,
Name: randString(),
})
return
}
if f < 0.05 {
db.Accounts.c.Delete(id)
return
}
db.Accounts.c.Update(id, func(a Account) (Account, error) {
a.Name = randString()
return a, nil
})
}
func (db *DB) randActionUser() {
id := randIDs[rand.Intn(len(randIDs))]
f := rand.Float32()
_, exists := db.Users.c.Get(id)
if !exists {
user := User{
ID: id,
Email: randStringShort() + "@domain.com",
Name: randString(),
}
if f < 0.1 {
user.ExtID = randString()
}
db.Users.c.Insert(user)
return
}
if f < 0.05 {
db.Users.c.Delete(id)
return
}
db.Users.c.Update(id, func(a User) (User, error) {
a.Name = randString()
if f < 0.1 {
a.ExtID = randString()
} else {
a.ExtID = ""
}
a.Email = randStringShort() + "@domain.com"
return a, nil
})
}

36
types.go Normal file
View File

@ -0,0 +1,36 @@
package mdb
/*
Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
type WALStatus struct {
MaxSeqNumKV uint64
MaxSeqNumWAL uint64
}
type itemIndex[T any] interface {
load(m map[uint64]*T) error
insert(*T)
update(old, new *T) // Old and new MUST have the same ID.
delete(*T)
}
type itemUniqueIndex[T any] interface {
load(m map[uint64]*T) error
name() string
lock(*T)
unlock(*T)
insertConflict(*T) bool
updateConflict(*T) bool
}
type dbCollection interface {
loadData()
onStore(uint64, []byte) // For WAL following.
onDelete(uint64) // For WAL following.
}

14
util.go Normal file
View File

@ -0,0 +1,14 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
func must(err error) {
if err != nil {
panic(err)
}
}

47
util_test.go Normal file
View File

@ -0,0 +1,47 @@
package mdb
/*Copyright (c) 2022, John David Lee
All rights reserved.
This source code is licensed under the BSD-style license found in the
LICENSE file in the root directory of this source tree.
*/
import (
"crypto/rand"
"encoding/hex"
mrand "math/rand"
"sync/atomic"
)
func randStringShort() string {
s := randString()
if len(s) > 2 {
return s[:2]
}
return s
}
func randString() string {
buf := make([]byte, 1+mrand.Intn(10))
if _, err := rand.Read(buf); err != nil {
panic(err)
}
return hex.EncodeToString(buf)
}
type atomicBool struct {
i int64
}
func (a *atomicBool) Get() bool {
return atomic.LoadInt64(&a.i) == 1
}
func (a *atomicBool) Set(b bool) {
if b {
atomic.StoreInt64(&a.i, 1)
} else {
atomic.StoreInt64(&a.i, 0)
}
}