parent
6cf0ea3c5d
commit
5641a89fee
|
@ -0,0 +1,26 @@
|
|||
Copyright 2022 John David Lee
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software without
|
||||
specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
@ -0,0 +1,165 @@
|
|||
package mdb
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/google/btree"
|
||||
)
|
||||
|
||||
type BTreeIndex[T any] struct {
|
||||
c *Collection[T]
|
||||
modLock sync.Mutex
|
||||
bt atomic.Value // *btree.BTreeG[*T]
|
||||
getID func(*T) uint64
|
||||
less func(*T, *T) bool
|
||||
include func(*T) bool
|
||||
}
|
||||
|
||||
func NewBTreeIndex[T any](
|
||||
c *Collection[T],
|
||||
less func(*T, *T) bool,
|
||||
include func(*T) bool,
|
||||
) *BTreeIndex[T] {
|
||||
|
||||
t := &BTreeIndex[T]{
|
||||
c: c,
|
||||
getID: c.getID,
|
||||
less: less,
|
||||
include: include,
|
||||
}
|
||||
|
||||
btree := btree.NewG(32, less)
|
||||
t.bt.Store(btree)
|
||||
c.indices = append(c.indices, t)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) load(m map[uint64]*T) error {
|
||||
btree := btree.NewG(64, t.less)
|
||||
t.bt.Store(btree)
|
||||
for _, item := range m {
|
||||
if t.include == nil || t.include(item) {
|
||||
if x, _ := btree.ReplaceOrInsert(item); x != nil {
|
||||
return ErrDuplicate
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (t *BTreeIndex[T]) Ascend() *BTreeIterator[T] {
|
||||
iter := newBTreeIterator[T]()
|
||||
go func() {
|
||||
t.btree().Ascend(iter.each)
|
||||
iter.done()
|
||||
}()
|
||||
return iter
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) AscendAfter(pivot T) *BTreeIterator[T] {
|
||||
iter := newBTreeIterator[T]()
|
||||
go func() {
|
||||
t.btree().AscendGreaterOrEqual(&pivot, iter.each)
|
||||
iter.done()
|
||||
}()
|
||||
return iter
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) Descend() *BTreeIterator[T] {
|
||||
iter := newBTreeIterator[T]()
|
||||
go func() {
|
||||
t.btree().Descend(iter.each)
|
||||
iter.done()
|
||||
}()
|
||||
return iter
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) DescendAfter(pivot T) *BTreeIterator[T] {
|
||||
iter := newBTreeIterator[T]()
|
||||
go func() {
|
||||
t.btree().DescendLessOrEqual(&pivot, iter.each)
|
||||
iter.done()
|
||||
}()
|
||||
return iter
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) Get(item T) (T, bool) {
|
||||
ptr, ok := t.btree().Get(&item)
|
||||
if !ok {
|
||||
return item, false
|
||||
}
|
||||
return *ptr, true
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) Min() (item T, ok bool) {
|
||||
if ptr, ok := t.btree().Min(); ok {
|
||||
return *ptr, ok
|
||||
}
|
||||
return item, false
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) Max() (item T, ok bool) {
|
||||
if ptr, ok := t.btree().Max(); ok {
|
||||
return *ptr, ok
|
||||
}
|
||||
return item, false
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) Len() int {
|
||||
return t.btree().Len()
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (t *BTreeIndex[T]) insert(item *T) {
|
||||
if t.include == nil || t.include(item) {
|
||||
t.modify(func(bt *btree.BTreeG[*T]) {
|
||||
bt.ReplaceOrInsert(item)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) update(old, new *T) {
|
||||
t.modify(func(bt *btree.BTreeG[*T]) {
|
||||
if t.include == nil || t.include(old) {
|
||||
bt.Delete(old)
|
||||
}
|
||||
if t.include == nil || t.include(new) {
|
||||
bt.ReplaceOrInsert(new)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) delete(item *T) {
|
||||
if t.include == nil || t.include(item) {
|
||||
t.modify(func(bt *btree.BTreeG[*T]) {
|
||||
bt.Delete(item)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (t *BTreeIndex[T]) btree() *btree.BTreeG[*T] {
|
||||
return t.bt.Load().(*btree.BTreeG[*T])
|
||||
}
|
||||
|
||||
func (t *BTreeIndex[T]) modify(mod func(clone *btree.BTreeG[*T])) {
|
||||
t.modLock.Lock()
|
||||
defer t.modLock.Unlock()
|
||||
clone := t.btree().Clone()
|
||||
mod(clone)
|
||||
t.bt.Store(clone)
|
||||
}
|
|
@ -0,0 +1,144 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (bt *BTreeIndex[T]) Equals(rhs *BTreeIndex[T]) error {
|
||||
if bt.Len() != rhs.Len() {
|
||||
return fmt.Errorf("Expected %d items, but found %d.", bt.Len(), rhs.Len())
|
||||
}
|
||||
|
||||
it1 := bt.Ascend()
|
||||
defer it1.Close()
|
||||
|
||||
it2 := rhs.Ascend()
|
||||
defer it2.Close()
|
||||
|
||||
for it1.Next() {
|
||||
it2.Next()
|
||||
|
||||
v1 := it1.Value()
|
||||
v2 := it2.Value()
|
||||
|
||||
if !reflect.DeepEqual(v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (bt *BTreeIndex[T]) EqualsList(data []*T) error {
|
||||
if bt.Len() != len(data) {
|
||||
return fmt.Errorf("Expected %d items, but found %d.", bt.Len(), len(data))
|
||||
}
|
||||
|
||||
if len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ascend fully.
|
||||
it := bt.Ascend()
|
||||
for _, v1 := range data {
|
||||
it.Next()
|
||||
v2 := it.Value()
|
||||
|
||||
if !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
}
|
||||
if it.Next() {
|
||||
return fmt.Errorf("Next returned true after full ascend.")
|
||||
}
|
||||
it.Close()
|
||||
|
||||
// Descend fully.
|
||||
it = bt.Descend()
|
||||
dataList := data
|
||||
for len(dataList) > 0 {
|
||||
v1 := dataList[len(dataList)-1]
|
||||
dataList = dataList[:len(dataList)-1]
|
||||
it.Next()
|
||||
v2 := it.Value()
|
||||
|
||||
if !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
}
|
||||
if it.Next() {
|
||||
return fmt.Errorf("Next returned true after full descend.")
|
||||
}
|
||||
it.Close()
|
||||
|
||||
// AscendAfter
|
||||
dataList = data
|
||||
for len(dataList) > 1 {
|
||||
dataList = dataList[1:]
|
||||
it = bt.AscendAfter(*dataList[0])
|
||||
|
||||
for _, v1 := range dataList {
|
||||
it.Next()
|
||||
v2 := it.Value()
|
||||
if !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
}
|
||||
if it.Next() {
|
||||
return fmt.Errorf("Next returned true after partial ascend.")
|
||||
}
|
||||
it.Close()
|
||||
}
|
||||
|
||||
// DescendAfter
|
||||
dataList = data
|
||||
for len(dataList) > 1 {
|
||||
dataList = dataList[:len(dataList)-1]
|
||||
it = bt.DescendAfter(*dataList[len(dataList)-1])
|
||||
|
||||
for i := len(dataList) - 1; i >= 0; i-- {
|
||||
v1 := dataList[i]
|
||||
it.Next()
|
||||
v2 := it.Value()
|
||||
if !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
}
|
||||
if it.Next() {
|
||||
return fmt.Errorf("Next returned true after partial descend: %#v", it.Value())
|
||||
}
|
||||
it.Close()
|
||||
}
|
||||
|
||||
// Using Get.
|
||||
for _, v1 := range data {
|
||||
v2, ok := bt.Get(*v1)
|
||||
if !ok || !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
}
|
||||
|
||||
// Min.
|
||||
v1 := data[0]
|
||||
v2, ok := bt.Min()
|
||||
if !ok || !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
|
||||
// Max.
|
||||
v1 = data[len(data)-1]
|
||||
v2, ok = bt.Max()
|
||||
if !ok || !reflect.DeepEqual(*v1, v2) {
|
||||
return fmt.Errorf("Value mismatch: %v != %v", *v1, v2)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,318 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFullBTreeIndex(t *testing.T) {
|
||||
|
||||
// Test against the email index.
|
||||
run := func(name string, inner func(t *testing.T, db *DB) []*User) {
|
||||
testWithDB(t, name, func(t *testing.T, db *DB) {
|
||||
expected := inner(t, db)
|
||||
|
||||
if err := db.Users.emailBTree.EqualsList(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Close()
|
||||
db = OpenDB(db.root, true)
|
||||
|
||||
if err := db.Users.emailBTree.EqualsList(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
run("insert", func(t *testing.T, db *DB) (users []*User) {
|
||||
users = append(users,
|
||||
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"})
|
||||
|
||||
for _, u := range users {
|
||||
u2, err := db.Users.c.Insert(*u)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(u2, *u) {
|
||||
t.Fatal(u2, *u)
|
||||
}
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("insert duplicate", func(t *testing.T, db *DB) (users []*User) {
|
||||
users = append(users,
|
||||
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "ccc"})
|
||||
|
||||
u1, err := db.Users.c.Insert(*users[0])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
u2, err := db.Users.c.Insert(*users[1])
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(u1, u2, err)
|
||||
}
|
||||
|
||||
return users[:1]
|
||||
})
|
||||
|
||||
run("update", func(t *testing.T, db *DB) (users []*User) {
|
||||
users = append(users,
|
||||
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"})
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users[2].ID, func(u User) (User, error) {
|
||||
u.Email = "g@h.com"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[2].Email = "g@h.com"
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("delete", func(t *testing.T, db *DB) (users []*User) {
|
||||
users = append(users,
|
||||
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"})
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
db.Users.c.Delete(users[0].ID)
|
||||
users = users[1:]
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("get not found", func(t *testing.T, db *DB) (users []*User) {
|
||||
users = append(users,
|
||||
&User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"},
|
||||
&User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"})
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if u, ok := db.Users.emailBTree.Get(User{Email: "g@h.com"}); ok {
|
||||
t.Fatal(u, ok)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("min/max empty", func(t *testing.T, db *DB) (users []*User) {
|
||||
|
||||
if u, ok := db.Users.emailBTree.Min(); ok {
|
||||
t.Fatal(u, ok)
|
||||
}
|
||||
if u, ok := db.Users.emailBTree.Max(); ok {
|
||||
t.Fatal(u, ok)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
}
|
||||
|
||||
func TestPartialBTreeIndex(t *testing.T) {
|
||||
|
||||
// Test against the extID btree index.
|
||||
run := func(name string, inner func(t *testing.T, db *DB) []*User) {
|
||||
testWithDB(t, name, func(t *testing.T, db *DB) {
|
||||
expected := inner(t, db)
|
||||
|
||||
if err := db.Users.extIDBTree.EqualsList(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Close()
|
||||
db = OpenDB(db.root, true)
|
||||
|
||||
if err := db.Users.extIDBTree.EqualsList(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
run("insert out", func(t *testing.T, db *DB) []*User {
|
||||
users := []*User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "xxx"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ggg"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "aaa"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return []*User{}
|
||||
})
|
||||
|
||||
run("insert in", func(t *testing.T, db *DB) []*User {
|
||||
users := []*User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "xxx"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ggg", ExtID: "x"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "aaa"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
return []*User{users[1]}
|
||||
})
|
||||
|
||||
run("update out to out", func(t *testing.T, db *DB) []*User {
|
||||
users := []*User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
|
||||
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users[0].ID, func(u User) (User, error) {
|
||||
u.Name = "axa"
|
||||
users[0].Name = "axa"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return []*User{users[1], users[3]}
|
||||
})
|
||||
|
||||
run("update in to in", func(t *testing.T, db *DB) []*User {
|
||||
users := []*User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
|
||||
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users[1].ID, func(u User) (User, error) {
|
||||
u.ExtID = "C"
|
||||
users[1].ExtID = "C"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return []*User{users[3], users[1]}
|
||||
})
|
||||
|
||||
run("update out to in", func(t *testing.T, db *DB) []*User {
|
||||
users := []*User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
|
||||
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users[2].ID, func(u User) (User, error) {
|
||||
u.ExtID = "C"
|
||||
users[2].ExtID = "C"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return []*User{users[1], users[3], users[2]}
|
||||
})
|
||||
|
||||
run("update in to out", func(t *testing.T, db *DB) []*User {
|
||||
users := []*User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"},
|
||||
{ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users[1].ID, func(u User) (User, error) {
|
||||
u.ExtID = ""
|
||||
users[1].ExtID = ""
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return []*User{users[3]}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBTreeIndex_load_ErrDuplicate(t *testing.T) {
|
||||
testWithDB(t, "", func(t *testing.T, db *DB) {
|
||||
idx := NewBTreeIndex(
|
||||
db.Users.c,
|
||||
func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID },
|
||||
nil)
|
||||
|
||||
users := map[uint64]*User{
|
||||
1: {ID: 1, Email: "x@y.com", Name: "xx", ExtID: "x"},
|
||||
2: {ID: 2, Email: "a@b.com", Name: "aa", ExtID: "x"},
|
||||
}
|
||||
|
||||
if err := idx.load(users); err != ErrDuplicate {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
type BTreeIterator[T any] struct {
|
||||
out chan *T
|
||||
close chan struct{}
|
||||
current *T
|
||||
}
|
||||
|
||||
func newBTreeIterator[T any]() *BTreeIterator[T] {
|
||||
return &BTreeIterator[T]{
|
||||
out: make(chan *T),
|
||||
close: make(chan struct{}, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (iter *BTreeIterator[T]) each(i *T) bool {
|
||||
select {
|
||||
case iter.out <- i:
|
||||
return true
|
||||
case <-iter.close:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (iter *BTreeIterator[T]) done() {
|
||||
close(iter.out)
|
||||
}
|
||||
|
||||
func (iter *BTreeIterator[T]) Next() bool {
|
||||
val, ok := <-iter.out
|
||||
if ok {
|
||||
iter.current = val
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (iter *BTreeIterator[T]) Close() {
|
||||
iter.close <- struct{}{}
|
||||
}
|
||||
|
||||
func (iter *BTreeIterator[T]) Value() T {
|
||||
return *iter.current
|
||||
}
|
|
@ -0,0 +1,49 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBTreeIterator(t *testing.T) {
|
||||
testWithDB(t, "min/max matches iterator", func(t *testing.T, db *DB) {
|
||||
users := []User{
|
||||
{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"},
|
||||
{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "bbb"},
|
||||
{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"},
|
||||
}
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
it := db.Users.nameBTree.Ascend()
|
||||
if !it.Next() {
|
||||
t.Fatal("No next")
|
||||
}
|
||||
u := it.Value()
|
||||
if !reflect.DeepEqual(u, users[0]) {
|
||||
t.Fatal(u, users[0])
|
||||
}
|
||||
it.Close()
|
||||
|
||||
it = db.Users.nameBTree.Descend()
|
||||
if !it.Next() {
|
||||
t.Fatal("No next")
|
||||
}
|
||||
u = it.Value()
|
||||
if !reflect.DeepEqual(u, users[2]) {
|
||||
t.Fatal(u, users[2])
|
||||
}
|
||||
it.Close()
|
||||
|
||||
})
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
"git.crumpington.com/public/mdb/kvstore"
|
||||
)
|
||||
|
||||
func decode[T any](data []byte) *T {
|
||||
item := new(T)
|
||||
must(json.Unmarshal(data, item))
|
||||
return item
|
||||
}
|
||||
|
||||
func encode(item any) []byte {
|
||||
w := bytes.NewBuffer(kvstore.GetDataBuf(0)[:0])
|
||||
must(json.NewEncoder(w).Encode(item))
|
||||
return w.Bytes()
|
||||
}
|
|
@ -0,0 +1,212 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"git.crumpington.com/public/mdb/keyedmutex"
|
||||
)
|
||||
|
||||
type Collection[T any] struct {
|
||||
primary bool
|
||||
db *Database
|
||||
name string
|
||||
idLock keyedmutex.KeyedMutex[uint64]
|
||||
items *itemMap[T]
|
||||
indices []itemIndex[T]
|
||||
uniqueIndices []itemUniqueIndex[T]
|
||||
getID func(*T) uint64
|
||||
sanitize func(*T)
|
||||
validate func(*T) error
|
||||
}
|
||||
|
||||
func NewCollection[T any](db *Database, name string, getID func(*T) uint64) *Collection[T] {
|
||||
items := newItemMap(db.kv, name, getID)
|
||||
c := &Collection[T]{
|
||||
primary: db.kv.Primary(),
|
||||
db: db,
|
||||
name: name,
|
||||
idLock: keyedmutex.New[uint64](),
|
||||
items: items,
|
||||
indices: []itemIndex[T]{items},
|
||||
uniqueIndices: []itemUniqueIndex[T]{},
|
||||
getID: items.getID,
|
||||
sanitize: func(*T) {},
|
||||
validate: func(*T) error { return nil },
|
||||
}
|
||||
|
||||
db.collections[name] = c
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Collection[T]) SetSanitize(sanitize func(*T)) {
|
||||
c.sanitize = sanitize
|
||||
}
|
||||
|
||||
func (c *Collection[T]) SetValidate(validate func(*T) error) {
|
||||
c.validate = validate
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (c *Collection[T]) NextID() uint64 {
|
||||
return c.items.nextID()
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (c *Collection[T]) Insert(item T) (T, error) {
|
||||
if !c.primary {
|
||||
return item, ErrReadOnly
|
||||
}
|
||||
|
||||
c.sanitize(&item)
|
||||
|
||||
if err := c.validate(&item); err != nil {
|
||||
return item, err
|
||||
}
|
||||
|
||||
id := c.getID(&item)
|
||||
c.idLock.Lock(id)
|
||||
defer c.idLock.Unlock(id)
|
||||
|
||||
if _, ok := c.items.Get(id); ok {
|
||||
return item, fmt.Errorf("%w: ID", ErrDuplicate)
|
||||
}
|
||||
|
||||
// Acquire locks and check for insert conflicts.
|
||||
for _, idx := range c.uniqueIndices {
|
||||
idx.lock(&item)
|
||||
defer idx.unlock(&item)
|
||||
|
||||
if idx.insertConflict(&item) {
|
||||
return item, fmt.Errorf("%w: %s", ErrDuplicate, idx.name())
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range c.indices {
|
||||
idx.insert(&item)
|
||||
}
|
||||
|
||||
return item, nil
|
||||
}
|
||||
|
||||
func (c *Collection[T]) Update(id uint64, update func(T) (T, error)) error {
|
||||
if !c.primary {
|
||||
return ErrReadOnly
|
||||
}
|
||||
|
||||
c.idLock.Lock(id)
|
||||
defer c.idLock.Unlock(id)
|
||||
|
||||
old, ok := c.items.Get(id)
|
||||
if !ok {
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
newItem, err := update(*old)
|
||||
if err != nil {
|
||||
if err == ErrAbortUpdate {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
new := &newItem
|
||||
|
||||
if c.getID(new) != id {
|
||||
return ErrMismatchedIDs
|
||||
}
|
||||
|
||||
c.sanitize(new)
|
||||
|
||||
if err := c.validate(new); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Acquire locks and check for update conflicts.
|
||||
for _, idx := range c.uniqueIndices {
|
||||
idx.lock(new)
|
||||
defer idx.unlock(new)
|
||||
|
||||
if idx.updateConflict(new) {
|
||||
return fmt.Errorf("%w: %s", ErrDuplicate, idx.name())
|
||||
}
|
||||
}
|
||||
|
||||
for _, idx := range c.indices {
|
||||
idx.update(old, new)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Collection[T]) Delete(id uint64) {
|
||||
if !c.primary {
|
||||
panic(ErrReadOnly)
|
||||
}
|
||||
|
||||
c.idLock.Lock(id)
|
||||
defer c.idLock.Unlock(id)
|
||||
|
||||
item, ok := c.items.Get(id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// Acquire locks and check for insert conflicts.
|
||||
for _, idx := range c.uniqueIndices {
|
||||
idx.lock(item)
|
||||
defer idx.unlock(item)
|
||||
}
|
||||
|
||||
for _, idx := range c.indices {
|
||||
idx.delete(item)
|
||||
}
|
||||
}
|
||||
|
||||
func (c Collection[T]) Get(id uint64) (t T, ok bool) {
|
||||
ptr, ok := c.items.Get(id)
|
||||
if !ok {
|
||||
return t, false
|
||||
}
|
||||
return *ptr, true
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (c *Collection[T]) loadData() {
|
||||
for _, idx := range c.indices {
|
||||
must(idx.load(c.items.m))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Collection[T]) onStore(id uint64, data []byte) {
|
||||
item := decode[T](data)
|
||||
old, ok := c.items.Get(id)
|
||||
if !ok {
|
||||
// Insert.
|
||||
for _, idx := range c.indices {
|
||||
idx.insert(item)
|
||||
}
|
||||
} else {
|
||||
// Otherwise update.
|
||||
for _, idx := range c.indices {
|
||||
idx.update(old, item)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Collection[T]) onDelete(id uint64) {
|
||||
if item, ok := c.items.Get(id); ok {
|
||||
for _, idx := range c.indices {
|
||||
idx.delete(item)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCollection(t *testing.T) {
|
||||
testWithDB(t, "insert on secondary", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
c.primary = false
|
||||
if _, err := c.Insert(User{}); !errors.Is(err, ErrReadOnly) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "insert validation error", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{
|
||||
ID: c.NextID(),
|
||||
Email: "a@b.com",
|
||||
}
|
||||
if _, err := c.Insert(user); !errors.Is(err, ErrInvalidName) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "insert duplicate", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{
|
||||
ID: c.NextID(),
|
||||
Name: "adsf",
|
||||
Email: "a@b.com",
|
||||
}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := c.Insert(user); !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "update on secondary", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c.primary = false
|
||||
|
||||
err := c.Update(user.ID, func(u User) (User, error) {
|
||||
u.Name = "xxx"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrReadOnly) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "update not found", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := c.Update(user.ID+1, func(u User) (User, error) {
|
||||
u.Name = "xxx"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "update failed validation", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err := c.Update(user.ID, func(u User) (User, error) {
|
||||
u.Name = ""
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrInvalidName) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "delete on secondary", func(t *testing.T, db *DB) {
|
||||
defer func() {
|
||||
if err := recover(); err == nil {
|
||||
t.Fatal("No panic")
|
||||
}
|
||||
}()
|
||||
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
c.primary = false
|
||||
|
||||
c.Delete(1)
|
||||
})
|
||||
|
||||
testWithDB(t, "delete not found", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c.Delete(user.ID + 1) // Does nothing.
|
||||
})
|
||||
|
||||
testWithDB(t, "get", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u2, ok := c.Get(user.ID)
|
||||
if !ok || !reflect.DeepEqual(user, u2) {
|
||||
t.Fatal(ok, u2, user)
|
||||
}
|
||||
})
|
||||
|
||||
testWithDB(t, "get not found", func(t *testing.T, db *DB) {
|
||||
c := db.Users.c
|
||||
user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"}
|
||||
if _, err := c.Insert(user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
u2, ok := c.Get(user.ID - 1)
|
||||
if ok {
|
||||
t.Fatal(ok, u2)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
package mdb
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"git.crumpington.com/public/mdb/kvstore"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
type LogInfo = kvstore.LogInfo
|
||||
|
||||
type Database struct {
|
||||
root string
|
||||
lock *os.File
|
||||
kv *kvstore.KV
|
||||
|
||||
collections map[string]dbCollection
|
||||
}
|
||||
|
||||
func NewPrimary(root string) *Database {
|
||||
return newDB(root, true)
|
||||
}
|
||||
|
||||
func NewSecondary(root string) *Database {
|
||||
return newDB(root, false)
|
||||
}
|
||||
|
||||
func newDB(root string, primary bool) *Database {
|
||||
must(os.MkdirAll(root, 0700))
|
||||
|
||||
lockPath := filepath.Join(root, "lock")
|
||||
|
||||
// Acquire the lock.
|
||||
lock, err := os.OpenFile(lockPath, os.O_RDWR|os.O_CREATE, 0600)
|
||||
must(err)
|
||||
must(unix.Flock(int(lock.Fd()), unix.LOCK_EX))
|
||||
|
||||
db := &Database{
|
||||
root: root,
|
||||
collections: map[string]dbCollection{},
|
||||
lock: lock,
|
||||
}
|
||||
if primary {
|
||||
db.kv = kvstore.NewPrimary(root)
|
||||
} else {
|
||||
db.kv = kvstore.NewSecondary(root, db.onStore, db.onDelete)
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *Database) Start() {
|
||||
wg := sync.WaitGroup{}
|
||||
for _, c := range db.collections {
|
||||
wg.Add(1)
|
||||
go func(c dbCollection) {
|
||||
defer wg.Done()
|
||||
c.loadData()
|
||||
}(c)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func (db *Database) MaxSeqNum() uint64 {
|
||||
return db.kv.MaxSeqNum()
|
||||
}
|
||||
|
||||
func (db *Database) LogInfo() LogInfo {
|
||||
return db.kv.LogInfo()
|
||||
}
|
||||
|
||||
func (db *Database) Close() {
|
||||
if db.kv != nil {
|
||||
db.kv.Close()
|
||||
db.kv = nil
|
||||
}
|
||||
if db.lock != nil {
|
||||
db.lock.Close()
|
||||
db.lock = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (db *Database) onStore(collection string, id uint64, data []byte) {
|
||||
if c, ok := db.collections[collection]; ok {
|
||||
c.onStore(id, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) onDelete(collection string, id uint64) {
|
||||
if c, ok := db.collections[collection]; ok {
|
||||
c.onDelete(id)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (db *Database) SyncSend(conn net.Conn) {
|
||||
db.kv.SyncSend(conn)
|
||||
}
|
||||
|
||||
func (db *Database) SyncRecv(conn net.Conn) {
|
||||
db.kv.SyncRecv(conn)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// CleanBefore deletes log entries that are more than the given number of
|
||||
// seconds old.
|
||||
func (db *Database) CleanBefore(seconds int64) {
|
||||
db.kv.CleanBefore(seconds)
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
package mdb
|
|
@ -0,0 +1,45 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDatabase(t *testing.T) {
|
||||
testWithDB(t, "multiple writers", func(t *testing.T, db *DB) {
|
||||
wg := sync.WaitGroup{}
|
||||
N := 64
|
||||
wg.Add(64)
|
||||
for i := 0; i < N; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < 1024; j++ {
|
||||
db.RandAction()
|
||||
}
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
})
|
||||
|
||||
testWithDB(t, "clean before", func(t *testing.T, db *DB) {
|
||||
db.RandAction()
|
||||
db.RandAction()
|
||||
db.RandAction()
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
db.CleanBefore(0)
|
||||
|
||||
if db.MaxSeqNum() == 0 {
|
||||
t.Fatal(db.MaxSeqNum())
|
||||
}
|
||||
})
|
||||
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
godepgraph -s . > .deps.dot && xdot .deps.dot
|
||||
|
||||
rm .deps.dot
|
|
@ -0,0 +1,21 @@
|
|||
package mdb
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrMismatchedIDs = errors.New("MismatchedIDs")
|
||||
ErrDuplicate = errors.New("Duplicate")
|
||||
ErrNotFound = errors.New("NotFound")
|
||||
ErrReadOnly = errors.New("ReadOnly")
|
||||
|
||||
// Return in update function to abort changes.
|
||||
ErrAbortUpdate = errors.New("AbortUpdate")
|
||||
)
|
|
@ -0,0 +1,9 @@
|
|||
module git.crumpington.com/public/mdb
|
||||
|
||||
go 1.18
|
||||
|
||||
require (
|
||||
github.com/google/btree v1.1.2
|
||||
github.com/mattn/go-sqlite3 v1.14.14
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d
|
||||
)
|
|
@ -0,0 +1,6 @@
|
|||
github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU=
|
||||
github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4=
|
||||
github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw=
|
||||
github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80=
|
||||
golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
|
@ -0,0 +1,114 @@
|
|||
package mdb
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"git.crumpington.com/public/mdb/keyedmutex"
|
||||
"git.crumpington.com/public/mdb/kvstore"
|
||||
)
|
||||
|
||||
// Implements ItemMap and ItemUniqueIndex interfaces.
|
||||
type itemMap[T any] struct {
|
||||
primary bool
|
||||
kv *kvstore.KV
|
||||
collection string
|
||||
idLock keyedmutex.KeyedMutex[uint64]
|
||||
mapLock sync.Mutex
|
||||
m map[uint64]*T
|
||||
getID func(*T) uint64
|
||||
maxID uint64
|
||||
}
|
||||
|
||||
func newItemMap[T any](kv *kvstore.KV, collection string, getID func(*T) uint64) *itemMap[T] {
|
||||
m := &itemMap[T]{
|
||||
primary: kv.Primary(),
|
||||
kv: kv,
|
||||
collection: collection,
|
||||
idLock: keyedmutex.New[uint64](),
|
||||
m: map[uint64]*T{},
|
||||
getID: getID,
|
||||
}
|
||||
|
||||
kv.Iterate(collection, func(id uint64, data []byte) {
|
||||
item := decode[T](data)
|
||||
if id > m.maxID {
|
||||
m.maxID = id
|
||||
}
|
||||
m.m[id] = item
|
||||
})
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *itemMap[T]) load(src map[uint64]*T) error {
|
||||
// No-op: The itemmap is the source for loading all other indices.
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (m *itemMap[T]) Get(id uint64) (*T, bool) {
|
||||
return m.mapGet(id)
|
||||
}
|
||||
|
||||
// Should hold item lock when calling.
|
||||
func (idx *itemMap[T]) insert(item *T) {
|
||||
id := idx.getID(item)
|
||||
if idx.primary {
|
||||
idx.kv.Store(idx.collection, id, encode(item))
|
||||
}
|
||||
idx.mapSet(id, item)
|
||||
}
|
||||
|
||||
// Should hold item lock when calling. old and new MUST have the same ID.
|
||||
func (idx *itemMap[T]) update(old, new *T) {
|
||||
idx.insert(new)
|
||||
}
|
||||
|
||||
// Should hold item lock when calling.
|
||||
func (idx *itemMap[T]) delete(item *T) {
|
||||
id := idx.getID(item)
|
||||
if idx.primary {
|
||||
idx.kv.Delete(idx.collection, id)
|
||||
}
|
||||
idx.mapDelete(id)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (idx *itemMap[T]) mapSet(id uint64, item *T) {
|
||||
idx.mapLock.Lock()
|
||||
idx.m[id] = item
|
||||
idx.mapLock.Unlock()
|
||||
}
|
||||
|
||||
func (idx *itemMap[T]) mapDelete(id uint64) {
|
||||
idx.mapLock.Lock()
|
||||
delete(idx.m, id)
|
||||
idx.mapLock.Unlock()
|
||||
}
|
||||
|
||||
func (idx *itemMap[T]) mapGet(id uint64) (*T, bool) {
|
||||
idx.mapLock.Lock()
|
||||
item, ok := idx.m[id]
|
||||
idx.mapLock.Unlock()
|
||||
return item, ok
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (idx *itemMap[T]) nextID() uint64 {
|
||||
n := 1 + rand.Int63n(256)
|
||||
return atomic.AddUint64(&idx.maxID, uint64(n))
|
||||
}
|
|
@ -0,0 +1,61 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (m *itemMap[T]) Equals(rhs *itemMap[T]) error {
|
||||
return m.EqualsMap(rhs.m)
|
||||
}
|
||||
|
||||
func (m *itemMap[T]) EqualsMap(data map[uint64]*T) error {
|
||||
if len(data) != len(m.m) {
|
||||
return fmt.Errorf("Expected %d items, but found %d.", len(data), len(m.m))
|
||||
}
|
||||
|
||||
for key, exp := range data {
|
||||
val, ok := m.m[key]
|
||||
if !ok {
|
||||
return fmt.Errorf("No value for %d. Expected: %v", key, *exp)
|
||||
}
|
||||
if !reflect.DeepEqual(*val, *exp) {
|
||||
return fmt.Errorf("Value mismatch %d: %v != %v", key, *val, *exp)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *itemMap[T]) EqualsKV() (err error) {
|
||||
count := 0
|
||||
m.kv.Iterate(m.collection, func(id uint64, data []byte) {
|
||||
count++
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
item := decode[T](data)
|
||||
val, ok := m.m[id]
|
||||
if !ok {
|
||||
err = fmt.Errorf("Item %d not found in memory: %v", id, *item)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(*item, *val) {
|
||||
err = fmt.Errorf("Items not equal %d: %v != %v", id, *item, *val)
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if err == nil && count != len(m.m) {
|
||||
err = fmt.Errorf("%d items on disk, but %d in memory", count, len(m.m))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
|
@ -0,0 +1,134 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestItemMap(t *testing.T) {
|
||||
|
||||
// expected is a map of users.
|
||||
run := func(name string, inner func(t *testing.T, db *DB) (expected map[uint64]*User)) {
|
||||
testWithDB(t, name, func(t *testing.T, db *DB) {
|
||||
expected := inner(t, db)
|
||||
|
||||
if err := db.Users.c.items.EqualsMap(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Users.c.items.EqualsKV(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Close()
|
||||
db = OpenDB(db.root, true)
|
||||
|
||||
if err := db.Users.c.items.EqualsMap(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := db.Users.c.items.EqualsKV(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
run("simple", func(t *testing.T, db *DB) (expected map[uint64]*User) {
|
||||
users := map[uint64]*User{}
|
||||
c := db.Users.c
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
id := c.NextID()
|
||||
users[id] = &User{
|
||||
ID: id,
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
user, err := c.Insert(*users[id])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(user, *users[id]) {
|
||||
t.Fatal(user, *users[id])
|
||||
}
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("insert and delete", func(t *testing.T, db *DB) (expected map[uint64]*User) {
|
||||
users := map[uint64]*User{}
|
||||
c := db.Users.c
|
||||
for x := uint64(1); x < 10; x++ {
|
||||
id := c.NextID()
|
||||
users[id] = &User{
|
||||
ID: id,
|
||||
Email: fmt.Sprintf("a.%d@c.com", x),
|
||||
Name: fmt.Sprintf("name.%d", x),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", x),
|
||||
}
|
||||
user, err := c.Insert(*users[id])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(user, *users[id]) {
|
||||
t.Fatal(user, *users[id])
|
||||
}
|
||||
}
|
||||
|
||||
var id uint64
|
||||
for key := range users {
|
||||
id = key
|
||||
}
|
||||
|
||||
delete(users, id)
|
||||
c.Delete(id)
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update", func(t *testing.T, db *DB) (expected map[uint64]*User) {
|
||||
users := map[uint64]*User{}
|
||||
c := db.Users.c
|
||||
for x := uint64(1); x < 10; x++ {
|
||||
id := c.NextID()
|
||||
users[id] = &User{
|
||||
ID: id,
|
||||
Email: fmt.Sprintf("a.%d@c.com", x),
|
||||
Name: fmt.Sprintf("name.%d", x),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", x),
|
||||
}
|
||||
user, err := c.Insert(*users[id])
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(user, *users[id]) {
|
||||
t.Fatal(user, *users[id])
|
||||
}
|
||||
}
|
||||
|
||||
var id uint64
|
||||
for key := range users {
|
||||
id = key
|
||||
}
|
||||
|
||||
err := c.Update(id, func(u User) (User, error) {
|
||||
u.Name = "Hello"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[id].Name = "Hello"
|
||||
|
||||
return users
|
||||
})
|
||||
}
|
|
@ -0,0 +1,75 @@
|
|||
package keyedmutex
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type KeyedMutex[K comparable] struct {
|
||||
mu *sync.Mutex
|
||||
waitList map[K]*list.List
|
||||
}
|
||||
|
||||
func New[K comparable]() KeyedMutex[K] {
|
||||
return KeyedMutex[K]{
|
||||
mu: new(sync.Mutex),
|
||||
waitList: map[K]*list.List{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) Lock(key K) {
|
||||
if ch := m.lock(key); ch != nil {
|
||||
<-ch
|
||||
}
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) lock(key K) chan struct{} {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if waitList, ok := m.waitList[key]; ok {
|
||||
ch := make(chan struct{})
|
||||
waitList.PushBack(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
m.waitList[key] = list.New()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) TryLock(key K) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.waitList[key]; ok {
|
||||
return false
|
||||
}
|
||||
|
||||
m.waitList[key] = list.New()
|
||||
return true
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) Unlock(key K) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
waitList, ok := m.waitList[key]
|
||||
if !ok {
|
||||
panic("unlock of unlocked mutex")
|
||||
}
|
||||
|
||||
if waitList.Len() == 0 {
|
||||
delete(m.waitList, key)
|
||||
} else {
|
||||
ch := waitList.Remove(waitList.Front()).(chan struct{})
|
||||
ch <- struct{}{}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
const sqlSchema = `
|
||||
BEGIN IMMEDIATE;
|
||||
|
||||
CREATE TABLE IF NOT EXISTS data(
|
||||
SeqNum INTEGER NOT NULL PRIMARY KEY,
|
||||
Deleted INTEGER NOT NULL DEFAULT 0,
|
||||
Data BLOB NOT NULL
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS data_deleted_index ON data(Deleted,SeqNum);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS log(
|
||||
SeqNum INTEGER NOT NULL PRIMARY KEY,
|
||||
CreatedAt INTEGER NOT NULL,
|
||||
Collection TEXT NOT NULL,
|
||||
ID INTEGER NOT NULL,
|
||||
Store INTEGER NOT NULL
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE INDEX IF NOT EXISTS log_created_at_index ON log(CreatedAt);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS kv(
|
||||
Collection TEXT NOT NULL,
|
||||
ID INTEGER NOT NULL,
|
||||
SeqNum INTEGER NOT NULL,
|
||||
PRIMARY KEY (Collection, ID)
|
||||
) WITHOUT ROWID;
|
||||
|
||||
CREATE VIEW IF NOT EXISTS kvdata AS
|
||||
SELECT
|
||||
kv.Collection,
|
||||
kv.ID,
|
||||
data.Data
|
||||
FROM kv
|
||||
JOIN data ON kv.SeqNum=data.SeqNum;
|
||||
|
||||
CREATE VIEW IF NOT EXISTS logdata AS
|
||||
SELECT
|
||||
log.SeqNum,
|
||||
log.Collection,
|
||||
log.ID,
|
||||
log.Store,
|
||||
data.data
|
||||
FROM log
|
||||
LEFT JOIN data on log.SeqNum=data.SeqNum;
|
||||
|
||||
CREATE TRIGGER IF NOT EXISTS deletedata AFTER UPDATE OF SeqNum ON kv
|
||||
FOR EACH ROW
|
||||
WHEN OLD.SeqNum != NEW.SeqNum
|
||||
BEGIN
|
||||
UPDATE data SET Deleted=1 WHERE SeqNum=OLD.SeqNum;
|
||||
END;
|
||||
|
||||
COMMIT;`
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
const sqlInsertData = `INSERT INTO data(SeqNum,Data) VALUES(?,?)`
|
||||
|
||||
const sqlInsertKV = `INSERT INTO kv(Collection,ID,SeqNum) VALUES (?,?,?)
|
||||
ON CONFLICT(Collection,ID) DO UPDATE SET SeqNum=excluded.SeqNum
|
||||
WHERE ID=excluded.ID`
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
const sqlDeleteKV = `DELETE FROM kv WHERE Collection=? AND ID=?`
|
||||
|
||||
const sqlDeleteData = `UPDATE data SET Deleted=1
|
||||
WHERE SeqNum=(
|
||||
SELECT SeqNum FROM kv WHERE Collection=? AND ID=?)`
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
const sqlInsertLog = `INSERT INTO log(SeqNum,CreatedAt,Collection,ID,Store)
|
||||
VALUES(?,?,?,?,?)`
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
const sqlKVIterate = `SELECT ID,Data FROM kvdata WHERE Collection=?`
|
||||
|
||||
const sqlLogIterate = `
|
||||
SELECT SeqNum,Collection,ID,Store,Data
|
||||
FROM logdata
|
||||
WHERE SeqNum > ?
|
||||
ORDER BY SeqNum ASC`
|
||||
|
||||
const sqlMaxSeqNumGet = `SELECT COALESCE(MAX(SeqNum),0) FROM log`
|
||||
|
||||
const sqlLogInfoGet = `SELECT
|
||||
COALESCE(SeqNum, 0),
|
||||
COALESCE(CreatedAt, 0)
|
||||
FROM log
|
||||
ORDER BY SeqNum DESC LIMIT 1`
|
||||
|
||||
const sqlCleanQuery = `
|
||||
DELETE FROM
|
||||
log
|
||||
WHERE
|
||||
CreatedAt < ? AND
|
||||
SeqNum < (SELECT MAX(SeqNum) FROM log);
|
||||
|
||||
DELETE FROM
|
||||
data
|
||||
WHERE
|
||||
Deleted != 0 AND
|
||||
SeqNum < (SELECT MIN(SeqNum) FROM log);`
|
|
@ -0,0 +1,5 @@
|
|||
#!/bin/bash
|
||||
|
||||
godepgraph -s . > .deps.dot && xdot .deps.dot
|
||||
|
||||
rm .deps.dot
|
|
@ -0,0 +1,40 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
connTimeout = 16 * time.Second
|
||||
pollInterval = 500 * time.Millisecond
|
||||
modQSize = 1024
|
||||
poolBufSize = 8192
|
||||
bufferPool = sync.Pool{
|
||||
New: func() any {
|
||||
return make([]byte, poolBufSize)
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func GetDataBuf(size int) []byte {
|
||||
if size > poolBufSize {
|
||||
return make([]byte, size)
|
||||
}
|
||||
return bufferPool.Get().([]byte)[:size]
|
||||
}
|
||||
|
||||
func RecycleDataBuf(b []byte) {
|
||||
if cap(b) != poolBufSize {
|
||||
return
|
||||
}
|
||||
bufferPool.Put(b)
|
||||
}
|
|
@ -0,0 +1,21 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
os.Exit(m.Run())
|
||||
}
|
|
@ -0,0 +1,53 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
const recHeaderSize = 22
|
||||
|
||||
func encodeRecordHeader(rec record, buf []byte) {
|
||||
// SeqNum (8)
|
||||
// ID (8)
|
||||
// DataLen (4)
|
||||
// Store (1)
|
||||
// CollectionLen (1)
|
||||
|
||||
binary.LittleEndian.PutUint64(buf[:8], rec.SeqNum)
|
||||
buf = buf[8:]
|
||||
binary.LittleEndian.PutUint64(buf[:8], rec.ID)
|
||||
buf = buf[8:]
|
||||
|
||||
if rec.Store {
|
||||
binary.LittleEndian.PutUint32(buf[:4], uint32(len(rec.Data)))
|
||||
buf[4] = 1
|
||||
} else {
|
||||
binary.LittleEndian.PutUint32(buf[:4], 0)
|
||||
buf[4] = 0
|
||||
}
|
||||
buf = buf[5:]
|
||||
|
||||
buf[0] = byte(len(rec.Collection))
|
||||
}
|
||||
|
||||
func decodeRecHeader(header []byte) (rec record, colLen, dataLen int) {
|
||||
buf := header
|
||||
|
||||
rec.SeqNum = binary.LittleEndian.Uint64(buf[:8])
|
||||
buf = buf[8:]
|
||||
rec.ID = binary.LittleEndian.Uint64(buf[:8])
|
||||
buf = buf[8:]
|
||||
dataLen = int(binary.LittleEndian.Uint32(buf[:4]))
|
||||
buf = buf[4:]
|
||||
rec.Store = buf[0] != 0
|
||||
buf = buf[1:]
|
||||
colLen = int(buf[0])
|
||||
|
||||
return
|
||||
}
|
|
@ -0,0 +1,254 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.crumpington.com/public/mdb/testconn"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
// Stores info from secondary callbacks.
|
||||
type callbacks struct {
|
||||
lock sync.Mutex
|
||||
m map[string]map[uint64]string
|
||||
}
|
||||
|
||||
func (sc *callbacks) onStore(c string, id uint64, data []byte) {
|
||||
sc.lock.Lock()
|
||||
defer sc.lock.Unlock()
|
||||
if _, ok := sc.m[c]; !ok {
|
||||
sc.m[c] = map[uint64]string{}
|
||||
}
|
||||
sc.m[c][id] = string(data)
|
||||
}
|
||||
|
||||
func (sc *callbacks) onDelete(c string, id uint64) {
|
||||
sc.lock.Lock()
|
||||
defer sc.lock.Unlock()
|
||||
if _, ok := sc.m[c]; !ok {
|
||||
return
|
||||
}
|
||||
delete(sc.m[c], id)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestShipping(t *testing.T) {
|
||||
run := func(name string, inner func(
|
||||
t *testing.T,
|
||||
pDir string,
|
||||
sDir string,
|
||||
primary *KV,
|
||||
secondary *KV,
|
||||
cbs *callbacks,
|
||||
nw *testconn.Network,
|
||||
)) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
pDir, _ := os.MkdirTemp("", "")
|
||||
defer os.RemoveAll(pDir)
|
||||
sDir, _ := os.MkdirTemp("", "")
|
||||
defer os.RemoveAll(sDir)
|
||||
|
||||
nw := testconn.NewNetwork()
|
||||
defer func() {
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
}()
|
||||
|
||||
cbs := &callbacks{
|
||||
m: map[string]map[uint64]string{},
|
||||
}
|
||||
|
||||
prim := NewPrimary(pDir)
|
||||
defer prim.Close()
|
||||
sec := NewSecondary(sDir, cbs.onStore, cbs.onDelete)
|
||||
defer sec.Close()
|
||||
|
||||
inner(t, pDir, sDir, prim, sec, cbs, nw)
|
||||
})
|
||||
}
|
||||
|
||||
run("simple", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
|
||||
M := 10
|
||||
N := 1000
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// Store M*N values in the background.
|
||||
for i := 0; i < M; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < N; i++ {
|
||||
time.Sleep(time.Millisecond)
|
||||
prim.randAction()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Accept()
|
||||
prim.SyncSend(conn)
|
||||
}()
|
||||
|
||||
// Recv in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Dial()
|
||||
sec.SyncRecv(conn)
|
||||
}()
|
||||
|
||||
sec.waitForSeqNum(uint64(M * N))
|
||||
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
wg.Wait()
|
||||
|
||||
if err := prim.equalsKV("a", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := prim.equalsKV("b", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := prim.equalsKV("c", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
run("simple concurrent", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
|
||||
M := 64
|
||||
N := 128
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// Store M*N values in the background.
|
||||
for i := 0; i < M; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < N; i++ {
|
||||
time.Sleep(time.Millisecond)
|
||||
prim.randAction()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Accept()
|
||||
prim.SyncSend(conn)
|
||||
}()
|
||||
|
||||
// Recv in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Dial()
|
||||
sec.SyncRecv(conn)
|
||||
}()
|
||||
|
||||
sec.waitForSeqNum(uint64(M * N))
|
||||
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
wg.Wait()
|
||||
|
||||
if err := prim.equalsKV("a", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := prim.equalsKV("b", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := prim.equalsKV("c", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
run("net failures", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
|
||||
M := 10
|
||||
N := 1000
|
||||
sleepTime := time.Millisecond
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// Store M*N values in the background.
|
||||
for i := 0; i < M; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for j := 0; j < N; j++ {
|
||||
time.Sleep(sleepTime)
|
||||
prim.randAction()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for sec.MaxSeqNum() < uint64(M*N) {
|
||||
if conn := nw.Accept(); conn != nil {
|
||||
prim.SyncSend(conn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Recv in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for sec.MaxSeqNum() < uint64(M*N) {
|
||||
if conn := nw.Dial(); conn != nil {
|
||||
sec.SyncRecv(conn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for sec.MaxSeqNum() < uint64(M*N) {
|
||||
time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime))))
|
||||
if rand.Float64() < 0.5 {
|
||||
nw.CloseClient()
|
||||
} else {
|
||||
nw.CloseServer()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
sec.waitForSeqNum(prim.MaxSeqNum())
|
||||
wg.Wait()
|
||||
|
||||
if err := prim.equalsKV("a", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := prim.equalsKV("b", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := prim.equalsKV("c", sec); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
|
@ -0,0 +1,149 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
type LogInfo struct {
|
||||
MaxSeqNum uint64
|
||||
MaxCreatedAt uint64
|
||||
}
|
||||
|
||||
type KV struct {
|
||||
primary bool
|
||||
dbPath string
|
||||
|
||||
db *sql.DB
|
||||
|
||||
maxSeqNumStmt *sql.Stmt
|
||||
logInfoStmt *sql.Stmt
|
||||
logIterateStmt *sql.Stmt
|
||||
|
||||
w *writer
|
||||
|
||||
onStore func(string, uint64, []byte)
|
||||
onDelete func(string, uint64)
|
||||
|
||||
closeLock sync.Mutex
|
||||
recvLock sync.Mutex
|
||||
}
|
||||
|
||||
func newKV(
|
||||
dir string,
|
||||
primary bool,
|
||||
onStore func(string, uint64, []byte),
|
||||
onDelete func(string, uint64),
|
||||
) *KV {
|
||||
kv := &KV{
|
||||
dbPath: dbPath(dir),
|
||||
primary: primary,
|
||||
onStore: onStore,
|
||||
onDelete: onDelete,
|
||||
}
|
||||
|
||||
opts := `?_journal=WAL`
|
||||
db, err := sql.Open("sqlite3", kv.dbPath+opts)
|
||||
must(err)
|
||||
|
||||
_, err = db.Exec(sqlSchema)
|
||||
must(err)
|
||||
|
||||
kv.maxSeqNumStmt, err = db.Prepare(sqlMaxSeqNumGet)
|
||||
must(err)
|
||||
kv.logInfoStmt, err = db.Prepare(sqlLogInfoGet)
|
||||
must(err)
|
||||
kv.logIterateStmt, err = db.Prepare(sqlLogIterate)
|
||||
must(err)
|
||||
|
||||
_, err = db.Exec(sqlSchema)
|
||||
must(err)
|
||||
|
||||
kv.db = db
|
||||
|
||||
if kv.primary {
|
||||
kv.w = newWriter(kv.db)
|
||||
kv.w.Start(kv.MaxSeqNum())
|
||||
}
|
||||
return kv
|
||||
}
|
||||
|
||||
func NewPrimary(dir string) *KV {
|
||||
return newKV(dir, true, nil, nil)
|
||||
}
|
||||
|
||||
func NewSecondary(
|
||||
dir string,
|
||||
onStore func(collection string, id uint64, data []byte),
|
||||
onDelete func(collection string, id uint64),
|
||||
) *KV {
|
||||
return newKV(dir, false, onStore, onDelete)
|
||||
}
|
||||
|
||||
func (kv *KV) Primary() bool {
|
||||
return kv.primary
|
||||
}
|
||||
|
||||
func (kv *KV) MaxSeqNum() (seqNum uint64) {
|
||||
must(kv.maxSeqNumStmt.QueryRow().Scan(&seqNum))
|
||||
return seqNum
|
||||
}
|
||||
|
||||
func (kv *KV) LogInfo() (info LogInfo) {
|
||||
must(kv.logInfoStmt.QueryRow().Scan(&info.MaxSeqNum, &info.MaxCreatedAt))
|
||||
return
|
||||
}
|
||||
|
||||
func (kv *KV) Iterate(collection string, each func(id uint64, data []byte)) {
|
||||
rows, err := kv.db.Query(sqlKVIterate, collection)
|
||||
must(err)
|
||||
defer rows.Close()
|
||||
|
||||
var (
|
||||
id uint64
|
||||
data []byte
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
must(rows.Scan(&id, &data))
|
||||
each(id, data)
|
||||
}
|
||||
}
|
||||
|
||||
func (kv *KV) Close() {
|
||||
kv.closeLock.Lock()
|
||||
defer kv.closeLock.Unlock()
|
||||
|
||||
if kv.w != nil {
|
||||
kv.w.Stop()
|
||||
}
|
||||
|
||||
if kv.db != nil {
|
||||
kv.db.Close()
|
||||
kv.db = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (kv *KV) Store(collection string, id uint64, data []byte) {
|
||||
kv.w.Store(collection, id, data)
|
||||
}
|
||||
|
||||
func (kv *KV) Delete(collection string, id uint64) {
|
||||
kv.w.Delete(collection, id)
|
||||
}
|
||||
|
||||
func (kv *KV) CleanBefore(seconds int64) {
|
||||
_, err := kv.db.Exec(sqlCleanQuery, time.Now().Unix()-seconds)
|
||||
must(err)
|
||||
}
|
|
@ -0,0 +1,125 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (kv *KV) waitForSeqNum(x uint64) {
|
||||
for {
|
||||
seqNum := kv.MaxSeqNum()
|
||||
if seqNum >= x {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (kv *KV) dump(collection string) map[uint64]string {
|
||||
m := map[uint64]string{}
|
||||
kv.Iterate(collection, func(id uint64, data []byte) {
|
||||
m[id] = string(data)
|
||||
})
|
||||
return m
|
||||
}
|
||||
|
||||
func (kv *KV) equals(collection string, expected map[uint64]string) error {
|
||||
m := kv.dump(collection)
|
||||
if len(m) != len(expected) {
|
||||
return fmt.Errorf("Expected %d values but found %d", len(expected), len(m))
|
||||
}
|
||||
|
||||
for key, exp := range expected {
|
||||
val, ok := m[key]
|
||||
if !ok {
|
||||
return fmt.Errorf("Value for %d not found.", key)
|
||||
}
|
||||
if val != exp {
|
||||
return fmt.Errorf("%d: Expected %s but found %s.", key, exp, val)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (kv *KV) equalsKV(collection string, rhs *KV) error {
|
||||
l1 := []record{}
|
||||
kv.replay(0, func(rec record) error {
|
||||
l1 = append(l1, rec)
|
||||
return nil
|
||||
})
|
||||
|
||||
idx := -1
|
||||
err := rhs.replay(0, func(rec record) error {
|
||||
idx++
|
||||
if !reflect.DeepEqual(rec, l1[idx]) {
|
||||
return fmt.Errorf("Records not equal: %d %v %v", idx, rec, l1[idx])
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return kv.equals(collection, rhs.dump(collection))
|
||||
}
|
||||
|
||||
// Collection one of ("a", "b", "c").
|
||||
// ID one of [1,10]
|
||||
var (
|
||||
randCollections = []string{"a", "b", "c"}
|
||||
randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16}
|
||||
)
|
||||
|
||||
func (kv *KV) randAction() {
|
||||
c := randCollections[rand.Intn(len(randCollections))]
|
||||
id := randIDs[rand.Intn(len(randIDs))]
|
||||
|
||||
// Mostly stores.
|
||||
if rand.Float64() < 0.9 {
|
||||
kv.Store(c, id, randBytes())
|
||||
} else {
|
||||
kv.Delete(c, id)
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func TestKV(t *testing.T) {
|
||||
run := func(name string, inner func(t *testing.T, kv *KV)) {
|
||||
dir, _ := os.MkdirTemp("", "")
|
||||
defer os.RemoveAll(dir)
|
||||
kv := NewPrimary(dir)
|
||||
defer kv.Close()
|
||||
|
||||
inner(t, kv)
|
||||
}
|
||||
|
||||
run("simple", func(t *testing.T, kv *KV) {
|
||||
kv.Store("a", 1, _b("Hello"))
|
||||
kv.Store("a", 2, _b("World"))
|
||||
|
||||
kv.waitForSeqNum(2)
|
||||
|
||||
err := kv.equals("a", map[uint64]string{
|
||||
1: "Hello",
|
||||
2: "World",
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,103 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (kv *KV) SyncRecv(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
if kv.primary {
|
||||
panic("SyncRecv called on primary.")
|
||||
}
|
||||
|
||||
if !kv.recvLock.TryLock() {
|
||||
return
|
||||
}
|
||||
defer kv.recvLock.Unlock()
|
||||
|
||||
// It's important that we stop the writer so that all queued writes are
|
||||
// committed to the database before syncing has a chance to restart.
|
||||
w := newWriter(kv.db)
|
||||
w.Start(kv.MaxSeqNum())
|
||||
defer w.Stop()
|
||||
|
||||
headerBuf := make([]byte, recHeaderSize)
|
||||
nameBuf := make([]byte, 32)
|
||||
afterSeqNumBuf := make([]byte, 8)
|
||||
|
||||
afterSeqNum := kv.MaxSeqNum()
|
||||
expectedSeqNum := afterSeqNum + 1
|
||||
|
||||
// Send fromID to the conn.
|
||||
conn.SetWriteDeadline(time.Now().Add(connTimeout))
|
||||
binary.LittleEndian.PutUint64(afterSeqNumBuf, afterSeqNum)
|
||||
if _, err := conn.Write(afterSeqNumBuf); err != nil {
|
||||
log.Printf("RecvWAL failed to send after sequence number: %v", err)
|
||||
return
|
||||
}
|
||||
conn.SetWriteDeadline(time.Time{})
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(connTimeout))
|
||||
|
||||
if _, err := conn.Read(headerBuf); err != nil {
|
||||
log.Printf("RecvWAL failed to read header: %v", err)
|
||||
return
|
||||
}
|
||||
rec, nameLen, dataLen := decodeRecHeader(headerBuf)
|
||||
|
||||
// Heartbeat.
|
||||
if rec.SeqNum == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if rec.SeqNum != expectedSeqNum {
|
||||
log.Printf("Expected sequence number %d but got %d.",
|
||||
expectedSeqNum, rec.SeqNum)
|
||||
return
|
||||
}
|
||||
expectedSeqNum++
|
||||
|
||||
if cap(nameBuf) < nameLen {
|
||||
nameBuf = make([]byte, nameLen)
|
||||
}
|
||||
nameBuf = nameBuf[:nameLen]
|
||||
|
||||
if _, err := conn.Read(nameBuf); err != nil {
|
||||
log.Printf("RecvWAL failed to read collection name: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
rec.Collection = string(nameBuf)
|
||||
|
||||
// Note that it's OK to apply changes via onStore / onDelete before the log
|
||||
// entries are written because all changes are idempotent. Even if the log
|
||||
// is replayed from some point in the past, the final state will be
|
||||
// consistent.
|
||||
|
||||
if rec.Store {
|
||||
rec.Data = GetDataBuf(dataLen)
|
||||
if _, err := conn.Read(rec.Data); err != nil {
|
||||
log.Printf("RecvWAL failed to read data: %v", err)
|
||||
return
|
||||
}
|
||||
kv.onStore(rec.Collection, rec.ID, rec.Data)
|
||||
w.StoreAsync(rec.Collection, rec.ID, rec.Data)
|
||||
} else {
|
||||
kv.onDelete(rec.Collection, rec.ID)
|
||||
w.DeleteAsync(rec.Collection, rec.ID)
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,116 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"log"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (kv *KV) SyncSend(conn net.Conn) {
|
||||
defer conn.Close()
|
||||
|
||||
var (
|
||||
seqNumBuf = make([]byte, 8)
|
||||
headerBuf = make([]byte, recHeaderSize)
|
||||
empty = make([]byte, recHeaderSize)
|
||||
err error
|
||||
)
|
||||
|
||||
// Read afterSeqNum from the conn.
|
||||
conn.SetReadDeadline(time.Now().Add(connTimeout))
|
||||
if _, err := conn.Read(seqNumBuf[:8]); err != nil {
|
||||
log.Printf("SyncSend failed to read afterSeqNum: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
afterSeqNum := binary.LittleEndian.Uint64(seqNumBuf[:8])
|
||||
|
||||
POLL:
|
||||
|
||||
for i := 0; i < 4; i++ {
|
||||
if kv.MaxSeqNum() > afterSeqNum {
|
||||
goto REPLAY
|
||||
}
|
||||
time.Sleep(pollInterval)
|
||||
}
|
||||
|
||||
goto HEARTBEAT
|
||||
|
||||
HEARTBEAT:
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(connTimeout))
|
||||
if _, err := conn.Write(empty); err != nil {
|
||||
log.Printf("SendWAL failed to send heartbeat: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
goto POLL
|
||||
|
||||
REPLAY:
|
||||
|
||||
err = kv.replay(afterSeqNum, func(rec record) error {
|
||||
conn.SetWriteDeadline(time.Now().Add(connTimeout))
|
||||
|
||||
afterSeqNum = rec.SeqNum
|
||||
encodeRecordHeader(rec, headerBuf)
|
||||
if _, err := conn.Write(headerBuf); err != nil {
|
||||
log.Printf("SendWAL failed to send header %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write([]byte(rec.Collection)); err != nil {
|
||||
log.Printf("SendWAL failed to send collection name %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !rec.Store {
|
||||
return nil
|
||||
}
|
||||
|
||||
if _, err := conn.Write(rec.Data); err != nil {
|
||||
log.Printf("SendWAL failed to send data %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
goto POLL
|
||||
}
|
||||
|
||||
func (kv *KV) replay(afterSeqNum uint64, each func(rec record) error) error {
|
||||
rec := record{}
|
||||
rows, err := kv.logIterateStmt.Query(afterSeqNum)
|
||||
must(err)
|
||||
defer rows.Close()
|
||||
|
||||
rec.Data = GetDataBuf(0)
|
||||
defer RecycleDataBuf(rec.Data)
|
||||
|
||||
for rows.Next() {
|
||||
must(rows.Scan(
|
||||
&rec.SeqNum,
|
||||
&rec.Collection,
|
||||
&rec.ID,
|
||||
&rec.Store,
|
||||
&rec.Data))
|
||||
|
||||
if err = each(rec); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,27 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import "sync"
|
||||
|
||||
type modJob struct {
|
||||
Collection string
|
||||
ID uint64
|
||||
Store bool
|
||||
Data []byte
|
||||
Ready *sync.WaitGroup
|
||||
}
|
||||
|
||||
type record struct {
|
||||
SeqNum uint64
|
||||
Collection string
|
||||
ID uint64
|
||||
Store bool
|
||||
Data []byte
|
||||
}
|
|
@ -0,0 +1,23 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
)
|
||||
|
||||
func must(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func dbPath(dir string) string {
|
||||
return filepath.Join(dir, "db")
|
||||
}
|
|
@ -0,0 +1,29 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
mrand "math/rand"
|
||||
)
|
||||
|
||||
func _b(in string) []byte {
|
||||
return []byte(in)
|
||||
}
|
||||
|
||||
func randString() string {
|
||||
buf := make([]byte, 1+mrand.Intn(20))
|
||||
rand.Read(buf)
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func randBytes() []byte {
|
||||
return _b(randString())
|
||||
}
|
|
@ -0,0 +1,194 @@
|
|||
package kvstore
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type writer struct {
|
||||
db *sql.DB
|
||||
modQ chan modJob
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
func newWriter(db *sql.DB) *writer {
|
||||
return &writer{
|
||||
db: db,
|
||||
modQ: make(chan modJob, modQSize),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writer) Start(maxSeqNum uint64) {
|
||||
w.wg.Add(1)
|
||||
go w.run(maxSeqNum)
|
||||
}
|
||||
|
||||
func (w *writer) Stop() {
|
||||
close(w.modQ)
|
||||
w.wg.Wait()
|
||||
}
|
||||
|
||||
// Takes ownership of the incoming data.
|
||||
func (w *writer) Store(collection string, id uint64, data []byte) {
|
||||
job := modJob{
|
||||
Collection: collection,
|
||||
ID: id,
|
||||
Store: true,
|
||||
Data: data,
|
||||
Ready: &sync.WaitGroup{},
|
||||
}
|
||||
job.Ready.Add(1)
|
||||
w.modQ <- job
|
||||
job.Ready.Wait()
|
||||
}
|
||||
|
||||
func (w *writer) Delete(collection string, id uint64) {
|
||||
job := modJob{
|
||||
Collection: collection,
|
||||
ID: id,
|
||||
Store: false,
|
||||
Ready: &sync.WaitGroup{},
|
||||
}
|
||||
job.Ready.Add(1)
|
||||
w.modQ <- job
|
||||
job.Ready.Wait()
|
||||
}
|
||||
|
||||
// Takes ownership of the incoming data.
|
||||
func (w *writer) StoreAsync(collection string, id uint64, data []byte) {
|
||||
w.modQ <- modJob{
|
||||
Collection: collection,
|
||||
ID: id,
|
||||
Store: true,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writer) DeleteAsync(collection string, id uint64) {
|
||||
w.modQ <- modJob{
|
||||
Collection: collection,
|
||||
ID: id,
|
||||
Store: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writer) run(maxSeqNum uint64) {
|
||||
defer w.wg.Done()
|
||||
|
||||
var (
|
||||
mod modJob
|
||||
ok bool
|
||||
tx *sql.Tx
|
||||
insertData *sql.Stmt
|
||||
insertKV *sql.Stmt
|
||||
deleteData *sql.Stmt
|
||||
deleteKV *sql.Stmt
|
||||
insertLog *sql.Stmt
|
||||
err error
|
||||
curSeqNum = maxSeqNum
|
||||
now int64
|
||||
wgs = make([]*sync.WaitGroup, 10)
|
||||
)
|
||||
|
||||
BEGIN:
|
||||
|
||||
insertData = nil
|
||||
deleteData = nil
|
||||
|
||||
wgs = wgs[:0]
|
||||
|
||||
mod, ok = <-w.modQ
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
tx, err = w.db.Begin()
|
||||
must(err)
|
||||
|
||||
now = time.Now().Unix()
|
||||
|
||||
insertLog, err = tx.Prepare(sqlInsertLog)
|
||||
must(err)
|
||||
|
||||
LOOP:
|
||||
|
||||
if mod.Ready != nil {
|
||||
wgs = append(wgs, mod.Ready)
|
||||
}
|
||||
|
||||
curSeqNum++
|
||||
if mod.Store {
|
||||
goto STORE
|
||||
} else {
|
||||
goto DELETE
|
||||
}
|
||||
|
||||
STORE:
|
||||
|
||||
if insertData == nil {
|
||||
insertData, err = tx.Prepare(sqlInsertData)
|
||||
must(err)
|
||||
insertKV, err = tx.Prepare(sqlInsertKV)
|
||||
must(err)
|
||||
}
|
||||
|
||||
_, err = insertData.Exec(curSeqNum, mod.Data)
|
||||
must(err)
|
||||
_, err = insertKV.Exec(mod.Collection, mod.ID, curSeqNum)
|
||||
must(err)
|
||||
_, err = insertLog.Exec(curSeqNum, now, mod.Collection, mod.ID, true)
|
||||
must(err)
|
||||
|
||||
RecycleDataBuf(mod.Data)
|
||||
|
||||
goto NEXT
|
||||
|
||||
DELETE:
|
||||
|
||||
if deleteData == nil {
|
||||
deleteData, err = tx.Prepare(sqlDeleteData)
|
||||
must(err)
|
||||
deleteKV, err = tx.Prepare(sqlDeleteKV)
|
||||
must(err)
|
||||
}
|
||||
|
||||
_, err = deleteData.Exec(mod.Collection, mod.ID)
|
||||
must(err)
|
||||
_, err = deleteKV.Exec(mod.Collection, mod.ID)
|
||||
must(err)
|
||||
_, err = insertLog.Exec(curSeqNum, now, mod.Collection, mod.ID, false)
|
||||
must(err)
|
||||
|
||||
goto NEXT
|
||||
|
||||
NEXT:
|
||||
|
||||
select {
|
||||
case mod, ok = <-w.modQ:
|
||||
if ok {
|
||||
goto LOOP
|
||||
}
|
||||
default:
|
||||
}
|
||||
|
||||
goto COMMIT
|
||||
|
||||
COMMIT:
|
||||
|
||||
must(tx.Commit())
|
||||
|
||||
for i := range wgs {
|
||||
wgs[i].Done()
|
||||
}
|
||||
|
||||
goto BEGIN
|
||||
}
|
|
@ -0,0 +1,32 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func testWithDB(t *testing.T, name string, inner func(t *testing.T, db *DB)) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
root, err := os.MkdirTemp("", "")
|
||||
must(err)
|
||||
defer os.RemoveAll(root)
|
||||
|
||||
db := OpenDB(root, true)
|
||||
defer db.Close()
|
||||
inner(t, db)
|
||||
})
|
||||
}
|
|
@ -0,0 +1,177 @@
|
|||
package mdb
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"git.crumpington.com/public/mdb/keyedmutex"
|
||||
)
|
||||
|
||||
type MapIndex[K comparable, T any] struct {
|
||||
c *Collection[T]
|
||||
_name string
|
||||
keyLock keyedmutex.KeyedMutex[K]
|
||||
mapLock sync.Mutex
|
||||
m map[K]*T
|
||||
getID func(*T) uint64
|
||||
getKey func(*T) K
|
||||
include func(*T) bool
|
||||
}
|
||||
|
||||
func NewMapIndex[K comparable, T any](
|
||||
c *Collection[T],
|
||||
name string,
|
||||
getKey func(*T) K,
|
||||
include func(*T) bool,
|
||||
) *MapIndex[K, T] {
|
||||
m := &MapIndex[K, T]{
|
||||
c: c,
|
||||
_name: name,
|
||||
keyLock: keyedmutex.New[K](),
|
||||
m: map[K]*T{},
|
||||
getID: c.getID,
|
||||
getKey: getKey,
|
||||
include: include,
|
||||
}
|
||||
|
||||
c.indices = append(c.indices, m)
|
||||
c.uniqueIndices = append(c.uniqueIndices, m)
|
||||
return m
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (m *MapIndex[K, T]) load(src map[uint64]*T) error {
|
||||
for _, item := range src {
|
||||
if m.include != nil && !m.include(item) {
|
||||
continue
|
||||
}
|
||||
|
||||
k := m.getKey(item)
|
||||
if _, ok := m.m[k]; ok {
|
||||
return ErrDuplicate
|
||||
}
|
||||
|
||||
m.m[k] = item
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MapIndex[K, T]) Get(k K) (t T, ok bool) {
|
||||
if item, ok := m.mapGet(k); ok {
|
||||
return *item, true
|
||||
}
|
||||
|
||||
return t, false
|
||||
}
|
||||
|
||||
func (m *MapIndex[K, T]) Update(k K, update func(T) (T, error)) error {
|
||||
wrapped := func(item T) (T, error) {
|
||||
new, err := update(item)
|
||||
if err != nil {
|
||||
return item, err
|
||||
}
|
||||
|
||||
return new, nil
|
||||
}
|
||||
|
||||
if item, ok := m.mapGet(k); ok {
|
||||
return m.c.Update(m.getID(item), wrapped)
|
||||
}
|
||||
|
||||
return ErrNotFound
|
||||
}
|
||||
|
||||
func (m *MapIndex[K, T]) Delete(k K) {
|
||||
if item, ok := m.mapGet(k); ok {
|
||||
m.c.Delete(m.getID(item))
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (m *MapIndex[K, T]) insert(item *T) {
|
||||
if m.include == nil || m.include(item) {
|
||||
m.mapSet(m.getKey(item), item)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MapIndex[K, T]) update(old, new *T) {
|
||||
if m.include == nil || m.include(old) {
|
||||
m.mapDelete(m.getKey(old))
|
||||
}
|
||||
if m.include == nil || m.include(new) {
|
||||
m.mapSet(m.getKey(new), new)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MapIndex[K, T]) delete(item *T) {
|
||||
if m.include == nil || m.include(item) {
|
||||
m.mapDelete(m.getKey(item))
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (idx *MapIndex[K, T]) name() string {
|
||||
return idx._name
|
||||
}
|
||||
|
||||
func (idx *MapIndex[K, T]) lock(item *T) {
|
||||
if idx.include == nil || idx.include(item) {
|
||||
idx.keyLock.Lock(idx.getKey(item))
|
||||
}
|
||||
}
|
||||
|
||||
func (idx *MapIndex[K, T]) unlock(item *T) {
|
||||
if idx.include == nil || idx.include(item) {
|
||||
idx.keyLock.Unlock(idx.getKey(item))
|
||||
}
|
||||
}
|
||||
|
||||
// Should hold item lock when calling.
|
||||
func (idx *MapIndex[K, T]) insertConflict(new *T) bool {
|
||||
if idx.include != nil && !idx.include(new) {
|
||||
return false
|
||||
}
|
||||
_, ok := idx.Get(idx.getKey(new))
|
||||
return ok
|
||||
}
|
||||
|
||||
// Should hold item lock when calling.
|
||||
func (idx *MapIndex[K, T]) updateConflict(new *T) bool {
|
||||
if idx.include != nil && !idx.include(new) {
|
||||
return false
|
||||
}
|
||||
cur, ok := idx.mapGet(idx.getKey(new))
|
||||
return ok && idx.getID(cur) != idx.getID(new)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (idx *MapIndex[K, T]) mapSet(k K, t *T) {
|
||||
idx.mapLock.Lock()
|
||||
idx.m[k] = t
|
||||
idx.mapLock.Unlock()
|
||||
}
|
||||
|
||||
func (idx *MapIndex[K, T]) mapDelete(k K) {
|
||||
idx.mapLock.Lock()
|
||||
delete(idx.m, k)
|
||||
idx.mapLock.Unlock()
|
||||
}
|
||||
|
||||
func (idx *MapIndex[K, T]) mapGet(k K) (*T, bool) {
|
||||
idx.mapLock.Lock()
|
||||
t, ok := idx.m[k]
|
||||
idx.mapLock.Unlock()
|
||||
return t, ok
|
||||
}
|
|
@ -0,0 +1,35 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
func (m *MapIndex[K, T]) Equals(rhs *MapIndex[K, T]) error {
|
||||
return m.EqualsMap(rhs.m)
|
||||
}
|
||||
|
||||
func (m *MapIndex[K, T]) EqualsMap(data map[K]*T) error {
|
||||
if len(m.m) != len(data) {
|
||||
return fmt.Errorf("Expected %d items, but found %d.", len(data), len(m.m))
|
||||
}
|
||||
|
||||
for key, exp := range data {
|
||||
val, ok := m.Get(key)
|
||||
if !ok {
|
||||
return fmt.Errorf("No value for %v. Expected: %v", key, *exp)
|
||||
}
|
||||
if !reflect.DeepEqual(val, *exp) {
|
||||
return fmt.Errorf("Value mismatch %v: %v != %v", key, val, *exp)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,692 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestFullMapIndex(t *testing.T) {
|
||||
|
||||
// Test against the emailMap index.
|
||||
run := func(name string, inner func(t *testing.T, db *DB) map[string]*User) {
|
||||
testWithDB(t, name, func(t *testing.T, db *DB) {
|
||||
expected := inner(t, db)
|
||||
|
||||
if err := db.Users.emailMap.EqualsMap(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Close()
|
||||
db = OpenDB(db.root, true)
|
||||
|
||||
if err := db.Users.emailMap.EqualsMap(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
run("insert", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
|
||||
user2, err := db.Users.c.Insert(*user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(*user, user2) {
|
||||
t.Fatal(*user, user2)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("delete", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
|
||||
user2, err := db.Users.c.Insert(*user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(*user, user2) {
|
||||
t.Fatal(*user, user2)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var id string
|
||||
for key := range users {
|
||||
id = key
|
||||
break
|
||||
}
|
||||
|
||||
delete(users, id)
|
||||
db.Users.emailMap.Delete(id)
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update non-indexed field", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
|
||||
user2, err := db.Users.c.Insert(*user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(*user, user2) {
|
||||
t.Fatal(*user, user2)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var id string
|
||||
for key := range users {
|
||||
id = key
|
||||
break
|
||||
}
|
||||
|
||||
err := db.Users.emailMap.Update(id, func(u User) (User, error) {
|
||||
u.Name = "UPDATED"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users[id].Name = "UPDATED"
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update indexed field", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
|
||||
user2, err := db.Users.c.Insert(*user)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(*user, user2) {
|
||||
t.Fatal(*user, user2)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var id uint64
|
||||
var email string
|
||||
for key := range users {
|
||||
email = key
|
||||
id = users[key].ID
|
||||
break
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(id, func(u User) (User, error) {
|
||||
u.Email = "test@x.com"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
user := users[email]
|
||||
user.Email = "test@x.com"
|
||||
delete(users, email)
|
||||
users[user.Email] = user
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update change id error", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
if _, err := db.Users.c.Insert(*user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var email string
|
||||
for key := range users {
|
||||
email = key
|
||||
break
|
||||
}
|
||||
|
||||
err := db.Users.emailMap.Update(email, func(u User) (User, error) {
|
||||
u.ID++
|
||||
return u, nil
|
||||
})
|
||||
if err != ErrMismatchedIDs {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update function error", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
if _, err := db.Users.c.Insert(*user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var email string
|
||||
for key := range users {
|
||||
email = key
|
||||
break
|
||||
}
|
||||
|
||||
myErr := errors.New("hello")
|
||||
|
||||
err := db.Users.emailMap.Update(email, func(u User) (User, error) {
|
||||
return u, myErr
|
||||
})
|
||||
if err != myErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update ErrAbortUpdate", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
if _, err := db.Users.c.Insert(*user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var email string
|
||||
for key := range users {
|
||||
email = key
|
||||
break
|
||||
}
|
||||
|
||||
err := db.Users.emailMap.Update(email, func(u User) (User, error) {
|
||||
return u, ErrAbortUpdate
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update ErrNotFound", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
for i := uint64(1); i < 10; i++ {
|
||||
user := &User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: fmt.Sprintf("a.%d@c.com", i),
|
||||
Name: fmt.Sprintf("name.%d", i),
|
||||
ExtID: fmt.Sprintf("EXTID.%d", i),
|
||||
}
|
||||
if _, err := db.Users.c.Insert(*user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[user.Email] = user
|
||||
}
|
||||
|
||||
var email string
|
||||
for key := range users {
|
||||
email = key
|
||||
break
|
||||
}
|
||||
|
||||
err := db.Users.emailMap.Update(email+"x", func(u User) (User, error) {
|
||||
return u, nil
|
||||
})
|
||||
if err != ErrNotFound {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("insert conflict", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
user := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a", ExtID: ""}
|
||||
|
||||
if _, err := db.Users.c.Insert(*user); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
users[user.Email] = user
|
||||
|
||||
user2 := User{
|
||||
ID: db.Users.c.NextID(),
|
||||
Email: "a@b.com",
|
||||
Name: "someone else",
|
||||
ExtID: "123",
|
||||
}
|
||||
|
||||
_, err := db.Users.c.Insert(user2)
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update conflict", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{}
|
||||
|
||||
user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a"}
|
||||
user2 := &User{ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x"}
|
||||
|
||||
users[user1.Email] = user1
|
||||
users[user2.Email] = user2
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(user2.ID, func(u User) (User, error) {
|
||||
u.Email = "a@b.com"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Users.emailMap.Update("x@y.com", func(u User) (User, error) {
|
||||
u.Email = "a@b.com"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
}
|
||||
|
||||
func TestPartialMapIndex(t *testing.T) {
|
||||
// Test against the extID map index.
|
||||
run := func(name string, inner func(t *testing.T, db *DB) map[string]*User) {
|
||||
testWithDB(t, name, func(t *testing.T, db *DB) {
|
||||
expected := inner(t, db)
|
||||
|
||||
if err := db.Users.extIDMap.EqualsMap(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db.Close()
|
||||
db = OpenDB(db.root, true)
|
||||
|
||||
if err := db.Users.extIDMap.EqualsMap(expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
run("insert", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x", ExtID: "x"},
|
||||
}
|
||||
user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a"}
|
||||
|
||||
if _, err := db.Users.c.Insert(*users["x"]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.Users.c.Insert(*user1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("insert with conflict", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x", ExtID: "x"},
|
||||
}
|
||||
user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "y", ExtID: "x"}
|
||||
|
||||
if _, err := db.Users.c.Insert(*users["x"]); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := db.Users.c.Insert(*user1); !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("insert and delete in index", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc", ExtID: "z"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete from index and from collection.
|
||||
db.Users.extIDMap.Delete("x")
|
||||
db.Users.c.Delete(users["z"].ID)
|
||||
|
||||
delete(users, "x")
|
||||
delete(users, "z")
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("insert and delete outside index", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete from index and from collection.
|
||||
db.Users.extIDMap.Delete("x")
|
||||
db.Users.c.Delete(users["z"].ID)
|
||||
|
||||
delete(users, "x")
|
||||
delete(users, "z")
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update outside index", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) {
|
||||
u.Name = "Whatever"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
delete(users, "z") // No ExtID => not in index.
|
||||
return users
|
||||
})
|
||||
|
||||
run("update into index", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) {
|
||||
u.ExtID = "z"
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
users["z"].ExtID = "z"
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update out of index", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.extIDMap.Update("y", func(u User) (User, error) {
|
||||
u.ExtID = ""
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = db.Users.c.Update(users["x"].ID, func(u User) (User, error) {
|
||||
u.ExtID = ""
|
||||
return u, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
delete(users, "x")
|
||||
delete(users, "z")
|
||||
delete(users, "y")
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update function error", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
myErr := errors.New("hello")
|
||||
|
||||
err := db.Users.extIDMap.Update("y", func(u User) (User, error) {
|
||||
u.Email = "blah"
|
||||
return u, myErr
|
||||
})
|
||||
if err != myErr {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update ErrAbortUpdate", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.extIDMap.Update("y", func(u User) (User, error) {
|
||||
u.Email = "blah"
|
||||
return u, ErrAbortUpdate
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update ErrNotFound", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.extIDMap.Update("z", func(u User) (User, error) {
|
||||
u.Name = "blah"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update conflict in to in", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "z@z.com", Name: "zz", ExtID: "z"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.extIDMap.Update("z", func(u User) (User, error) {
|
||||
u.ExtID = "x"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return users
|
||||
})
|
||||
|
||||
run("update conflict out to in", func(t *testing.T, db *DB) map[string]*User {
|
||||
users := map[string]*User{
|
||||
"x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
"y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"},
|
||||
"z": {ID: db.Users.c.NextID(), Email: "z@z.com", Name: "zz"},
|
||||
}
|
||||
|
||||
for _, u := range users {
|
||||
if _, err := db.Users.c.Insert(*u); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) {
|
||||
u.ExtID = "x"
|
||||
return u, nil
|
||||
})
|
||||
if !errors.Is(err, ErrDuplicate) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
delete(users, "z")
|
||||
|
||||
return users
|
||||
})
|
||||
}
|
||||
|
||||
func TestMapIndex_load_ErrDuplicate(t *testing.T) {
|
||||
testWithDB(t, "", func(t *testing.T, db *DB) {
|
||||
idx := NewMapIndex(
|
||||
db.Users.c,
|
||||
"broken",
|
||||
func(u *User) string { return u.Name },
|
||||
nil)
|
||||
|
||||
users := map[uint64]*User{
|
||||
1: {ID: 1, Email: "x@y.com", Name: "aa", ExtID: "x"},
|
||||
2: {ID: 2, Email: "b@c.com", Name: "aa", ExtID: "y"},
|
||||
}
|
||||
|
||||
if err := idx.load(users); err != ErrDuplicate {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,186 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"os"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.crumpington.com/public/mdb/testconn"
|
||||
)
|
||||
|
||||
func TestShipping(t *testing.T) {
|
||||
run := func(name string, inner func(t *testing.T, db1 *DB, db2 *DB, network *testconn.Network)) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
root1, err := os.MkdirTemp("", "")
|
||||
must(err)
|
||||
defer os.RemoveAll(root1)
|
||||
root2, err := os.MkdirTemp("", "")
|
||||
must(err)
|
||||
defer os.RemoveAll(root2)
|
||||
|
||||
db1 := OpenDB(root1, true)
|
||||
defer db1.Close()
|
||||
db2 := OpenDB(root2, false)
|
||||
defer db2.Close()
|
||||
|
||||
inner(t, db1, db2, testconn.NewNetwork())
|
||||
})
|
||||
}
|
||||
|
||||
run("simple", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
|
||||
// Send in background.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := network.Accept()
|
||||
db.SyncSend(conn)
|
||||
}()
|
||||
|
||||
// Recv in background.
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := network.Dial()
|
||||
db2.SyncRecv(conn)
|
||||
}()
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
db.RandAction()
|
||||
}
|
||||
|
||||
db.WaitForSync(db2)
|
||||
network.CloseClient()
|
||||
wg.Wait()
|
||||
|
||||
if err := db.Equals(db2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
run("simple multiple writers", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
// Send in background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := network.Accept()
|
||||
db.SyncSend(conn)
|
||||
}()
|
||||
|
||||
// Recv in background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := network.Dial()
|
||||
db2.SyncRecv(conn)
|
||||
}()
|
||||
|
||||
updateWG := sync.WaitGroup{}
|
||||
updateWG.Add(64)
|
||||
for i := 0; i < 64; i++ {
|
||||
go func() {
|
||||
defer updateWG.Done()
|
||||
for j := 0; j < 1024; j++ {
|
||||
db.RandAction()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
updateWG.Wait()
|
||||
|
||||
db.WaitForSync(db2)
|
||||
network.CloseClient()
|
||||
wg.Wait()
|
||||
|
||||
if err := db.Equals(db2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
|
||||
run("unstable network", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
|
||||
sleepTimeout := time.Millisecond
|
||||
|
||||
updateWG := sync.WaitGroup{}
|
||||
updateWG.Add(64)
|
||||
for i := 0; i < 64; i++ {
|
||||
go func() {
|
||||
defer updateWG.Done()
|
||||
for j := 0; j < 4096; j++ {
|
||||
time.Sleep(sleepTimeout)
|
||||
db.RandAction()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
updating := &atomicBool{}
|
||||
updating.Set(true)
|
||||
|
||||
go func() {
|
||||
updateWG.Wait()
|
||||
updating.Set(false)
|
||||
}()
|
||||
|
||||
// Recv in background.
|
||||
recving := &atomicBool{}
|
||||
recving.Set(true)
|
||||
go func() {
|
||||
for {
|
||||
// Stop when no longer updating and WAL files match.
|
||||
if !updating.Get() {
|
||||
if db.MaxSeqNum() == db2.MaxSeqNum() {
|
||||
recving.Set(false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if conn := network.Dial(); conn != nil {
|
||||
db2.SyncRecv(conn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Send in background.
|
||||
sending := &atomicBool{}
|
||||
sending.Set(true)
|
||||
go func() {
|
||||
for {
|
||||
// Stop when no longer updating and WAL files match.
|
||||
if !updating.Get() {
|
||||
if db.MaxSeqNum() == db2.MaxSeqNum() {
|
||||
sending.Set(false)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if conn := network.Accept(); conn != nil {
|
||||
db.SyncSend(conn)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Interrupt network periodically as long as sending or receiving.
|
||||
for sending.Get() || recving.Get() {
|
||||
time.Sleep(time.Duration(rand.Intn(10 * int(sleepTimeout))))
|
||||
if rand.Float64() < 0.5 {
|
||||
network.CloseClient()
|
||||
} else {
|
||||
network.CloseServer()
|
||||
}
|
||||
}
|
||||
|
||||
if err := db.Equals(db2); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
|
@ -0,0 +1,87 @@
|
|||
package testconn
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
lock sync.Mutex
|
||||
// Current connections.
|
||||
cConn net.Conn
|
||||
sConn net.Conn
|
||||
|
||||
acceptQ chan net.Conn
|
||||
}
|
||||
|
||||
func NewNetwork() *Network {
|
||||
return &Network{
|
||||
acceptQ: make(chan net.Conn, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) Dial() net.Conn {
|
||||
cc, sc := net.Pipe()
|
||||
func() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.cConn != nil {
|
||||
n.cConn.Close()
|
||||
n.cConn = nil
|
||||
}
|
||||
select {
|
||||
case n.acceptQ <- sc:
|
||||
n.cConn = cc
|
||||
default:
|
||||
cc = nil
|
||||
}
|
||||
}()
|
||||
return cc
|
||||
}
|
||||
|
||||
func (n *Network) Accept() net.Conn {
|
||||
var sc net.Conn
|
||||
select {
|
||||
case sc = <-n.acceptQ:
|
||||
case <-time.After(time.Second):
|
||||
return nil
|
||||
}
|
||||
|
||||
func() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.sConn != nil {
|
||||
n.sConn.Close()
|
||||
n.sConn = nil
|
||||
}
|
||||
n.sConn = sc
|
||||
}()
|
||||
return sc
|
||||
}
|
||||
|
||||
func (n *Network) CloseClient() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.cConn != nil {
|
||||
n.cConn.Close()
|
||||
n.cConn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) CloseServer() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.sConn != nil {
|
||||
n.sConn.Close()
|
||||
n.sConn = nil
|
||||
}
|
||||
}
|
|
@ -0,0 +1,272 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/mail"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Validate errors.
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
var ErrInvalidName = errors.New("invalid name")
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// User Collection
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type User struct {
|
||||
ID uint64
|
||||
Email string
|
||||
Name string
|
||||
ExtID string
|
||||
}
|
||||
|
||||
type Users struct {
|
||||
c *Collection[User]
|
||||
emailMap *MapIndex[string, User] // Full map index.
|
||||
emailBTree *BTreeIndex[User] // Full btree index.
|
||||
nameBTree *BTreeIndex[User] // Full btree with duplicates.
|
||||
extIDMap *MapIndex[string, User] // Partial map index.
|
||||
extIDBTree *BTreeIndex[User] // Partial btree index.
|
||||
}
|
||||
|
||||
func userGetID(u *User) uint64 { return u.ID }
|
||||
|
||||
func userSanitize(u *User) {
|
||||
u.Name = strings.TrimSpace(u.Name)
|
||||
e, err := mail.ParseAddress(strings.ToLower(strings.TrimSpace(u.Email)))
|
||||
if err == nil {
|
||||
u.Email = e.Address
|
||||
}
|
||||
}
|
||||
|
||||
func userValidate(u *User) error {
|
||||
if len(u.Name) == 0 {
|
||||
return ErrInvalidName
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Account Collection
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type Account struct {
|
||||
ID uint64
|
||||
Name string
|
||||
}
|
||||
|
||||
type Accounts struct {
|
||||
c *Collection[Account]
|
||||
nameMap *MapIndex[string, Account]
|
||||
}
|
||||
|
||||
func accountGetID(a *Account) uint64 { return a.ID }
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Database
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type DB struct {
|
||||
*Database
|
||||
root string
|
||||
Users Users
|
||||
Accounts Accounts
|
||||
}
|
||||
|
||||
func OpenDB(root string, primary bool) *DB {
|
||||
db := &DB{root: root}
|
||||
if primary {
|
||||
db.Database = NewPrimary(root)
|
||||
} else {
|
||||
db.Database = NewSecondary(root)
|
||||
}
|
||||
|
||||
db.Users = Users{}
|
||||
db.Users.c = NewCollection(db.Database, "users", userGetID)
|
||||
db.Users.c.SetSanitize(userSanitize)
|
||||
db.Users.c.SetValidate(userValidate)
|
||||
|
||||
db.Users.emailMap = NewMapIndex(
|
||||
db.Users.c,
|
||||
"email",
|
||||
func(u *User) string { return u.Email },
|
||||
nil)
|
||||
|
||||
db.Users.emailBTree = NewBTreeIndex(
|
||||
db.Users.c,
|
||||
func(lhs, rhs *User) bool { return lhs.Email < rhs.Email },
|
||||
nil)
|
||||
|
||||
db.Users.nameBTree = NewBTreeIndex(
|
||||
db.Users.c,
|
||||
func(lhs, rhs *User) bool {
|
||||
if lhs.Name != rhs.Name {
|
||||
return lhs.Name < rhs.Name
|
||||
}
|
||||
return lhs.ID < rhs.ID
|
||||
},
|
||||
nil)
|
||||
|
||||
db.Users.extIDMap = NewMapIndex(
|
||||
db.Users.c,
|
||||
"extID",
|
||||
func(u *User) string { return u.ExtID },
|
||||
func(u *User) bool { return u.ExtID != "" })
|
||||
|
||||
db.Users.extIDBTree = NewBTreeIndex(
|
||||
db.Users.c,
|
||||
func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID },
|
||||
func(u *User) bool { return u.ExtID != "" })
|
||||
|
||||
db.Accounts = Accounts{}
|
||||
db.Accounts.c = NewCollection(db.Database, "accounts", accountGetID)
|
||||
|
||||
db.Accounts.nameMap = NewMapIndex(
|
||||
db.Accounts.c,
|
||||
"name",
|
||||
func(a *Account) string { return a.Name },
|
||||
nil)
|
||||
|
||||
db.Start()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) Equals(rhs *DB) error {
|
||||
db.WaitForSync(rhs)
|
||||
|
||||
// Users: itemMap.
|
||||
if err := db.Users.c.items.Equals(rhs.Users.c.items); err != nil {
|
||||
return fmt.Errorf("%w: Users.c.items not equal", err)
|
||||
}
|
||||
|
||||
// Users: emailMap
|
||||
if err := db.Users.emailMap.Equals(rhs.Users.emailMap); err != nil {
|
||||
return fmt.Errorf("%w: Users.emailMap not equal", err)
|
||||
}
|
||||
|
||||
// Users: emailBTree
|
||||
if err := db.Users.emailBTree.Equals(rhs.Users.emailBTree); err != nil {
|
||||
return fmt.Errorf("%w: Users.emailBTree not equal", err)
|
||||
}
|
||||
|
||||
// Users: nameBTree
|
||||
if err := db.Users.nameBTree.Equals(rhs.Users.nameBTree); err != nil {
|
||||
return fmt.Errorf("%w: Users.nameBTree not equal", err)
|
||||
}
|
||||
|
||||
// Users: extIDMap
|
||||
if err := db.Users.extIDMap.Equals(rhs.Users.extIDMap); err != nil {
|
||||
return fmt.Errorf("%w: Users.extIDMap not equal", err)
|
||||
}
|
||||
|
||||
// Users: extIDBTree
|
||||
if err := db.Users.extIDBTree.Equals(rhs.Users.extIDBTree); err != nil {
|
||||
return fmt.Errorf("%w: Users.extIDBTree not equal", err)
|
||||
}
|
||||
|
||||
// Accounts: itemMap
|
||||
if err := db.Accounts.c.items.Equals(rhs.Accounts.c.items); err != nil {
|
||||
return fmt.Errorf("%w: Accounts.c.items not equal", err)
|
||||
}
|
||||
|
||||
// Accounts: nameMap
|
||||
if err := db.Accounts.nameMap.Equals(rhs.Accounts.nameMap); err != nil {
|
||||
return fmt.Errorf("%w: Accounts.nameMap not equal", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Wait for two databases to become synchronized.
|
||||
func (db *DB) WaitForSync(rhs *DB) {
|
||||
for {
|
||||
if db.MaxSeqNum() == rhs.MaxSeqNum() {
|
||||
return
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 2, 13, 14, 15, 16}
|
||||
)
|
||||
|
||||
func (db *DB) RandAction() {
|
||||
if rand.Float32() < 0.3 {
|
||||
db.randActionAccount()
|
||||
} else {
|
||||
db.randActionUser()
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) randActionAccount() {
|
||||
id := randIDs[rand.Intn(len(randIDs))]
|
||||
f := rand.Float32()
|
||||
|
||||
_, exists := db.Accounts.c.Get(id)
|
||||
if !exists {
|
||||
db.Accounts.c.Insert(Account{
|
||||
ID: id,
|
||||
Name: randString(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if f < 0.05 {
|
||||
db.Accounts.c.Delete(id)
|
||||
return
|
||||
}
|
||||
db.Accounts.c.Update(id, func(a Account) (Account, error) {
|
||||
a.Name = randString()
|
||||
return a, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (db *DB) randActionUser() {
|
||||
id := randIDs[rand.Intn(len(randIDs))]
|
||||
f := rand.Float32()
|
||||
|
||||
_, exists := db.Users.c.Get(id)
|
||||
if !exists {
|
||||
user := User{
|
||||
ID: id,
|
||||
Email: randStringShort() + "@domain.com",
|
||||
Name: randString(),
|
||||
}
|
||||
if f < 0.1 {
|
||||
user.ExtID = randString()
|
||||
}
|
||||
db.Users.c.Insert(user)
|
||||
return
|
||||
}
|
||||
|
||||
if f < 0.05 {
|
||||
db.Users.c.Delete(id)
|
||||
return
|
||||
}
|
||||
|
||||
db.Users.c.Update(id, func(a User) (User, error) {
|
||||
a.Name = randString()
|
||||
if f < 0.1 {
|
||||
a.ExtID = randString()
|
||||
} else {
|
||||
a.ExtID = ""
|
||||
}
|
||||
a.Email = randStringShort() + "@domain.com"
|
||||
return a, nil
|
||||
})
|
||||
}
|
|
@ -0,0 +1,36 @@
|
|||
package mdb
|
||||
|
||||
/*
|
||||
Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
type WALStatus struct {
|
||||
MaxSeqNumKV uint64
|
||||
MaxSeqNumWAL uint64
|
||||
}
|
||||
|
||||
type itemIndex[T any] interface {
|
||||
load(m map[uint64]*T) error
|
||||
insert(*T)
|
||||
update(old, new *T) // Old and new MUST have the same ID.
|
||||
delete(*T)
|
||||
}
|
||||
|
||||
type itemUniqueIndex[T any] interface {
|
||||
load(m map[uint64]*T) error
|
||||
name() string
|
||||
lock(*T)
|
||||
unlock(*T)
|
||||
insertConflict(*T) bool
|
||||
updateConflict(*T) bool
|
||||
}
|
||||
|
||||
type dbCollection interface {
|
||||
loadData()
|
||||
onStore(uint64, []byte) // For WAL following.
|
||||
onDelete(uint64) // For WAL following.
|
||||
}
|
|
@ -0,0 +1,14 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
func must(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
package mdb
|
||||
|
||||
/*Copyright (c) 2022, John David Lee
|
||||
All rights reserved.
|
||||
|
||||
This source code is licensed under the BSD-style license found in the
|
||||
LICENSE file in the root directory of this source tree.
|
||||
*/
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
mrand "math/rand"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
func randStringShort() string {
|
||||
s := randString()
|
||||
if len(s) > 2 {
|
||||
return s[:2]
|
||||
}
|
||||
return s
|
||||
}
|
||||
|
||||
func randString() string {
|
||||
buf := make([]byte, 1+mrand.Intn(10))
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return hex.EncodeToString(buf)
|
||||
}
|
||||
|
||||
type atomicBool struct {
|
||||
i int64
|
||||
}
|
||||
|
||||
func (a *atomicBool) Get() bool {
|
||||
return atomic.LoadInt64(&a.i) == 1
|
||||
}
|
||||
|
||||
func (a *atomicBool) Set(b bool) {
|
||||
if b {
|
||||
atomic.StoreInt64(&a.i, 1)
|
||||
} else {
|
||||
atomic.StoreInt64(&a.i, 0)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue