diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..cdd9b45 --- /dev/null +++ b/LICENSE @@ -0,0 +1,26 @@ +Copyright 2022 John David Lee + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors +may be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/README.md b/README.md index 4eb9d3d..73c4b8d 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,3 @@ # mdb +An in-process, in-memory database for Go. diff --git a/btreeindex.go b/btreeindex.go new file mode 100644 index 0000000..eb181fc --- /dev/null +++ b/btreeindex.go @@ -0,0 +1,165 @@ +package mdb + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "sync" + "sync/atomic" + + "github.com/google/btree" +) + +type BTreeIndex[T any] struct { + c *Collection[T] + modLock sync.Mutex + bt atomic.Value // *btree.BTreeG[*T] + getID func(*T) uint64 + less func(*T, *T) bool + include func(*T) bool +} + +func NewBTreeIndex[T any]( + c *Collection[T], + less func(*T, *T) bool, + include func(*T) bool, +) *BTreeIndex[T] { + + t := &BTreeIndex[T]{ + c: c, + getID: c.getID, + less: less, + include: include, + } + + btree := btree.NewG(32, less) + t.bt.Store(btree) + c.indices = append(c.indices, t) + + return t +} + +func (t *BTreeIndex[T]) load(m map[uint64]*T) error { + btree := btree.NewG(64, t.less) + t.bt.Store(btree) + for _, item := range m { + if t.include == nil || t.include(item) { + if x, _ := btree.ReplaceOrInsert(item); x != nil { + return ErrDuplicate + } + } + } + return nil +} + +// ---------------------------------------------------------------------------- + +func (t *BTreeIndex[T]) Ascend() *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().Ascend(iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) AscendAfter(pivot T) *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().AscendGreaterOrEqual(&pivot, iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) Descend() *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().Descend(iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) DescendAfter(pivot T) *BTreeIterator[T] { + iter := newBTreeIterator[T]() + go func() { + t.btree().DescendLessOrEqual(&pivot, iter.each) + iter.done() + }() + return iter +} + +func (t *BTreeIndex[T]) Get(item T) (T, bool) { + ptr, ok := t.btree().Get(&item) + if !ok { + return item, false + } + return *ptr, true +} + +func (t *BTreeIndex[T]) Min() (item T, ok bool) { + if ptr, ok := t.btree().Min(); ok { + return *ptr, ok + } + return item, false +} + +func (t *BTreeIndex[T]) Max() (item T, ok bool) { + if ptr, ok := t.btree().Max(); ok { + return *ptr, ok + } + return item, false +} + +func (t *BTreeIndex[T]) Len() int { + return t.btree().Len() +} + +// ---------------------------------------------------------------------------- + +func (t *BTreeIndex[T]) insert(item *T) { + if t.include == nil || t.include(item) { + t.modify(func(bt *btree.BTreeG[*T]) { + bt.ReplaceOrInsert(item) + }) + } +} + +func (t *BTreeIndex[T]) update(old, new *T) { + t.modify(func(bt *btree.BTreeG[*T]) { + if t.include == nil || t.include(old) { + bt.Delete(old) + } + if t.include == nil || t.include(new) { + bt.ReplaceOrInsert(new) + } + }) +} + +func (t *BTreeIndex[T]) delete(item *T) { + if t.include == nil || t.include(item) { + t.modify(func(bt *btree.BTreeG[*T]) { + bt.Delete(item) + }) + } +} + +// ---------------------------------------------------------------------------- + +func (t *BTreeIndex[T]) btree() *btree.BTreeG[*T] { + return t.bt.Load().(*btree.BTreeG[*T]) +} + +func (t *BTreeIndex[T]) modify(mod func(clone *btree.BTreeG[*T])) { + t.modLock.Lock() + defer t.modLock.Unlock() + clone := t.btree().Clone() + mod(clone) + t.bt.Store(clone) +} diff --git a/btreeindex_ex_test.go b/btreeindex_ex_test.go new file mode 100644 index 0000000..04d2704 --- /dev/null +++ b/btreeindex_ex_test.go @@ -0,0 +1,144 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "fmt" + "reflect" +) + +func (bt *BTreeIndex[T]) Equals(rhs *BTreeIndex[T]) error { + if bt.Len() != rhs.Len() { + return fmt.Errorf("Expected %d items, but found %d.", bt.Len(), rhs.Len()) + } + + it1 := bt.Ascend() + defer it1.Close() + + it2 := rhs.Ascend() + defer it2.Close() + + for it1.Next() { + it2.Next() + + v1 := it1.Value() + v2 := it2.Value() + + if !reflect.DeepEqual(v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", v1, v2) + } + } + + return nil +} + +func (bt *BTreeIndex[T]) EqualsList(data []*T) error { + if bt.Len() != len(data) { + return fmt.Errorf("Expected %d items, but found %d.", bt.Len(), len(data)) + } + + if len(data) == 0 { + return nil + } + + // Ascend fully. + it := bt.Ascend() + for _, v1 := range data { + it.Next() + v2 := it.Value() + + if !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + } + if it.Next() { + return fmt.Errorf("Next returned true after full ascend.") + } + it.Close() + + // Descend fully. + it = bt.Descend() + dataList := data + for len(dataList) > 0 { + v1 := dataList[len(dataList)-1] + dataList = dataList[:len(dataList)-1] + it.Next() + v2 := it.Value() + + if !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + } + if it.Next() { + return fmt.Errorf("Next returned true after full descend.") + } + it.Close() + + // AscendAfter + dataList = data + for len(dataList) > 1 { + dataList = dataList[1:] + it = bt.AscendAfter(*dataList[0]) + + for _, v1 := range dataList { + it.Next() + v2 := it.Value() + if !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + } + if it.Next() { + return fmt.Errorf("Next returned true after partial ascend.") + } + it.Close() + } + + // DescendAfter + dataList = data + for len(dataList) > 1 { + dataList = dataList[:len(dataList)-1] + it = bt.DescendAfter(*dataList[len(dataList)-1]) + + for i := len(dataList) - 1; i >= 0; i-- { + v1 := dataList[i] + it.Next() + v2 := it.Value() + if !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + } + if it.Next() { + return fmt.Errorf("Next returned true after partial descend: %#v", it.Value()) + } + it.Close() + } + + // Using Get. + for _, v1 := range data { + v2, ok := bt.Get(*v1) + if !ok || !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + } + + // Min. + v1 := data[0] + v2, ok := bt.Min() + if !ok || !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + + // Max. + v1 = data[len(data)-1] + v2, ok = bt.Max() + if !ok || !reflect.DeepEqual(*v1, v2) { + return fmt.Errorf("Value mismatch: %v != %v", *v1, v2) + } + + return nil +} diff --git a/btreeindex_test.go b/btreeindex_test.go new file mode 100644 index 0000000..45c9038 --- /dev/null +++ b/btreeindex_test.go @@ -0,0 +1,318 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "errors" + "reflect" + "testing" +) + +func TestFullBTreeIndex(t *testing.T) { + + // Test against the email index. + run := func(name string, inner func(t *testing.T, db *DB) []*User) { + testWithDB(t, name, func(t *testing.T, db *DB) { + expected := inner(t, db) + + if err := db.Users.emailBTree.EqualsList(expected); err != nil { + t.Fatal(err) + } + + db.Close() + db = OpenDB(db.root, true) + + if err := db.Users.emailBTree.EqualsList(expected); err != nil { + t.Fatal(err) + } + }) + } + + run("insert", func(t *testing.T, db *DB) (users []*User) { + users = append(users, + &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + &User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"}) + + for _, u := range users { + u2, err := db.Users.c.Insert(*u) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(u2, *u) { + t.Fatal(u2, *u) + } + } + + return users + }) + + run("insert duplicate", func(t *testing.T, db *DB) (users []*User) { + users = append(users, + &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "ccc"}) + + u1, err := db.Users.c.Insert(*users[0]) + if err != nil { + t.Fatal(err) + } + u2, err := db.Users.c.Insert(*users[1]) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(u1, u2, err) + } + + return users[:1] + }) + + run("update", func(t *testing.T, db *DB) (users []*User) { + users = append(users, + &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + &User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}, + &User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"}) + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users[2].ID, func(u User) (User, error) { + u.Email = "g@h.com" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + users[2].Email = "g@h.com" + + return users + }) + + run("delete", func(t *testing.T, db *DB) (users []*User) { + users = append(users, + &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + &User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"}, + &User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}) + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + db.Users.c.Delete(users[0].ID) + users = users[1:] + + return users + }) + + run("get not found", func(t *testing.T, db *DB) (users []*User) { + users = append(users, + &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + &User{ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"}, + &User{ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}) + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + if u, ok := db.Users.emailBTree.Get(User{Email: "g@h.com"}); ok { + t.Fatal(u, ok) + } + + return users + }) + + run("min/max empty", func(t *testing.T, db *DB) (users []*User) { + + if u, ok := db.Users.emailBTree.Min(); ok { + t.Fatal(u, ok) + } + if u, ok := db.Users.emailBTree.Max(); ok { + t.Fatal(u, ok) + } + + return users + }) +} + +func TestPartialBTreeIndex(t *testing.T) { + + // Test against the extID btree index. + run := func(name string, inner func(t *testing.T, db *DB) []*User) { + testWithDB(t, name, func(t *testing.T, db *DB) { + expected := inner(t, db) + + if err := db.Users.extIDBTree.EqualsList(expected); err != nil { + t.Fatal(err) + } + + db.Close() + db = OpenDB(db.root, true) + + if err := db.Users.extIDBTree.EqualsList(expected); err != nil { + t.Fatal(err) + } + }) + } + + run("insert out", func(t *testing.T, db *DB) []*User { + users := []*User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "xxx"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ggg"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "aaa"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + return []*User{} + }) + + run("insert in", func(t *testing.T, db *DB) []*User { + users := []*User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "xxx"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ggg", ExtID: "x"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "aaa"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + return []*User{users[1]} + }) + + run("update out to out", func(t *testing.T, db *DB) []*User { + users := []*User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}, + {ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users[0].ID, func(u User) (User, error) { + u.Name = "axa" + users[0].Name = "axa" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + return []*User{users[1], users[3]} + }) + + run("update in to in", func(t *testing.T, db *DB) []*User { + users := []*User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}, + {ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users[1].ID, func(u User) (User, error) { + u.ExtID = "C" + users[1].ExtID = "C" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + return []*User{users[3], users[1]} + }) + + run("update out to in", func(t *testing.T, db *DB) []*User { + users := []*User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}, + {ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users[2].ID, func(u User) (User, error) { + u.ExtID = "C" + users[2].ExtID = "C" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + return []*User{users[1], users[3], users[2]} + }) + + run("update in to out", func(t *testing.T, db *DB) []*User { + users := []*User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc", ExtID: "A"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "eee"}, + {ID: db.Users.c.NextID(), Email: "g@h.com", Name: "ggg", ExtID: "B"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users[1].ID, func(u User) (User, error) { + u.ExtID = "" + users[1].ExtID = "" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + return []*User{users[3]} + }) +} + +func TestBTreeIndex_load_ErrDuplicate(t *testing.T) { + testWithDB(t, "", func(t *testing.T, db *DB) { + idx := NewBTreeIndex( + db.Users.c, + func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID }, + nil) + + users := map[uint64]*User{ + 1: {ID: 1, Email: "x@y.com", Name: "xx", ExtID: "x"}, + 2: {ID: 2, Email: "a@b.com", Name: "aa", ExtID: "x"}, + } + + if err := idx.load(users); err != ErrDuplicate { + t.Fatal(err) + } + }) +} diff --git a/btreeiterator.go b/btreeiterator.go new file mode 100644 index 0000000..70d0e72 --- /dev/null +++ b/btreeiterator.go @@ -0,0 +1,51 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +type BTreeIterator[T any] struct { + out chan *T + close chan struct{} + current *T +} + +func newBTreeIterator[T any]() *BTreeIterator[T] { + return &BTreeIterator[T]{ + out: make(chan *T), + close: make(chan struct{}, 1), + } +} + +func (iter *BTreeIterator[T]) each(i *T) bool { + select { + case iter.out <- i: + return true + case <-iter.close: + return false + } +} + +func (iter *BTreeIterator[T]) done() { + close(iter.out) +} + +func (iter *BTreeIterator[T]) Next() bool { + val, ok := <-iter.out + if ok { + iter.current = val + return true + } + return false +} + +func (iter *BTreeIterator[T]) Close() { + iter.close <- struct{}{} +} + +func (iter *BTreeIterator[T]) Value() T { + return *iter.current +} diff --git a/btreeiterator_test.go b/btreeiterator_test.go new file mode 100644 index 0000000..7cc7a00 --- /dev/null +++ b/btreeiterator_test.go @@ -0,0 +1,49 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "reflect" + "testing" +) + +func TestBTreeIterator(t *testing.T) { + testWithDB(t, "min/max matches iterator", func(t *testing.T, db *DB) { + users := []User{ + {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "aaa"}, + {ID: db.Users.c.NextID(), Email: "e@f.com", Name: "bbb"}, + {ID: db.Users.c.NextID(), Email: "c@d.com", Name: "ccc"}, + } + for _, u := range users { + if _, err := db.Users.c.Insert(u); err != nil { + t.Fatal(err) + } + } + + it := db.Users.nameBTree.Ascend() + if !it.Next() { + t.Fatal("No next") + } + u := it.Value() + if !reflect.DeepEqual(u, users[0]) { + t.Fatal(u, users[0]) + } + it.Close() + + it = db.Users.nameBTree.Descend() + if !it.Next() { + t.Fatal("No next") + } + u = it.Value() + if !reflect.DeepEqual(u, users[2]) { + t.Fatal(u, users[2]) + } + it.Close() + + }) +} diff --git a/codec.go b/codec.go new file mode 100644 index 0000000..42d0ace --- /dev/null +++ b/codec.go @@ -0,0 +1,27 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "bytes" + "encoding/json" + + "git.crumpington.com/public/mdb/kvstore" +) + +func decode[T any](data []byte) *T { + item := new(T) + must(json.Unmarshal(data, item)) + return item +} + +func encode(item any) []byte { + w := bytes.NewBuffer(kvstore.GetDataBuf(0)[:0]) + must(json.NewEncoder(w).Encode(item)) + return w.Bytes() +} diff --git a/collection.go b/collection.go new file mode 100644 index 0000000..f56317f --- /dev/null +++ b/collection.go @@ -0,0 +1,212 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "fmt" + + "git.crumpington.com/public/mdb/keyedmutex" +) + +type Collection[T any] struct { + primary bool + db *Database + name string + idLock keyedmutex.KeyedMutex[uint64] + items *itemMap[T] + indices []itemIndex[T] + uniqueIndices []itemUniqueIndex[T] + getID func(*T) uint64 + sanitize func(*T) + validate func(*T) error +} + +func NewCollection[T any](db *Database, name string, getID func(*T) uint64) *Collection[T] { + items := newItemMap(db.kv, name, getID) + c := &Collection[T]{ + primary: db.kv.Primary(), + db: db, + name: name, + idLock: keyedmutex.New[uint64](), + items: items, + indices: []itemIndex[T]{items}, + uniqueIndices: []itemUniqueIndex[T]{}, + getID: items.getID, + sanitize: func(*T) {}, + validate: func(*T) error { return nil }, + } + + db.collections[name] = c + return c +} + +func (c *Collection[T]) SetSanitize(sanitize func(*T)) { + c.sanitize = sanitize +} + +func (c *Collection[T]) SetValidate(validate func(*T) error) { + c.validate = validate +} + +// ---------------------------------------------------------------------------- + +func (c *Collection[T]) NextID() uint64 { + return c.items.nextID() +} + +// ---------------------------------------------------------------------------- + +func (c *Collection[T]) Insert(item T) (T, error) { + if !c.primary { + return item, ErrReadOnly + } + + c.sanitize(&item) + + if err := c.validate(&item); err != nil { + return item, err + } + + id := c.getID(&item) + c.idLock.Lock(id) + defer c.idLock.Unlock(id) + + if _, ok := c.items.Get(id); ok { + return item, fmt.Errorf("%w: ID", ErrDuplicate) + } + + // Acquire locks and check for insert conflicts. + for _, idx := range c.uniqueIndices { + idx.lock(&item) + defer idx.unlock(&item) + + if idx.insertConflict(&item) { + return item, fmt.Errorf("%w: %s", ErrDuplicate, idx.name()) + } + } + + for _, idx := range c.indices { + idx.insert(&item) + } + + return item, nil +} + +func (c *Collection[T]) Update(id uint64, update func(T) (T, error)) error { + if !c.primary { + return ErrReadOnly + } + + c.idLock.Lock(id) + defer c.idLock.Unlock(id) + + old, ok := c.items.Get(id) + if !ok { + return ErrNotFound + } + + newItem, err := update(*old) + if err != nil { + if err == ErrAbortUpdate { + return nil + } + return err + } + + new := &newItem + + if c.getID(new) != id { + return ErrMismatchedIDs + } + + c.sanitize(new) + + if err := c.validate(new); err != nil { + return err + } + + // Acquire locks and check for update conflicts. + for _, idx := range c.uniqueIndices { + idx.lock(new) + defer idx.unlock(new) + + if idx.updateConflict(new) { + return fmt.Errorf("%w: %s", ErrDuplicate, idx.name()) + } + } + + for _, idx := range c.indices { + idx.update(old, new) + } + + return nil +} + +func (c Collection[T]) Delete(id uint64) { + if !c.primary { + panic(ErrReadOnly) + } + + c.idLock.Lock(id) + defer c.idLock.Unlock(id) + + item, ok := c.items.Get(id) + if !ok { + return + } + + // Acquire locks and check for insert conflicts. + for _, idx := range c.uniqueIndices { + idx.lock(item) + defer idx.unlock(item) + } + + for _, idx := range c.indices { + idx.delete(item) + } +} + +func (c Collection[T]) Get(id uint64) (t T, ok bool) { + ptr, ok := c.items.Get(id) + if !ok { + return t, false + } + return *ptr, true +} + +// ---------------------------------------------------------------------------- + +func (c *Collection[T]) loadData() { + for _, idx := range c.indices { + must(idx.load(c.items.m)) + } +} + +func (c *Collection[T]) onStore(id uint64, data []byte) { + item := decode[T](data) + old, ok := c.items.Get(id) + if !ok { + // Insert. + for _, idx := range c.indices { + idx.insert(item) + } + } else { + // Otherwise update. + for _, idx := range c.indices { + idx.update(old, item) + } + } +} + +func (c *Collection[T]) onDelete(id uint64) { + if item, ok := c.items.Get(id); ok { + for _, idx := range c.indices { + idx.delete(item) + } + } +} diff --git a/collection_test.go b/collection_test.go new file mode 100644 index 0000000..ac0c013 --- /dev/null +++ b/collection_test.go @@ -0,0 +1,153 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "errors" + "reflect" + "testing" +) + +func TestCollection(t *testing.T) { + testWithDB(t, "insert on secondary", func(t *testing.T, db *DB) { + c := db.Users.c + c.primary = false + if _, err := c.Insert(User{}); !errors.Is(err, ErrReadOnly) { + t.Fatal(err) + } + }) + + testWithDB(t, "insert validation error", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ + ID: c.NextID(), + Email: "a@b.com", + } + if _, err := c.Insert(user); !errors.Is(err, ErrInvalidName) { + t.Fatal(err) + } + }) + + testWithDB(t, "insert duplicate", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ + ID: c.NextID(), + Name: "adsf", + Email: "a@b.com", + } + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + if _, err := c.Insert(user); !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + }) + + testWithDB(t, "update on secondary", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + + c.primary = false + + err := c.Update(user.ID, func(u User) (User, error) { + u.Name = "xxx" + return u, nil + }) + if !errors.Is(err, ErrReadOnly) { + t.Fatal(err) + } + }) + + testWithDB(t, "update not found", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + + err := c.Update(user.ID+1, func(u User) (User, error) { + u.Name = "xxx" + return u, nil + }) + if !errors.Is(err, ErrNotFound) { + t.Fatal(err) + } + }) + + testWithDB(t, "update failed validation", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + + err := c.Update(user.ID, func(u User) (User, error) { + u.Name = "" + return u, nil + }) + if !errors.Is(err, ErrInvalidName) { + t.Fatal(err) + } + }) + + testWithDB(t, "delete on secondary", func(t *testing.T, db *DB) { + defer func() { + if err := recover(); err == nil { + t.Fatal("No panic") + } + }() + + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + + c.primary = false + + c.Delete(1) + }) + + testWithDB(t, "delete not found", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + c.Delete(user.ID + 1) // Does nothing. + }) + + testWithDB(t, "get", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + + u2, ok := c.Get(user.ID) + if !ok || !reflect.DeepEqual(user, u2) { + t.Fatal(ok, u2, user) + } + }) + + testWithDB(t, "get not found", func(t *testing.T, db *DB) { + c := db.Users.c + user := User{ID: c.NextID(), Name: "adsf", Email: "a@b.com"} + if _, err := c.Insert(user); err != nil { + t.Fatal(err) + } + + u2, ok := c.Get(user.ID - 1) + if ok { + t.Fatal(ok, u2) + } + }) +} diff --git a/database.go b/database.go new file mode 100644 index 0000000..7e81c38 --- /dev/null +++ b/database.go @@ -0,0 +1,123 @@ +package mdb + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "net" + "os" + "path/filepath" + "sync" + + "git.crumpington.com/public/mdb/kvstore" + "golang.org/x/sys/unix" +) + +type LogInfo = kvstore.LogInfo + +type Database struct { + root string + lock *os.File + kv *kvstore.KV + + collections map[string]dbCollection +} + +func NewPrimary(root string) *Database { + return newDB(root, true) +} + +func NewSecondary(root string) *Database { + return newDB(root, false) +} + +func newDB(root string, primary bool) *Database { + must(os.MkdirAll(root, 0700)) + + lockPath := filepath.Join(root, "lock") + + // Acquire the lock. + lock, err := os.OpenFile(lockPath, os.O_RDWR|os.O_CREATE, 0600) + must(err) + must(unix.Flock(int(lock.Fd()), unix.LOCK_EX)) + + db := &Database{ + root: root, + collections: map[string]dbCollection{}, + lock: lock, + } + if primary { + db.kv = kvstore.NewPrimary(root) + } else { + db.kv = kvstore.NewSecondary(root, db.onStore, db.onDelete) + } + return db +} + +func (db *Database) Start() { + wg := sync.WaitGroup{} + for _, c := range db.collections { + wg.Add(1) + go func(c dbCollection) { + defer wg.Done() + c.loadData() + }(c) + } + wg.Wait() +} + +func (db *Database) MaxSeqNum() uint64 { + return db.kv.MaxSeqNum() +} + +func (db *Database) LogInfo() LogInfo { + return db.kv.LogInfo() +} + +func (db *Database) Close() { + if db.kv != nil { + db.kv.Close() + db.kv = nil + } + if db.lock != nil { + db.lock.Close() + db.lock = nil + } +} + +// ---------------------------------------------------------------------------- + +func (db *Database) onStore(collection string, id uint64, data []byte) { + if c, ok := db.collections[collection]; ok { + c.onStore(id, data) + } +} + +func (db *Database) onDelete(collection string, id uint64) { + if c, ok := db.collections[collection]; ok { + c.onDelete(id) + } +} + +// ---------------------------------------------------------------------------- + +func (db *Database) SyncSend(conn net.Conn) { + db.kv.SyncSend(conn) +} + +func (db *Database) SyncRecv(conn net.Conn) { + db.kv.SyncRecv(conn) +} + +// ---------------------------------------------------------------------------- + +// CleanBefore deletes log entries that are more than the given number of +// seconds old. +func (db *Database) CleanBefore(seconds int64) { + db.kv.CleanBefore(seconds) +} diff --git a/database_ex_test.go b/database_ex_test.go new file mode 100644 index 0000000..e9a28ba --- /dev/null +++ b/database_ex_test.go @@ -0,0 +1 @@ +package mdb diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..86522bd --- /dev/null +++ b/database_test.go @@ -0,0 +1,45 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "sync" + "testing" + "time" +) + +func TestDatabase(t *testing.T) { + testWithDB(t, "multiple writers", func(t *testing.T, db *DB) { + wg := sync.WaitGroup{} + N := 64 + wg.Add(64) + for i := 0; i < N; i++ { + go func() { + defer wg.Done() + for j := 0; j < 1024; j++ { + db.RandAction() + } + }() + } + wg.Wait() + }) + + testWithDB(t, "clean before", func(t *testing.T, db *DB) { + db.RandAction() + db.RandAction() + db.RandAction() + + time.Sleep(2 * time.Second) + db.CleanBefore(0) + + if db.MaxSeqNum() == 0 { + t.Fatal(db.MaxSeqNum()) + } + }) + +} diff --git a/dep-graph.sh b/dep-graph.sh new file mode 100755 index 0000000..6d088f3 --- /dev/null +++ b/dep-graph.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +godepgraph -s . > .deps.dot && xdot .deps.dot + +rm .deps.dot diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..920023b --- /dev/null +++ b/errors.go @@ -0,0 +1,21 @@ +package mdb + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import "errors" + +var ( + ErrMismatchedIDs = errors.New("MismatchedIDs") + ErrDuplicate = errors.New("Duplicate") + ErrNotFound = errors.New("NotFound") + ErrReadOnly = errors.New("ReadOnly") + + // Return in update function to abort changes. + ErrAbortUpdate = errors.New("AbortUpdate") +) diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..2d3f4d5 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.crumpington.com/public/mdb + +go 1.18 + +require ( + github.com/google/btree v1.1.2 + github.com/mattn/go-sqlite3 v1.14.14 + golang.org/x/sys v0.0.0-20220731174439-a90be440212d +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..0862629 --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +github.com/mattn/go-sqlite3 v1.14.14 h1:qZgc/Rwetq+MtyE18WhzjokPD93dNqLGNT3QJuLvBGw= +github.com/mattn/go-sqlite3 v1.14.14/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d h1:Sv5ogFZatcgIMMtBSTTAgMYsicp25MXBubjXNDKwm80= +golang.org/x/sys v0.0.0-20220731174439-a90be440212d/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/itemmap.go b/itemmap.go new file mode 100644 index 0000000..863680a --- /dev/null +++ b/itemmap.go @@ -0,0 +1,114 @@ +package mdb + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "math/rand" + + "sync" + "sync/atomic" + + "git.crumpington.com/public/mdb/keyedmutex" + "git.crumpington.com/public/mdb/kvstore" +) + +// Implements ItemMap and ItemUniqueIndex interfaces. +type itemMap[T any] struct { + primary bool + kv *kvstore.KV + collection string + idLock keyedmutex.KeyedMutex[uint64] + mapLock sync.Mutex + m map[uint64]*T + getID func(*T) uint64 + maxID uint64 +} + +func newItemMap[T any](kv *kvstore.KV, collection string, getID func(*T) uint64) *itemMap[T] { + m := &itemMap[T]{ + primary: kv.Primary(), + kv: kv, + collection: collection, + idLock: keyedmutex.New[uint64](), + m: map[uint64]*T{}, + getID: getID, + } + + kv.Iterate(collection, func(id uint64, data []byte) { + item := decode[T](data) + if id > m.maxID { + m.maxID = id + } + m.m[id] = item + }) + + return m +} + +func (m *itemMap[T]) load(src map[uint64]*T) error { + // No-op: The itemmap is the source for loading all other indices. + return nil +} + +// ---------------------------------------------------------------------------- + +func (m *itemMap[T]) Get(id uint64) (*T, bool) { + return m.mapGet(id) +} + +// Should hold item lock when calling. +func (idx *itemMap[T]) insert(item *T) { + id := idx.getID(item) + if idx.primary { + idx.kv.Store(idx.collection, id, encode(item)) + } + idx.mapSet(id, item) +} + +// Should hold item lock when calling. old and new MUST have the same ID. +func (idx *itemMap[T]) update(old, new *T) { + idx.insert(new) +} + +// Should hold item lock when calling. +func (idx *itemMap[T]) delete(item *T) { + id := idx.getID(item) + if idx.primary { + idx.kv.Delete(idx.collection, id) + } + idx.mapDelete(id) +} + +// ---------------------------------------------------------------------------- + +func (idx *itemMap[T]) mapSet(id uint64, item *T) { + idx.mapLock.Lock() + idx.m[id] = item + idx.mapLock.Unlock() +} + +func (idx *itemMap[T]) mapDelete(id uint64) { + idx.mapLock.Lock() + delete(idx.m, id) + idx.mapLock.Unlock() +} + +func (idx *itemMap[T]) mapGet(id uint64) (*T, bool) { + idx.mapLock.Lock() + item, ok := idx.m[id] + idx.mapLock.Unlock() + return item, ok +} + +// ---------------------------------------------------------------------------- + +func (idx *itemMap[T]) nextID() uint64 { + n := 1 + rand.Int63n(256) + return atomic.AddUint64(&idx.maxID, uint64(n)) +} diff --git a/itemmap_ex_test.go b/itemmap_ex_test.go new file mode 100644 index 0000000..dbea7c2 --- /dev/null +++ b/itemmap_ex_test.go @@ -0,0 +1,61 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "fmt" + "reflect" +) + +func (m *itemMap[T]) Equals(rhs *itemMap[T]) error { + return m.EqualsMap(rhs.m) +} + +func (m *itemMap[T]) EqualsMap(data map[uint64]*T) error { + if len(data) != len(m.m) { + return fmt.Errorf("Expected %d items, but found %d.", len(data), len(m.m)) + } + + for key, exp := range data { + val, ok := m.m[key] + if !ok { + return fmt.Errorf("No value for %d. Expected: %v", key, *exp) + } + if !reflect.DeepEqual(*val, *exp) { + return fmt.Errorf("Value mismatch %d: %v != %v", key, *val, *exp) + } + } + return nil +} + +func (m *itemMap[T]) EqualsKV() (err error) { + count := 0 + m.kv.Iterate(m.collection, func(id uint64, data []byte) { + count++ + if err != nil { + return + } + item := decode[T](data) + val, ok := m.m[id] + if !ok { + err = fmt.Errorf("Item %d not found in memory: %v", id, *item) + return + } + + if !reflect.DeepEqual(*item, *val) { + err = fmt.Errorf("Items not equal %d: %v != %v", id, *item, *val) + return + } + }) + + if err == nil && count != len(m.m) { + err = fmt.Errorf("%d items on disk, but %d in memory", count, len(m.m)) + } + + return err +} diff --git a/itemmap_test.go b/itemmap_test.go new file mode 100644 index 0000000..d3d70be --- /dev/null +++ b/itemmap_test.go @@ -0,0 +1,134 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "fmt" + "reflect" + "testing" +) + +func TestItemMap(t *testing.T) { + + // expected is a map of users. + run := func(name string, inner func(t *testing.T, db *DB) (expected map[uint64]*User)) { + testWithDB(t, name, func(t *testing.T, db *DB) { + expected := inner(t, db) + + if err := db.Users.c.items.EqualsMap(expected); err != nil { + t.Fatal(err) + } + + if err := db.Users.c.items.EqualsKV(); err != nil { + t.Fatal(err) + } + + db.Close() + db = OpenDB(db.root, true) + + if err := db.Users.c.items.EqualsMap(expected); err != nil { + t.Fatal(err) + } + + if err := db.Users.c.items.EqualsKV(); err != nil { + t.Fatal(err) + } + }) + } + + run("simple", func(t *testing.T, db *DB) (expected map[uint64]*User) { + users := map[uint64]*User{} + c := db.Users.c + for i := uint64(1); i < 10; i++ { + id := c.NextID() + users[id] = &User{ + ID: id, + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + user, err := c.Insert(*users[id]) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(user, *users[id]) { + t.Fatal(user, *users[id]) + } + } + + return users + }) + + run("insert and delete", func(t *testing.T, db *DB) (expected map[uint64]*User) { + users := map[uint64]*User{} + c := db.Users.c + for x := uint64(1); x < 10; x++ { + id := c.NextID() + users[id] = &User{ + ID: id, + Email: fmt.Sprintf("a.%d@c.com", x), + Name: fmt.Sprintf("name.%d", x), + ExtID: fmt.Sprintf("EXTID.%d", x), + } + user, err := c.Insert(*users[id]) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(user, *users[id]) { + t.Fatal(user, *users[id]) + } + } + + var id uint64 + for key := range users { + id = key + } + + delete(users, id) + c.Delete(id) + + return users + }) + + run("update", func(t *testing.T, db *DB) (expected map[uint64]*User) { + users := map[uint64]*User{} + c := db.Users.c + for x := uint64(1); x < 10; x++ { + id := c.NextID() + users[id] = &User{ + ID: id, + Email: fmt.Sprintf("a.%d@c.com", x), + Name: fmt.Sprintf("name.%d", x), + ExtID: fmt.Sprintf("EXTID.%d", x), + } + user, err := c.Insert(*users[id]) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(user, *users[id]) { + t.Fatal(user, *users[id]) + } + } + + var id uint64 + for key := range users { + id = key + } + + err := c.Update(id, func(u User) (User, error) { + u.Name = "Hello" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + users[id].Name = "Hello" + + return users + }) +} diff --git a/keyedmutex/mutex.go b/keyedmutex/mutex.go new file mode 100644 index 0000000..0ba9b69 --- /dev/null +++ b/keyedmutex/mutex.go @@ -0,0 +1,75 @@ +package keyedmutex + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "container/list" + "sync" +) + +type KeyedMutex[K comparable] struct { + mu *sync.Mutex + waitList map[K]*list.List +} + +func New[K comparable]() KeyedMutex[K] { + return KeyedMutex[K]{ + mu: new(sync.Mutex), + waitList: map[K]*list.List{}, + } +} + +func (m KeyedMutex[K]) Lock(key K) { + if ch := m.lock(key); ch != nil { + <-ch + } +} + +func (m KeyedMutex[K]) lock(key K) chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + + if waitList, ok := m.waitList[key]; ok { + ch := make(chan struct{}) + waitList.PushBack(ch) + return ch + } + + m.waitList[key] = list.New() + return nil +} + +func (m KeyedMutex[K]) TryLock(key K) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.waitList[key]; ok { + return false + } + + m.waitList[key] = list.New() + return true +} + +func (m KeyedMutex[K]) Unlock(key K) { + m.mu.Lock() + defer m.mu.Unlock() + + waitList, ok := m.waitList[key] + if !ok { + panic("unlock of unlocked mutex") + } + + if waitList.Len() == 0 { + delete(m.waitList, key) + } else { + ch := waitList.Remove(waitList.Front()).(chan struct{}) + ch <- struct{}{} + } +} diff --git a/kvstore/db-sql.go b/kvstore/db-sql.go new file mode 100644 index 0000000..8c64496 --- /dev/null +++ b/kvstore/db-sql.go @@ -0,0 +1,116 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +const sqlSchema = ` +BEGIN IMMEDIATE; + +CREATE TABLE IF NOT EXISTS data( + SeqNum INTEGER NOT NULL PRIMARY KEY, + Deleted INTEGER NOT NULL DEFAULT 0, + Data BLOB NOT NULL +) WITHOUT ROWID; + +CREATE INDEX IF NOT EXISTS data_deleted_index ON data(Deleted,SeqNum); + +CREATE TABLE IF NOT EXISTS log( + SeqNum INTEGER NOT NULL PRIMARY KEY, + CreatedAt INTEGER NOT NULL, + Collection TEXT NOT NULL, + ID INTEGER NOT NULL, + Store INTEGER NOT NULL +) WITHOUT ROWID; + +CREATE INDEX IF NOT EXISTS log_created_at_index ON log(CreatedAt); + +CREATE TABLE IF NOT EXISTS kv( + Collection TEXT NOT NULL, + ID INTEGER NOT NULL, + SeqNum INTEGER NOT NULL, + PRIMARY KEY (Collection, ID) +) WITHOUT ROWID; + +CREATE VIEW IF NOT EXISTS kvdata AS +SELECT + kv.Collection, + kv.ID, + data.Data +FROM kv +JOIN data ON kv.SeqNum=data.SeqNum; + +CREATE VIEW IF NOT EXISTS logdata AS +SELECT + log.SeqNum, + log.Collection, + log.ID, + log.Store, + data.data +FROM log +LEFT JOIN data on log.SeqNum=data.SeqNum; + +CREATE TRIGGER IF NOT EXISTS deletedata AFTER UPDATE OF SeqNum ON kv +FOR EACH ROW +WHEN OLD.SeqNum != NEW.SeqNum +BEGIN + UPDATE data SET Deleted=1 WHERE SeqNum=OLD.SeqNum; +END; + +COMMIT;` + +// ---------------------------------------------------------------------------- + +const sqlInsertData = `INSERT INTO data(SeqNum,Data) VALUES(?,?)` + +const sqlInsertKV = `INSERT INTO kv(Collection,ID,SeqNum) VALUES (?,?,?) +ON CONFLICT(Collection,ID) DO UPDATE SET SeqNum=excluded.SeqNum +WHERE ID=excluded.ID` + +// ---------------------------------------------------------------------------- + +const sqlDeleteKV = `DELETE FROM kv WHERE Collection=? AND ID=?` + +const sqlDeleteData = `UPDATE data SET Deleted=1 +WHERE SeqNum=( + SELECT SeqNum FROM kv WHERE Collection=? AND ID=?)` + +// ---------------------------------------------------------------------------- + +const sqlInsertLog = `INSERT INTO log(SeqNum,CreatedAt,Collection,ID,Store) + VALUES(?,?,?,?,?)` + +// ---------------------------------------------------------------------------- + +const sqlKVIterate = `SELECT ID,Data FROM kvdata WHERE Collection=?` + +const sqlLogIterate = ` +SELECT SeqNum,Collection,ID,Store,Data +FROM logdata +WHERE SeqNum > ? +ORDER BY SeqNum ASC` + +const sqlMaxSeqNumGet = `SELECT COALESCE(MAX(SeqNum),0) FROM log` + +const sqlLogInfoGet = `SELECT + COALESCE(SeqNum, 0), + COALESCE(CreatedAt, 0) +FROM log +ORDER BY SeqNum DESC LIMIT 1` + +const sqlCleanQuery = ` +DELETE FROM + log +WHERE + CreatedAt < ? AND + SeqNum < (SELECT MAX(SeqNum) FROM log); + +DELETE FROM + data +WHERE + Deleted != 0 AND + SeqNum < (SELECT MIN(SeqNum) FROM log);` diff --git a/kvstore/dep-graph.sh b/kvstore/dep-graph.sh new file mode 100755 index 0000000..6d088f3 --- /dev/null +++ b/kvstore/dep-graph.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +godepgraph -s . > .deps.dot && xdot .deps.dot + +rm .deps.dot diff --git a/kvstore/globals.go b/kvstore/globals.go new file mode 100644 index 0000000..0373554 --- /dev/null +++ b/kvstore/globals.go @@ -0,0 +1,40 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "sync" + "time" +) + +var ( + connTimeout = 16 * time.Second + pollInterval = 500 * time.Millisecond + modQSize = 1024 + poolBufSize = 8192 + bufferPool = sync.Pool{ + New: func() any { + return make([]byte, poolBufSize) + }, + } +) + +func GetDataBuf(size int) []byte { + if size > poolBufSize { + return make([]byte, size) + } + return bufferPool.Get().([]byte)[:size] +} + +func RecycleDataBuf(b []byte) { + if cap(b) != poolBufSize { + return + } + bufferPool.Put(b) +} diff --git a/kvstore/main_test.go b/kvstore/main_test.go new file mode 100644 index 0000000..0948c96 --- /dev/null +++ b/kvstore/main_test.go @@ -0,0 +1,21 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "math/rand" + "os" + "testing" + "time" +) + +func TestMain(m *testing.M) { + rand.Seed(time.Now().UnixNano()) + os.Exit(m.Run()) +} diff --git a/kvstore/shipping.go b/kvstore/shipping.go new file mode 100644 index 0000000..a295ae4 --- /dev/null +++ b/kvstore/shipping.go @@ -0,0 +1,53 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import "encoding/binary" + +const recHeaderSize = 22 + +func encodeRecordHeader(rec record, buf []byte) { + // SeqNum (8) + // ID (8) + // DataLen (4) + // Store (1) + // CollectionLen (1) + + binary.LittleEndian.PutUint64(buf[:8], rec.SeqNum) + buf = buf[8:] + binary.LittleEndian.PutUint64(buf[:8], rec.ID) + buf = buf[8:] + + if rec.Store { + binary.LittleEndian.PutUint32(buf[:4], uint32(len(rec.Data))) + buf[4] = 1 + } else { + binary.LittleEndian.PutUint32(buf[:4], 0) + buf[4] = 0 + } + buf = buf[5:] + + buf[0] = byte(len(rec.Collection)) +} + +func decodeRecHeader(header []byte) (rec record, colLen, dataLen int) { + buf := header + + rec.SeqNum = binary.LittleEndian.Uint64(buf[:8]) + buf = buf[8:] + rec.ID = binary.LittleEndian.Uint64(buf[:8]) + buf = buf[8:] + dataLen = int(binary.LittleEndian.Uint32(buf[:4])) + buf = buf[4:] + rec.Store = buf[0] != 0 + buf = buf[1:] + colLen = int(buf[0]) + + return +} diff --git a/kvstore/shipping_test.go b/kvstore/shipping_test.go new file mode 100644 index 0000000..0e5aa87 --- /dev/null +++ b/kvstore/shipping_test.go @@ -0,0 +1,254 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "math/rand" + "os" + "sync" + "testing" + "time" + + "git.crumpington.com/public/mdb/testconn" +) + +// ---------------------------------------------------------------------------- + +// Stores info from secondary callbacks. +type callbacks struct { + lock sync.Mutex + m map[string]map[uint64]string +} + +func (sc *callbacks) onStore(c string, id uint64, data []byte) { + sc.lock.Lock() + defer sc.lock.Unlock() + if _, ok := sc.m[c]; !ok { + sc.m[c] = map[uint64]string{} + } + sc.m[c][id] = string(data) +} + +func (sc *callbacks) onDelete(c string, id uint64) { + sc.lock.Lock() + defer sc.lock.Unlock() + if _, ok := sc.m[c]; !ok { + return + } + delete(sc.m[c], id) +} + +// ---------------------------------------------------------------------------- + +func TestShipping(t *testing.T) { + run := func(name string, inner func( + t *testing.T, + pDir string, + sDir string, + primary *KV, + secondary *KV, + cbs *callbacks, + nw *testconn.Network, + )) { + t.Run(name, func(t *testing.T) { + pDir, _ := os.MkdirTemp("", "") + defer os.RemoveAll(pDir) + sDir, _ := os.MkdirTemp("", "") + defer os.RemoveAll(sDir) + + nw := testconn.NewNetwork() + defer func() { + nw.CloseServer() + nw.CloseClient() + }() + + cbs := &callbacks{ + m: map[string]map[uint64]string{}, + } + + prim := NewPrimary(pDir) + defer prim.Close() + sec := NewSecondary(sDir, cbs.onStore, cbs.onDelete) + defer sec.Close() + + inner(t, pDir, sDir, prim, sec, cbs, nw) + }) + } + + run("simple", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) { + M := 10 + N := 1000 + + wg := sync.WaitGroup{} + + // Store M*N values in the background. + for i := 0; i < M; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + time.Sleep(time.Millisecond) + prim.randAction() + } + }() + } + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Accept() + prim.SyncSend(conn) + }() + + // Recv in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Dial() + sec.SyncRecv(conn) + }() + + sec.waitForSeqNum(uint64(M * N)) + + nw.CloseServer() + nw.CloseClient() + wg.Wait() + + if err := prim.equalsKV("a", sec); err != nil { + t.Fatal(err) + } + if err := prim.equalsKV("b", sec); err != nil { + t.Fatal(err) + } + if err := prim.equalsKV("c", sec); err != nil { + t.Fatal(err) + } + }) + + run("simple concurrent", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) { + M := 64 + N := 128 + + wg := sync.WaitGroup{} + + // Store M*N values in the background. + for i := 0; i < M; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < N; i++ { + time.Sleep(time.Millisecond) + prim.randAction() + } + }() + } + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Accept() + prim.SyncSend(conn) + }() + + // Recv in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Dial() + sec.SyncRecv(conn) + }() + + sec.waitForSeqNum(uint64(M * N)) + + nw.CloseServer() + nw.CloseClient() + wg.Wait() + + if err := prim.equalsKV("a", sec); err != nil { + t.Fatal(err) + } + if err := prim.equalsKV("b", sec); err != nil { + t.Fatal(err) + } + if err := prim.equalsKV("c", sec); err != nil { + t.Fatal(err) + } + }) + + run("net failures", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) { + M := 10 + N := 1000 + sleepTime := time.Millisecond + + wg := sync.WaitGroup{} + + // Store M*N values in the background. + for i := 0; i < M; i++ { + wg.Add(1) + go func() { + defer wg.Done() + for j := 0; j < N; j++ { + time.Sleep(sleepTime) + prim.randAction() + } + }() + } + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + for sec.MaxSeqNum() < uint64(M*N) { + if conn := nw.Accept(); conn != nil { + prim.SyncSend(conn) + } + } + }() + + // Recv in the background. + wg.Add(1) + go func() { + defer wg.Done() + for sec.MaxSeqNum() < uint64(M*N) { + if conn := nw.Dial(); conn != nil { + sec.SyncRecv(conn) + } + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + for sec.MaxSeqNum() < uint64(M*N) { + time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime)))) + if rand.Float64() < 0.5 { + nw.CloseClient() + } else { + nw.CloseServer() + } + } + }() + + sec.waitForSeqNum(prim.MaxSeqNum()) + wg.Wait() + + if err := prim.equalsKV("a", sec); err != nil { + t.Fatal(err) + } + if err := prim.equalsKV("b", sec); err != nil { + t.Fatal(err) + } + if err := prim.equalsKV("c", sec); err != nil { + t.Fatal(err) + } + }) + +} diff --git a/kvstore/store.go b/kvstore/store.go new file mode 100644 index 0000000..02a2741 --- /dev/null +++ b/kvstore/store.go @@ -0,0 +1,149 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "database/sql" + "sync" + "time" + + _ "github.com/mattn/go-sqlite3" +) + +type LogInfo struct { + MaxSeqNum uint64 + MaxCreatedAt uint64 +} + +type KV struct { + primary bool + dbPath string + + db *sql.DB + + maxSeqNumStmt *sql.Stmt + logInfoStmt *sql.Stmt + logIterateStmt *sql.Stmt + + w *writer + + onStore func(string, uint64, []byte) + onDelete func(string, uint64) + + closeLock sync.Mutex + recvLock sync.Mutex +} + +func newKV( + dir string, + primary bool, + onStore func(string, uint64, []byte), + onDelete func(string, uint64), +) *KV { + kv := &KV{ + dbPath: dbPath(dir), + primary: primary, + onStore: onStore, + onDelete: onDelete, + } + + opts := `?_journal=WAL` + db, err := sql.Open("sqlite3", kv.dbPath+opts) + must(err) + + _, err = db.Exec(sqlSchema) + must(err) + + kv.maxSeqNumStmt, err = db.Prepare(sqlMaxSeqNumGet) + must(err) + kv.logInfoStmt, err = db.Prepare(sqlLogInfoGet) + must(err) + kv.logIterateStmt, err = db.Prepare(sqlLogIterate) + must(err) + + _, err = db.Exec(sqlSchema) + must(err) + + kv.db = db + + if kv.primary { + kv.w = newWriter(kv.db) + kv.w.Start(kv.MaxSeqNum()) + } + return kv +} + +func NewPrimary(dir string) *KV { + return newKV(dir, true, nil, nil) +} + +func NewSecondary( + dir string, + onStore func(collection string, id uint64, data []byte), + onDelete func(collection string, id uint64), +) *KV { + return newKV(dir, false, onStore, onDelete) +} + +func (kv *KV) Primary() bool { + return kv.primary +} + +func (kv *KV) MaxSeqNum() (seqNum uint64) { + must(kv.maxSeqNumStmt.QueryRow().Scan(&seqNum)) + return seqNum +} + +func (kv *KV) LogInfo() (info LogInfo) { + must(kv.logInfoStmt.QueryRow().Scan(&info.MaxSeqNum, &info.MaxCreatedAt)) + return +} + +func (kv *KV) Iterate(collection string, each func(id uint64, data []byte)) { + rows, err := kv.db.Query(sqlKVIterate, collection) + must(err) + defer rows.Close() + + var ( + id uint64 + data []byte + ) + + for rows.Next() { + must(rows.Scan(&id, &data)) + each(id, data) + } +} + +func (kv *KV) Close() { + kv.closeLock.Lock() + defer kv.closeLock.Unlock() + + if kv.w != nil { + kv.w.Stop() + } + + if kv.db != nil { + kv.db.Close() + kv.db = nil + } +} + +func (kv *KV) Store(collection string, id uint64, data []byte) { + kv.w.Store(collection, id, data) +} + +func (kv *KV) Delete(collection string, id uint64) { + kv.w.Delete(collection, id) +} + +func (kv *KV) CleanBefore(seconds int64) { + _, err := kv.db.Exec(sqlCleanQuery, time.Now().Unix()-seconds) + must(err) +} diff --git a/kvstore/store_test.go b/kvstore/store_test.go new file mode 100644 index 0000000..e2abe0b --- /dev/null +++ b/kvstore/store_test.go @@ -0,0 +1,125 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "fmt" + "math/rand" + "os" + "reflect" + "testing" + "time" +) + +// ---------------------------------------------------------------------------- + +func (kv *KV) waitForSeqNum(x uint64) { + for { + seqNum := kv.MaxSeqNum() + if seqNum >= x { + return + } + time.Sleep(100 * time.Millisecond) + } +} + +func (kv *KV) dump(collection string) map[uint64]string { + m := map[uint64]string{} + kv.Iterate(collection, func(id uint64, data []byte) { + m[id] = string(data) + }) + return m +} + +func (kv *KV) equals(collection string, expected map[uint64]string) error { + m := kv.dump(collection) + if len(m) != len(expected) { + return fmt.Errorf("Expected %d values but found %d", len(expected), len(m)) + } + + for key, exp := range expected { + val, ok := m[key] + if !ok { + return fmt.Errorf("Value for %d not found.", key) + } + if val != exp { + return fmt.Errorf("%d: Expected %s but found %s.", key, exp, val) + } + } + return nil +} + +func (kv *KV) equalsKV(collection string, rhs *KV) error { + l1 := []record{} + kv.replay(0, func(rec record) error { + l1 = append(l1, rec) + return nil + }) + + idx := -1 + err := rhs.replay(0, func(rec record) error { + idx++ + if !reflect.DeepEqual(rec, l1[idx]) { + return fmt.Errorf("Records not equal: %d %v %v", idx, rec, l1[idx]) + } + return nil + }) + if err != nil { + return err + } + + return kv.equals(collection, rhs.dump(collection)) +} + +// Collection one of ("a", "b", "c"). +// ID one of [1,10] +var ( + randCollections = []string{"a", "b", "c"} + randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16} +) + +func (kv *KV) randAction() { + c := randCollections[rand.Intn(len(randCollections))] + id := randIDs[rand.Intn(len(randIDs))] + + // Mostly stores. + if rand.Float64() < 0.9 { + kv.Store(c, id, randBytes()) + } else { + kv.Delete(c, id) + } +} + +// ---------------------------------------------------------------------------- + +func TestKV(t *testing.T) { + run := func(name string, inner func(t *testing.T, kv *KV)) { + dir, _ := os.MkdirTemp("", "") + defer os.RemoveAll(dir) + kv := NewPrimary(dir) + defer kv.Close() + + inner(t, kv) + } + + run("simple", func(t *testing.T, kv *KV) { + kv.Store("a", 1, _b("Hello")) + kv.Store("a", 2, _b("World")) + + kv.waitForSeqNum(2) + + err := kv.equals("a", map[uint64]string{ + 1: "Hello", + 2: "World", + }) + if err != nil { + t.Fatal(err) + } + }) +} diff --git a/kvstore/sync-recv.go b/kvstore/sync-recv.go new file mode 100644 index 0000000..fea3ea0 --- /dev/null +++ b/kvstore/sync-recv.go @@ -0,0 +1,103 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "encoding/binary" + "log" + "net" + "time" +) + +func (kv *KV) SyncRecv(conn net.Conn) { + defer conn.Close() + + if kv.primary { + panic("SyncRecv called on primary.") + } + + if !kv.recvLock.TryLock() { + return + } + defer kv.recvLock.Unlock() + + // It's important that we stop the writer so that all queued writes are + // committed to the database before syncing has a chance to restart. + w := newWriter(kv.db) + w.Start(kv.MaxSeqNum()) + defer w.Stop() + + headerBuf := make([]byte, recHeaderSize) + nameBuf := make([]byte, 32) + afterSeqNumBuf := make([]byte, 8) + + afterSeqNum := kv.MaxSeqNum() + expectedSeqNum := afterSeqNum + 1 + + // Send fromID to the conn. + conn.SetWriteDeadline(time.Now().Add(connTimeout)) + binary.LittleEndian.PutUint64(afterSeqNumBuf, afterSeqNum) + if _, err := conn.Write(afterSeqNumBuf); err != nil { + log.Printf("RecvWAL failed to send after sequence number: %v", err) + return + } + conn.SetWriteDeadline(time.Time{}) + + for { + conn.SetReadDeadline(time.Now().Add(connTimeout)) + + if _, err := conn.Read(headerBuf); err != nil { + log.Printf("RecvWAL failed to read header: %v", err) + return + } + rec, nameLen, dataLen := decodeRecHeader(headerBuf) + + // Heartbeat. + if rec.SeqNum == 0 { + continue + } + + if rec.SeqNum != expectedSeqNum { + log.Printf("Expected sequence number %d but got %d.", + expectedSeqNum, rec.SeqNum) + return + } + expectedSeqNum++ + + if cap(nameBuf) < nameLen { + nameBuf = make([]byte, nameLen) + } + nameBuf = nameBuf[:nameLen] + + if _, err := conn.Read(nameBuf); err != nil { + log.Printf("RecvWAL failed to read collection name: %v", err) + return + } + + rec.Collection = string(nameBuf) + + // Note that it's OK to apply changes via onStore / onDelete before the log + // entries are written because all changes are idempotent. Even if the log + // is replayed from some point in the past, the final state will be + // consistent. + + if rec.Store { + rec.Data = GetDataBuf(dataLen) + if _, err := conn.Read(rec.Data); err != nil { + log.Printf("RecvWAL failed to read data: %v", err) + return + } + kv.onStore(rec.Collection, rec.ID, rec.Data) + w.StoreAsync(rec.Collection, rec.ID, rec.Data) + } else { + kv.onDelete(rec.Collection, rec.ID) + w.DeleteAsync(rec.Collection, rec.ID) + } + } +} diff --git a/kvstore/sync-send.go b/kvstore/sync-send.go new file mode 100644 index 0000000..8b6c6e7 --- /dev/null +++ b/kvstore/sync-send.go @@ -0,0 +1,116 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "encoding/binary" + "log" + "net" + "time" +) + +func (kv *KV) SyncSend(conn net.Conn) { + defer conn.Close() + + var ( + seqNumBuf = make([]byte, 8) + headerBuf = make([]byte, recHeaderSize) + empty = make([]byte, recHeaderSize) + err error + ) + + // Read afterSeqNum from the conn. + conn.SetReadDeadline(time.Now().Add(connTimeout)) + if _, err := conn.Read(seqNumBuf[:8]); err != nil { + log.Printf("SyncSend failed to read afterSeqNum: %v", err) + return + } + + afterSeqNum := binary.LittleEndian.Uint64(seqNumBuf[:8]) + +POLL: + + for i := 0; i < 4; i++ { + if kv.MaxSeqNum() > afterSeqNum { + goto REPLAY + } + time.Sleep(pollInterval) + } + + goto HEARTBEAT + +HEARTBEAT: + + conn.SetWriteDeadline(time.Now().Add(connTimeout)) + if _, err := conn.Write(empty); err != nil { + log.Printf("SendWAL failed to send heartbeat: %v", err) + return + } + + goto POLL + +REPLAY: + + err = kv.replay(afterSeqNum, func(rec record) error { + conn.SetWriteDeadline(time.Now().Add(connTimeout)) + + afterSeqNum = rec.SeqNum + encodeRecordHeader(rec, headerBuf) + if _, err := conn.Write(headerBuf); err != nil { + log.Printf("SendWAL failed to send header %v", err) + return err + } + + if _, err := conn.Write([]byte(rec.Collection)); err != nil { + log.Printf("SendWAL failed to send collection name %v", err) + return err + } + + if !rec.Store { + return nil + } + + if _, err := conn.Write(rec.Data); err != nil { + log.Printf("SendWAL failed to send data %v", err) + return err + } + + return nil + }) + + if err != nil { + return + } + + goto POLL +} + +func (kv *KV) replay(afterSeqNum uint64, each func(rec record) error) error { + rec := record{} + rows, err := kv.logIterateStmt.Query(afterSeqNum) + must(err) + defer rows.Close() + + rec.Data = GetDataBuf(0) + defer RecycleDataBuf(rec.Data) + + for rows.Next() { + must(rows.Scan( + &rec.SeqNum, + &rec.Collection, + &rec.ID, + &rec.Store, + &rec.Data)) + + if err = each(rec); err != nil { + return err + } + } + return nil +} diff --git a/kvstore/types.go b/kvstore/types.go new file mode 100644 index 0000000..0fb0da5 --- /dev/null +++ b/kvstore/types.go @@ -0,0 +1,27 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import "sync" + +type modJob struct { + Collection string + ID uint64 + Store bool + Data []byte + Ready *sync.WaitGroup +} + +type record struct { + SeqNum uint64 + Collection string + ID uint64 + Store bool + Data []byte +} diff --git a/kvstore/util.go b/kvstore/util.go new file mode 100644 index 0000000..f3cd773 --- /dev/null +++ b/kvstore/util.go @@ -0,0 +1,23 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "path/filepath" +) + +func must(err error) { + if err != nil { + panic(err) + } +} + +func dbPath(dir string) string { + return filepath.Join(dir, "db") +} diff --git a/kvstore/util_test.go b/kvstore/util_test.go new file mode 100644 index 0000000..104c6f3 --- /dev/null +++ b/kvstore/util_test.go @@ -0,0 +1,29 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "crypto/rand" + "encoding/hex" + mrand "math/rand" +) + +func _b(in string) []byte { + return []byte(in) +} + +func randString() string { + buf := make([]byte, 1+mrand.Intn(20)) + rand.Read(buf) + return hex.EncodeToString(buf) +} + +func randBytes() []byte { + return _b(randString()) +} diff --git a/kvstore/writer.go b/kvstore/writer.go new file mode 100644 index 0000000..00b21af --- /dev/null +++ b/kvstore/writer.go @@ -0,0 +1,194 @@ +package kvstore + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "database/sql" + "sync" + "time" +) + +type writer struct { + db *sql.DB + modQ chan modJob + wg sync.WaitGroup +} + +func newWriter(db *sql.DB) *writer { + return &writer{ + db: db, + modQ: make(chan modJob, modQSize), + } +} + +func (w *writer) Start(maxSeqNum uint64) { + w.wg.Add(1) + go w.run(maxSeqNum) +} + +func (w *writer) Stop() { + close(w.modQ) + w.wg.Wait() +} + +// Takes ownership of the incoming data. +func (w *writer) Store(collection string, id uint64, data []byte) { + job := modJob{ + Collection: collection, + ID: id, + Store: true, + Data: data, + Ready: &sync.WaitGroup{}, + } + job.Ready.Add(1) + w.modQ <- job + job.Ready.Wait() +} + +func (w *writer) Delete(collection string, id uint64) { + job := modJob{ + Collection: collection, + ID: id, + Store: false, + Ready: &sync.WaitGroup{}, + } + job.Ready.Add(1) + w.modQ <- job + job.Ready.Wait() +} + +// Takes ownership of the incoming data. +func (w *writer) StoreAsync(collection string, id uint64, data []byte) { + w.modQ <- modJob{ + Collection: collection, + ID: id, + Store: true, + Data: data, + } +} + +func (w *writer) DeleteAsync(collection string, id uint64) { + w.modQ <- modJob{ + Collection: collection, + ID: id, + Store: false, + } +} + +func (w *writer) run(maxSeqNum uint64) { + defer w.wg.Done() + + var ( + mod modJob + ok bool + tx *sql.Tx + insertData *sql.Stmt + insertKV *sql.Stmt + deleteData *sql.Stmt + deleteKV *sql.Stmt + insertLog *sql.Stmt + err error + curSeqNum = maxSeqNum + now int64 + wgs = make([]*sync.WaitGroup, 10) + ) + +BEGIN: + + insertData = nil + deleteData = nil + + wgs = wgs[:0] + + mod, ok = <-w.modQ + if !ok { + return + } + + tx, err = w.db.Begin() + must(err) + + now = time.Now().Unix() + + insertLog, err = tx.Prepare(sqlInsertLog) + must(err) + +LOOP: + + if mod.Ready != nil { + wgs = append(wgs, mod.Ready) + } + + curSeqNum++ + if mod.Store { + goto STORE + } else { + goto DELETE + } + +STORE: + + if insertData == nil { + insertData, err = tx.Prepare(sqlInsertData) + must(err) + insertKV, err = tx.Prepare(sqlInsertKV) + must(err) + } + + _, err = insertData.Exec(curSeqNum, mod.Data) + must(err) + _, err = insertKV.Exec(mod.Collection, mod.ID, curSeqNum) + must(err) + _, err = insertLog.Exec(curSeqNum, now, mod.Collection, mod.ID, true) + must(err) + + RecycleDataBuf(mod.Data) + + goto NEXT + +DELETE: + + if deleteData == nil { + deleteData, err = tx.Prepare(sqlDeleteData) + must(err) + deleteKV, err = tx.Prepare(sqlDeleteKV) + must(err) + } + + _, err = deleteData.Exec(mod.Collection, mod.ID) + must(err) + _, err = deleteKV.Exec(mod.Collection, mod.ID) + must(err) + _, err = insertLog.Exec(curSeqNum, now, mod.Collection, mod.ID, false) + must(err) + + goto NEXT + +NEXT: + + select { + case mod, ok = <-w.modQ: + if ok { + goto LOOP + } + default: + } + + goto COMMIT + +COMMIT: + + must(tx.Commit()) + + for i := range wgs { + wgs[i].Done() + } + + goto BEGIN +} diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..b4c60a0 --- /dev/null +++ b/main_test.go @@ -0,0 +1,32 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "math/rand" + "os" + "testing" + "time" +) + +func TestMain(m *testing.M) { + rand.Seed(time.Now().UnixNano()) + os.Exit(m.Run()) +} + +func testWithDB(t *testing.T, name string, inner func(t *testing.T, db *DB)) { + t.Run(name, func(t *testing.T) { + root, err := os.MkdirTemp("", "") + must(err) + defer os.RemoveAll(root) + + db := OpenDB(root, true) + defer db.Close() + inner(t, db) + }) +} diff --git a/mapindex.go b/mapindex.go new file mode 100644 index 0000000..0dd978a --- /dev/null +++ b/mapindex.go @@ -0,0 +1,177 @@ +package mdb + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "sync" + + "git.crumpington.com/public/mdb/keyedmutex" +) + +type MapIndex[K comparable, T any] struct { + c *Collection[T] + _name string + keyLock keyedmutex.KeyedMutex[K] + mapLock sync.Mutex + m map[K]*T + getID func(*T) uint64 + getKey func(*T) K + include func(*T) bool +} + +func NewMapIndex[K comparable, T any]( + c *Collection[T], + name string, + getKey func(*T) K, + include func(*T) bool, +) *MapIndex[K, T] { + m := &MapIndex[K, T]{ + c: c, + _name: name, + keyLock: keyedmutex.New[K](), + m: map[K]*T{}, + getID: c.getID, + getKey: getKey, + include: include, + } + + c.indices = append(c.indices, m) + c.uniqueIndices = append(c.uniqueIndices, m) + return m +} + +// ---------------------------------------------------------------------------- + +func (m *MapIndex[K, T]) load(src map[uint64]*T) error { + for _, item := range src { + if m.include != nil && !m.include(item) { + continue + } + + k := m.getKey(item) + if _, ok := m.m[k]; ok { + return ErrDuplicate + } + + m.m[k] = item + } + + return nil +} + +func (m *MapIndex[K, T]) Get(k K) (t T, ok bool) { + if item, ok := m.mapGet(k); ok { + return *item, true + } + + return t, false +} + +func (m *MapIndex[K, T]) Update(k K, update func(T) (T, error)) error { + wrapped := func(item T) (T, error) { + new, err := update(item) + if err != nil { + return item, err + } + + return new, nil + } + + if item, ok := m.mapGet(k); ok { + return m.c.Update(m.getID(item), wrapped) + } + + return ErrNotFound +} + +func (m *MapIndex[K, T]) Delete(k K) { + if item, ok := m.mapGet(k); ok { + m.c.Delete(m.getID(item)) + } +} + +// ---------------------------------------------------------------------------- + +func (m *MapIndex[K, T]) insert(item *T) { + if m.include == nil || m.include(item) { + m.mapSet(m.getKey(item), item) + } +} + +func (m *MapIndex[K, T]) update(old, new *T) { + if m.include == nil || m.include(old) { + m.mapDelete(m.getKey(old)) + } + if m.include == nil || m.include(new) { + m.mapSet(m.getKey(new), new) + } +} + +func (m *MapIndex[K, T]) delete(item *T) { + if m.include == nil || m.include(item) { + m.mapDelete(m.getKey(item)) + } +} + +// ---------------------------------------------------------------------------- + +func (idx *MapIndex[K, T]) name() string { + return idx._name +} + +func (idx *MapIndex[K, T]) lock(item *T) { + if idx.include == nil || idx.include(item) { + idx.keyLock.Lock(idx.getKey(item)) + } +} + +func (idx *MapIndex[K, T]) unlock(item *T) { + if idx.include == nil || idx.include(item) { + idx.keyLock.Unlock(idx.getKey(item)) + } +} + +// Should hold item lock when calling. +func (idx *MapIndex[K, T]) insertConflict(new *T) bool { + if idx.include != nil && !idx.include(new) { + return false + } + _, ok := idx.Get(idx.getKey(new)) + return ok +} + +// Should hold item lock when calling. +func (idx *MapIndex[K, T]) updateConflict(new *T) bool { + if idx.include != nil && !idx.include(new) { + return false + } + cur, ok := idx.mapGet(idx.getKey(new)) + return ok && idx.getID(cur) != idx.getID(new) +} + +// ---------------------------------------------------------------------------- + +func (idx *MapIndex[K, T]) mapSet(k K, t *T) { + idx.mapLock.Lock() + idx.m[k] = t + idx.mapLock.Unlock() +} + +func (idx *MapIndex[K, T]) mapDelete(k K) { + idx.mapLock.Lock() + delete(idx.m, k) + idx.mapLock.Unlock() +} + +func (idx *MapIndex[K, T]) mapGet(k K) (*T, bool) { + idx.mapLock.Lock() + t, ok := idx.m[k] + idx.mapLock.Unlock() + return t, ok +} diff --git a/mapindex_ex_test.go b/mapindex_ex_test.go new file mode 100644 index 0000000..368cc67 --- /dev/null +++ b/mapindex_ex_test.go @@ -0,0 +1,35 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "fmt" + "reflect" +) + +func (m *MapIndex[K, T]) Equals(rhs *MapIndex[K, T]) error { + return m.EqualsMap(rhs.m) +} + +func (m *MapIndex[K, T]) EqualsMap(data map[K]*T) error { + if len(m.m) != len(data) { + return fmt.Errorf("Expected %d items, but found %d.", len(data), len(m.m)) + } + + for key, exp := range data { + val, ok := m.Get(key) + if !ok { + return fmt.Errorf("No value for %v. Expected: %v", key, *exp) + } + if !reflect.DeepEqual(val, *exp) { + return fmt.Errorf("Value mismatch %v: %v != %v", key, val, *exp) + } + } + + return nil +} diff --git a/mapindex_test.go b/mapindex_test.go new file mode 100644 index 0000000..22c7dae --- /dev/null +++ b/mapindex_test.go @@ -0,0 +1,692 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "errors" + "fmt" + "reflect" + "testing" +) + +func TestFullMapIndex(t *testing.T) { + + // Test against the emailMap index. + run := func(name string, inner func(t *testing.T, db *DB) map[string]*User) { + testWithDB(t, name, func(t *testing.T, db *DB) { + expected := inner(t, db) + + if err := db.Users.emailMap.EqualsMap(expected); err != nil { + t.Fatal(err) + } + + db.Close() + db = OpenDB(db.root, true) + + if err := db.Users.emailMap.EqualsMap(expected); err != nil { + t.Fatal(err) + } + }) + } + + run("insert", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + + user2, err := db.Users.c.Insert(*user) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*user, user2) { + t.Fatal(*user, user2) + } + users[user.Email] = user + } + + return users + }) + + run("delete", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + + user2, err := db.Users.c.Insert(*user) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*user, user2) { + t.Fatal(*user, user2) + } + users[user.Email] = user + } + + var id string + for key := range users { + id = key + break + } + + delete(users, id) + db.Users.emailMap.Delete(id) + + return users + }) + + run("update non-indexed field", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + + user2, err := db.Users.c.Insert(*user) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*user, user2) { + t.Fatal(*user, user2) + } + users[user.Email] = user + } + + var id string + for key := range users { + id = key + break + } + + err := db.Users.emailMap.Update(id, func(u User) (User, error) { + u.Name = "UPDATED" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + users[id].Name = "UPDATED" + + return users + }) + + run("update indexed field", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + + user2, err := db.Users.c.Insert(*user) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(*user, user2) { + t.Fatal(*user, user2) + } + users[user.Email] = user + } + + var id uint64 + var email string + for key := range users { + email = key + id = users[key].ID + break + } + + err := db.Users.c.Update(id, func(u User) (User, error) { + u.Email = "test@x.com" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + user := users[email] + user.Email = "test@x.com" + delete(users, email) + users[user.Email] = user + + return users + }) + + run("update change id error", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + if _, err := db.Users.c.Insert(*user); err != nil { + t.Fatal(err) + } + users[user.Email] = user + } + + var email string + for key := range users { + email = key + break + } + + err := db.Users.emailMap.Update(email, func(u User) (User, error) { + u.ID++ + return u, nil + }) + if err != ErrMismatchedIDs { + t.Fatal(err) + } + + return users + }) + + run("update function error", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + if _, err := db.Users.c.Insert(*user); err != nil { + t.Fatal(err) + } + users[user.Email] = user + } + + var email string + for key := range users { + email = key + break + } + + myErr := errors.New("hello") + + err := db.Users.emailMap.Update(email, func(u User) (User, error) { + return u, myErr + }) + if err != myErr { + t.Fatal(err) + } + + return users + }) + + run("update ErrAbortUpdate", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + if _, err := db.Users.c.Insert(*user); err != nil { + t.Fatal(err) + } + users[user.Email] = user + } + + var email string + for key := range users { + email = key + break + } + + err := db.Users.emailMap.Update(email, func(u User) (User, error) { + return u, ErrAbortUpdate + }) + if err != nil { + t.Fatal(err) + } + + return users + }) + + run("update ErrNotFound", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + for i := uint64(1); i < 10; i++ { + user := &User{ + ID: db.Users.c.NextID(), + Email: fmt.Sprintf("a.%d@c.com", i), + Name: fmt.Sprintf("name.%d", i), + ExtID: fmt.Sprintf("EXTID.%d", i), + } + if _, err := db.Users.c.Insert(*user); err != nil { + t.Fatal(err) + } + users[user.Email] = user + } + + var email string + for key := range users { + email = key + break + } + + err := db.Users.emailMap.Update(email+"x", func(u User) (User, error) { + return u, nil + }) + if err != ErrNotFound { + t.Fatal(err) + } + + return users + }) + + run("insert conflict", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + user := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a", ExtID: ""} + + if _, err := db.Users.c.Insert(*user); err != nil { + t.Fatal(err) + } + users[user.Email] = user + + user2 := User{ + ID: db.Users.c.NextID(), + Email: "a@b.com", + Name: "someone else", + ExtID: "123", + } + + _, err := db.Users.c.Insert(user2) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + + return users + }) + + run("update conflict", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{} + + user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a"} + user2 := &User{ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x"} + + users[user1.Email] = user1 + users[user2.Email] = user2 + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(user2.ID, func(u User) (User, error) { + u.Email = "a@b.com" + return u, nil + }) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + + err = db.Users.emailMap.Update("x@y.com", func(u User) (User, error) { + u.Email = "a@b.com" + return u, nil + }) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + + return users + }) +} + +func TestPartialMapIndex(t *testing.T) { + // Test against the extID map index. + run := func(name string, inner func(t *testing.T, db *DB) map[string]*User) { + testWithDB(t, name, func(t *testing.T, db *DB) { + expected := inner(t, db) + + if err := db.Users.extIDMap.EqualsMap(expected); err != nil { + t.Fatal(err) + } + + db.Close() + db = OpenDB(db.root, true) + + if err := db.Users.extIDMap.EqualsMap(expected); err != nil { + t.Fatal(err) + } + }) + } + + run("insert", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x", ExtID: "x"}, + } + user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "a"} + + if _, err := db.Users.c.Insert(*users["x"]); err != nil { + t.Fatal(err) + } + if _, err := db.Users.c.Insert(*user1); err != nil { + t.Fatal(err) + } + + return users + }) + + run("insert with conflict", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "x", ExtID: "x"}, + } + user1 := &User{ID: db.Users.c.NextID(), Email: "a@b.com", Name: "y", ExtID: "x"} + + if _, err := db.Users.c.Insert(*users["x"]); err != nil { + t.Fatal(err) + } + if _, err := db.Users.c.Insert(*user1); !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + + return users + }) + + run("insert and delete in index", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc", ExtID: "z"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + // Delete from index and from collection. + db.Users.extIDMap.Delete("x") + db.Users.c.Delete(users["z"].ID) + + delete(users, "x") + delete(users, "z") + + return users + }) + + run("insert and delete outside index", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + // Delete from index and from collection. + db.Users.extIDMap.Delete("x") + db.Users.c.Delete(users["z"].ID) + + delete(users, "x") + delete(users, "z") + + return users + }) + + run("update outside index", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) { + u.Name = "Whatever" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + delete(users, "z") // No ExtID => not in index. + return users + }) + + run("update into index", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) { + u.ExtID = "z" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + users["z"].ExtID = "z" + + return users + }) + + run("update out of index", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "a@b.com", Name: "cc"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.extIDMap.Update("y", func(u User) (User, error) { + u.ExtID = "" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + err = db.Users.c.Update(users["x"].ID, func(u User) (User, error) { + u.ExtID = "" + return u, nil + }) + if err != nil { + t.Fatal(err) + } + + delete(users, "x") + delete(users, "z") + delete(users, "y") + + return users + }) + + run("update function error", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + myErr := errors.New("hello") + + err := db.Users.extIDMap.Update("y", func(u User) (User, error) { + u.Email = "blah" + return u, myErr + }) + if err != myErr { + t.Fatal(err) + } + + return users + }) + + run("update ErrAbortUpdate", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.extIDMap.Update("y", func(u User) (User, error) { + u.Email = "blah" + return u, ErrAbortUpdate + }) + if err != nil { + t.Fatal(err) + } + + return users + }) + + run("update ErrNotFound", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.extIDMap.Update("z", func(u User) (User, error) { + u.Name = "blah" + return u, nil + }) + if !errors.Is(err, ErrNotFound) { + t.Fatal(err) + } + + return users + }) + + run("update conflict in to in", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "z@z.com", Name: "zz", ExtID: "z"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.extIDMap.Update("z", func(u User) (User, error) { + u.ExtID = "x" + return u, nil + }) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + + return users + }) + + run("update conflict out to in", func(t *testing.T, db *DB) map[string]*User { + users := map[string]*User{ + "x": {ID: db.Users.c.NextID(), Email: "x@y.com", Name: "aa", ExtID: "x"}, + "y": {ID: db.Users.c.NextID(), Email: "q@r.com", Name: "bb", ExtID: "y"}, + "z": {ID: db.Users.c.NextID(), Email: "z@z.com", Name: "zz"}, + } + + for _, u := range users { + if _, err := db.Users.c.Insert(*u); err != nil { + t.Fatal(err) + } + } + + err := db.Users.c.Update(users["z"].ID, func(u User) (User, error) { + u.ExtID = "x" + return u, nil + }) + if !errors.Is(err, ErrDuplicate) { + t.Fatal(err) + } + + delete(users, "z") + + return users + }) +} + +func TestMapIndex_load_ErrDuplicate(t *testing.T) { + testWithDB(t, "", func(t *testing.T, db *DB) { + idx := NewMapIndex( + db.Users.c, + "broken", + func(u *User) string { return u.Name }, + nil) + + users := map[uint64]*User{ + 1: {ID: 1, Email: "x@y.com", Name: "aa", ExtID: "x"}, + 2: {ID: 2, Email: "b@c.com", Name: "aa", ExtID: "y"}, + } + + if err := idx.load(users); err != ErrDuplicate { + t.Fatal(err) + } + }) +} diff --git a/shipping_test.go b/shipping_test.go new file mode 100644 index 0000000..9065f01 --- /dev/null +++ b/shipping_test.go @@ -0,0 +1,186 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "math/rand" + "os" + "sync" + "testing" + "time" + + "git.crumpington.com/public/mdb/testconn" +) + +func TestShipping(t *testing.T) { + run := func(name string, inner func(t *testing.T, db1 *DB, db2 *DB, network *testconn.Network)) { + t.Run(name, func(t *testing.T) { + root1, err := os.MkdirTemp("", "") + must(err) + defer os.RemoveAll(root1) + root2, err := os.MkdirTemp("", "") + must(err) + defer os.RemoveAll(root2) + + db1 := OpenDB(root1, true) + defer db1.Close() + db2 := OpenDB(root2, false) + defer db2.Close() + + inner(t, db1, db2, testconn.NewNetwork()) + }) + } + + run("simple", func(t *testing.T, db, db2 *DB, network *testconn.Network) { + wg := sync.WaitGroup{} + wg.Add(2) + + // Send in background. + go func() { + defer wg.Done() + conn := network.Accept() + db.SyncSend(conn) + }() + + // Recv in background. + go func() { + defer wg.Done() + conn := network.Dial() + db2.SyncRecv(conn) + }() + + for i := 0; i < 100; i++ { + db.RandAction() + } + + db.WaitForSync(db2) + network.CloseClient() + wg.Wait() + + if err := db.Equals(db2); err != nil { + t.Fatal(err) + } + }) + + run("simple multiple writers", func(t *testing.T, db, db2 *DB, network *testconn.Network) { + wg := sync.WaitGroup{} + + // Send in background. + wg.Add(1) + go func() { + defer wg.Done() + conn := network.Accept() + db.SyncSend(conn) + }() + + // Recv in background. + wg.Add(1) + go func() { + defer wg.Done() + conn := network.Dial() + db2.SyncRecv(conn) + }() + + updateWG := sync.WaitGroup{} + updateWG.Add(64) + for i := 0; i < 64; i++ { + go func() { + defer updateWG.Done() + for j := 0; j < 1024; j++ { + db.RandAction() + } + }() + } + + updateWG.Wait() + + db.WaitForSync(db2) + network.CloseClient() + wg.Wait() + + if err := db.Equals(db2); err != nil { + t.Fatal(err) + } + }) + + run("unstable network", func(t *testing.T, db, db2 *DB, network *testconn.Network) { + sleepTimeout := time.Millisecond + + updateWG := sync.WaitGroup{} + updateWG.Add(64) + for i := 0; i < 64; i++ { + go func() { + defer updateWG.Done() + for j := 0; j < 4096; j++ { + time.Sleep(sleepTimeout) + db.RandAction() + } + }() + } + + updating := &atomicBool{} + updating.Set(true) + + go func() { + updateWG.Wait() + updating.Set(false) + }() + + // Recv in background. + recving := &atomicBool{} + recving.Set(true) + go func() { + for { + // Stop when no longer updating and WAL files match. + if !updating.Get() { + if db.MaxSeqNum() == db2.MaxSeqNum() { + recving.Set(false) + return + } + } + + if conn := network.Dial(); conn != nil { + db2.SyncRecv(conn) + } + } + }() + + // Send in background. + sending := &atomicBool{} + sending.Set(true) + go func() { + for { + // Stop when no longer updating and WAL files match. + if !updating.Get() { + if db.MaxSeqNum() == db2.MaxSeqNum() { + sending.Set(false) + return + } + } + + if conn := network.Accept(); conn != nil { + db.SyncSend(conn) + } + } + }() + + // Interrupt network periodically as long as sending or receiving. + for sending.Get() || recving.Get() { + time.Sleep(time.Duration(rand.Intn(10 * int(sleepTimeout)))) + if rand.Float64() < 0.5 { + network.CloseClient() + } else { + network.CloseServer() + } + } + + if err := db.Equals(db2); err != nil { + t.Fatal(err) + } + }) +} diff --git a/testconn/net.go b/testconn/net.go new file mode 100644 index 0000000..8b0537f --- /dev/null +++ b/testconn/net.go @@ -0,0 +1,87 @@ +package testconn + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "net" + "sync" + "time" +) + +type Network struct { + lock sync.Mutex + // Current connections. + cConn net.Conn + sConn net.Conn + + acceptQ chan net.Conn +} + +func NewNetwork() *Network { + return &Network{ + acceptQ: make(chan net.Conn, 1), + } +} + +func (n *Network) Dial() net.Conn { + cc, sc := net.Pipe() + func() { + n.lock.Lock() + defer n.lock.Unlock() + if n.cConn != nil { + n.cConn.Close() + n.cConn = nil + } + select { + case n.acceptQ <- sc: + n.cConn = cc + default: + cc = nil + } + }() + return cc +} + +func (n *Network) Accept() net.Conn { + var sc net.Conn + select { + case sc = <-n.acceptQ: + case <-time.After(time.Second): + return nil + } + + func() { + n.lock.Lock() + defer n.lock.Unlock() + if n.sConn != nil { + n.sConn.Close() + n.sConn = nil + } + n.sConn = sc + }() + return sc +} + +func (n *Network) CloseClient() { + n.lock.Lock() + defer n.lock.Unlock() + if n.cConn != nil { + n.cConn.Close() + n.cConn = nil + } +} + +func (n *Network) CloseServer() { + n.lock.Lock() + defer n.lock.Unlock() + if n.sConn != nil { + n.sConn.Close() + n.sConn = nil + } +} diff --git a/testdb_test.go b/testdb_test.go new file mode 100644 index 0000000..3a4fee7 --- /dev/null +++ b/testdb_test.go @@ -0,0 +1,272 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "errors" + "fmt" + "math/rand" + "net/mail" + "strings" + "time" +) + +// ---------------------------------------------------------------------------- +// Validate errors. +// ---------------------------------------------------------------------------- + +var ErrInvalidName = errors.New("invalid name") + +// ---------------------------------------------------------------------------- +// User Collection +// ---------------------------------------------------------------------------- + +type User struct { + ID uint64 + Email string + Name string + ExtID string +} + +type Users struct { + c *Collection[User] + emailMap *MapIndex[string, User] // Full map index. + emailBTree *BTreeIndex[User] // Full btree index. + nameBTree *BTreeIndex[User] // Full btree with duplicates. + extIDMap *MapIndex[string, User] // Partial map index. + extIDBTree *BTreeIndex[User] // Partial btree index. +} + +func userGetID(u *User) uint64 { return u.ID } + +func userSanitize(u *User) { + u.Name = strings.TrimSpace(u.Name) + e, err := mail.ParseAddress(strings.ToLower(strings.TrimSpace(u.Email))) + if err == nil { + u.Email = e.Address + } +} + +func userValidate(u *User) error { + if len(u.Name) == 0 { + return ErrInvalidName + } + return nil +} + +// ---------------------------------------------------------------------------- +// Account Collection +// ---------------------------------------------------------------------------- + +type Account struct { + ID uint64 + Name string +} + +type Accounts struct { + c *Collection[Account] + nameMap *MapIndex[string, Account] +} + +func accountGetID(a *Account) uint64 { return a.ID } + +// ---------------------------------------------------------------------------- +// Database +// ---------------------------------------------------------------------------- + +type DB struct { + *Database + root string + Users Users + Accounts Accounts +} + +func OpenDB(root string, primary bool) *DB { + db := &DB{root: root} + if primary { + db.Database = NewPrimary(root) + } else { + db.Database = NewSecondary(root) + } + + db.Users = Users{} + db.Users.c = NewCollection(db.Database, "users", userGetID) + db.Users.c.SetSanitize(userSanitize) + db.Users.c.SetValidate(userValidate) + + db.Users.emailMap = NewMapIndex( + db.Users.c, + "email", + func(u *User) string { return u.Email }, + nil) + + db.Users.emailBTree = NewBTreeIndex( + db.Users.c, + func(lhs, rhs *User) bool { return lhs.Email < rhs.Email }, + nil) + + db.Users.nameBTree = NewBTreeIndex( + db.Users.c, + func(lhs, rhs *User) bool { + if lhs.Name != rhs.Name { + return lhs.Name < rhs.Name + } + return lhs.ID < rhs.ID + }, + nil) + + db.Users.extIDMap = NewMapIndex( + db.Users.c, + "extID", + func(u *User) string { return u.ExtID }, + func(u *User) bool { return u.ExtID != "" }) + + db.Users.extIDBTree = NewBTreeIndex( + db.Users.c, + func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID }, + func(u *User) bool { return u.ExtID != "" }) + + db.Accounts = Accounts{} + db.Accounts.c = NewCollection(db.Database, "accounts", accountGetID) + + db.Accounts.nameMap = NewMapIndex( + db.Accounts.c, + "name", + func(a *Account) string { return a.Name }, + nil) + + db.Start() + + return db +} + +func (db *DB) Equals(rhs *DB) error { + db.WaitForSync(rhs) + + // Users: itemMap. + if err := db.Users.c.items.Equals(rhs.Users.c.items); err != nil { + return fmt.Errorf("%w: Users.c.items not equal", err) + } + + // Users: emailMap + if err := db.Users.emailMap.Equals(rhs.Users.emailMap); err != nil { + return fmt.Errorf("%w: Users.emailMap not equal", err) + } + + // Users: emailBTree + if err := db.Users.emailBTree.Equals(rhs.Users.emailBTree); err != nil { + return fmt.Errorf("%w: Users.emailBTree not equal", err) + } + + // Users: nameBTree + if err := db.Users.nameBTree.Equals(rhs.Users.nameBTree); err != nil { + return fmt.Errorf("%w: Users.nameBTree not equal", err) + } + + // Users: extIDMap + if err := db.Users.extIDMap.Equals(rhs.Users.extIDMap); err != nil { + return fmt.Errorf("%w: Users.extIDMap not equal", err) + } + + // Users: extIDBTree + if err := db.Users.extIDBTree.Equals(rhs.Users.extIDBTree); err != nil { + return fmt.Errorf("%w: Users.extIDBTree not equal", err) + } + + // Accounts: itemMap + if err := db.Accounts.c.items.Equals(rhs.Accounts.c.items); err != nil { + return fmt.Errorf("%w: Accounts.c.items not equal", err) + } + + // Accounts: nameMap + if err := db.Accounts.nameMap.Equals(rhs.Accounts.nameMap); err != nil { + return fmt.Errorf("%w: Accounts.nameMap not equal", err) + } + + return nil +} + +// Wait for two databases to become synchronized. +func (db *DB) WaitForSync(rhs *DB) { + for { + if db.MaxSeqNum() == rhs.MaxSeqNum() { + return + } + time.Sleep(100 * time.Millisecond) + } +} + +var ( + randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 2, 13, 14, 15, 16} +) + +func (db *DB) RandAction() { + if rand.Float32() < 0.3 { + db.randActionAccount() + } else { + db.randActionUser() + } +} + +func (db *DB) randActionAccount() { + id := randIDs[rand.Intn(len(randIDs))] + f := rand.Float32() + + _, exists := db.Accounts.c.Get(id) + if !exists { + db.Accounts.c.Insert(Account{ + ID: id, + Name: randString(), + }) + return + } + + if f < 0.05 { + db.Accounts.c.Delete(id) + return + } + db.Accounts.c.Update(id, func(a Account) (Account, error) { + a.Name = randString() + return a, nil + }) +} + +func (db *DB) randActionUser() { + id := randIDs[rand.Intn(len(randIDs))] + f := rand.Float32() + + _, exists := db.Users.c.Get(id) + if !exists { + user := User{ + ID: id, + Email: randStringShort() + "@domain.com", + Name: randString(), + } + if f < 0.1 { + user.ExtID = randString() + } + db.Users.c.Insert(user) + return + } + + if f < 0.05 { + db.Users.c.Delete(id) + return + } + + db.Users.c.Update(id, func(a User) (User, error) { + a.Name = randString() + if f < 0.1 { + a.ExtID = randString() + } else { + a.ExtID = "" + } + a.Email = randStringShort() + "@domain.com" + return a, nil + }) +} diff --git a/types.go b/types.go new file mode 100644 index 0000000..e5e46e6 --- /dev/null +++ b/types.go @@ -0,0 +1,36 @@ +package mdb + +/* +Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +type WALStatus struct { + MaxSeqNumKV uint64 + MaxSeqNumWAL uint64 +} + +type itemIndex[T any] interface { + load(m map[uint64]*T) error + insert(*T) + update(old, new *T) // Old and new MUST have the same ID. + delete(*T) +} + +type itemUniqueIndex[T any] interface { + load(m map[uint64]*T) error + name() string + lock(*T) + unlock(*T) + insertConflict(*T) bool + updateConflict(*T) bool +} + +type dbCollection interface { + loadData() + onStore(uint64, []byte) // For WAL following. + onDelete(uint64) // For WAL following. +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..d92ae5b --- /dev/null +++ b/util.go @@ -0,0 +1,14 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +func must(err error) { + if err != nil { + panic(err) + } +} diff --git a/util_test.go b/util_test.go new file mode 100644 index 0000000..5bcc2b0 --- /dev/null +++ b/util_test.go @@ -0,0 +1,47 @@ +package mdb + +/*Copyright (c) 2022, John David Lee +All rights reserved. + +This source code is licensed under the BSD-style license found in the +LICENSE file in the root directory of this source tree. +*/ + +import ( + "crypto/rand" + "encoding/hex" + mrand "math/rand" + "sync/atomic" +) + +func randStringShort() string { + s := randString() + if len(s) > 2 { + return s[:2] + } + return s +} + +func randString() string { + buf := make([]byte, 1+mrand.Intn(10)) + if _, err := rand.Read(buf); err != nil { + panic(err) + } + return hex.EncodeToString(buf) +} + +type atomicBool struct { + i int64 +} + +func (a *atomicBool) Get() bool { + return atomic.LoadInt64(&a.i) == 1 +} + +func (a *atomicBool) Set(b bool) { + if b { + atomic.StoreInt64(&a.i, 1) + } else { + atomic.StoreInt64(&a.i, 0) + } +}