Initial commit

This commit is contained in:
jdl
2023-10-13 11:43:27 +02:00
commit 71eb6b0c7e
121 changed files with 11493 additions and 0 deletions

25
mdb/change/binary.go Normal file
View File

@@ -0,0 +1,25 @@
package change
import (
"encoding/binary"
"io"
"git.crumpington.com/public/jldb/lib/errs"
)
func writeBin(w io.Writer, data ...any) error {
for _, value := range data {
if err := binary.Write(w, binary.LittleEndian, value); err != nil {
return errs.IO.WithErr(err)
}
}
return nil
}
func readBin(r io.Reader, ptrs ...any) error {
for _, ptr := range ptrs {
if err := binary.Read(r, binary.LittleEndian, ptr); err != nil {
return errs.IO.WithErr(err)
}
}
return nil
}

98
mdb/change/change.go Normal file
View File

@@ -0,0 +1,98 @@
package change
import (
"io"
"git.crumpington.com/public/jldb/lib/errs"
)
// ----------------------------------------------------------------------------
// Change
// ----------------------------------------------------------------------------
// The Change type encodes a change (store / delete) to be applied to a
// pagefile.
type Change struct {
CollectionID uint64
ItemID uint64
Store bool
Data []byte
WritePageIDs []uint64
ClearPageIDs []uint64
}
func (ch Change) writeTo(w io.Writer) error {
dataSize := int64(len(ch.Data))
if !ch.Store {
dataSize = -1
}
err := writeBin(w,
ch.CollectionID,
ch.ItemID,
dataSize,
uint64(len(ch.WritePageIDs)),
uint64(len(ch.ClearPageIDs)),
ch.WritePageIDs,
ch.ClearPageIDs)
if err != nil {
return err
}
if ch.Store {
if _, err := w.Write(ch.Data); err != nil {
return errs.IO.WithErr(err)
}
}
return nil
}
func (ch *Change) readFrom(r io.Reader) error {
var pageCount, clearCount uint64
var dataSize int64
err := readBin(r,
&ch.CollectionID,
&ch.ItemID,
&dataSize,
&pageCount,
&clearCount)
if err != nil {
return err
}
if uint64(cap(ch.WritePageIDs)) < pageCount {
ch.WritePageIDs = make([]uint64, pageCount)
}
ch.WritePageIDs = ch.WritePageIDs[:pageCount]
if uint64(cap(ch.ClearPageIDs)) < clearCount {
ch.ClearPageIDs = make([]uint64, clearCount)
}
ch.ClearPageIDs = ch.ClearPageIDs[:clearCount]
if err = readBin(r, ch.WritePageIDs); err != nil {
return err
}
if err = readBin(r, ch.ClearPageIDs); err != nil {
return err
}
ch.Store = dataSize != -1
if ch.Store {
if int64(cap(ch.Data)) < dataSize {
ch.Data = make([]byte, dataSize)
}
ch.Data = ch.Data[:dataSize]
if _, err := r.Read(ch.Data); err != nil {
return errs.IO.WithErr(err)
}
} else {
ch.Data = ch.Data[:0]
}
return nil
}

67
mdb/change/change_test.go Normal file
View File

@@ -0,0 +1,67 @@
package change
import (
"bytes"
"reflect"
"testing"
)
func (lhs Change) AssertEqual(t *testing.T, rhs Change) {
if lhs.CollectionID != rhs.CollectionID {
t.Fatal(lhs.CollectionID, rhs.CollectionID)
}
if lhs.ItemID != rhs.ItemID {
t.Fatal(lhs.ItemID, rhs.ItemID)
}
if lhs.Store != rhs.Store {
t.Fatal(lhs.Store, rhs.Store)
}
if len(lhs.Data) != len(rhs.Data) {
t.Fatal(len(lhs.Data), len(rhs.Data))
}
if len(lhs.Data) != 0 {
if !reflect.DeepEqual(lhs.Data, rhs.Data) {
t.Fatal(lhs.Data, rhs.Data)
}
}
if len(lhs.WritePageIDs) != len(rhs.WritePageIDs) {
t.Fatal(len(lhs.WritePageIDs), len(rhs.WritePageIDs))
}
if len(lhs.WritePageIDs) != 0 {
if !reflect.DeepEqual(lhs.WritePageIDs, rhs.WritePageIDs) {
t.Fatal(lhs.WritePageIDs, rhs.WritePageIDs)
}
}
if len(lhs.ClearPageIDs) != len(rhs.ClearPageIDs) {
t.Fatal(len(lhs.ClearPageIDs), len(rhs.ClearPageIDs))
}
if len(lhs.ClearPageIDs) != 0 {
if !reflect.DeepEqual(lhs.ClearPageIDs, rhs.ClearPageIDs) {
t.Fatal(lhs.ClearPageIDs, rhs.ClearPageIDs)
}
}
}
func TestChangeWriteToReadFrom(t *testing.T) {
out := Change{}
for i := 0; i < 100; i++ {
in := randChange()
buf := &bytes.Buffer{}
if err := in.writeTo(buf); err != nil {
t.Fatal(err)
}
if err := out.readFrom(buf); err != nil {
t.Fatal(err)
}
in.AssertEqual(t, out)
}
}

35
mdb/change/encoding.go Normal file
View File

@@ -0,0 +1,35 @@
package change
import "io"
func Write(changes []Change, w io.Writer) error {
count := uint64(len(changes))
if err := writeBin(w, count); err != nil {
return err
}
for _, c := range changes {
if err := c.writeTo(w); err != nil {
return err
}
}
return nil
}
func Read(changes []Change, r io.Reader) ([]Change, error) {
var count uint64
if err := readBin(r, &count); err != nil {
return changes, err
}
if uint64(len(changes)) < count {
changes = make([]Change, count)
}
changes = changes[:count]
for i := range changes {
if err := changes[i].readFrom(r); err != nil {
return changes, err
}
}
return changes, nil
}

View File

@@ -0,0 +1,64 @@
package change
import (
"bytes"
crand "crypto/rand"
"math/rand"
"testing"
)
func randChange() Change {
c := Change{
CollectionID: rand.Uint64(),
ItemID: rand.Uint64(),
Store: rand.Float32() < 0.5,
}
if c.Store {
data := make([]byte, 1+rand.Intn(100))
crand.Read(data)
c.Data = data
}
c.WritePageIDs = make([]uint64, rand.Intn(10))
for i := range c.WritePageIDs {
c.WritePageIDs[i] = rand.Uint64()
}
c.ClearPageIDs = make([]uint64, rand.Intn(10))
for i := range c.ClearPageIDs {
c.ClearPageIDs[i] = rand.Uint64()
}
return c
}
func randChangeSlice() []Change {
changes := make([]Change, 1+rand.Intn(10))
for i := range changes {
changes[i] = randChange()
}
return changes
}
func TestWriteRead(t *testing.T) {
in := randChangeSlice()
var out []Change
buf := &bytes.Buffer{}
if err := Write(in, buf); err != nil {
t.Fatal(err)
}
out, err := Read(out, buf)
if err != nil {
t.Fatal(err)
}
if len(in) != len(out) {
t.Fatal(len(in), len(out))
}
for i := range in {
in[i].AssertEqual(t, out[i])
}
}

View File

@@ -0,0 +1,24 @@
package mdb
type collectionState[T any] struct {
Version uint64
Indices []indexState[T]
}
func (c *collectionState[T]) clone(version uint64) *collectionState[T] {
indices := make([]indexState[T], len(c.Indices))
for i := range indices {
indices[i] = c.Indices[i].clone()
}
return &collectionState[T]{
Version: version,
Indices: indices,
}
}
// Add an index returning it's assigned ID.
func (c *collectionState[T]) addIndex(idx indexState[T]) uint64 {
id := uint64(len(c.Indices))
c.Indices = append(c.Indices, idx)
return id
}

347
mdb/collection.go Normal file
View File

@@ -0,0 +1,347 @@
package mdb
import (
"bytes"
"encoding/json"
"errors"
"hash/crc64"
"git.crumpington.com/public/jldb/lib/errs"
"unsafe"
"github.com/google/btree"
)
type Collection[T any] struct {
db *Database
name string
collectionID uint64
copy func(*T) *T
sanitize func(*T)
validate func(*T) error
indices []Index[T]
uniqueIndices []Index[T]
ByID Index[T]
buf *bytes.Buffer
}
type CollectionConfig[T any] struct {
Copy func(*T) *T
Sanitize func(*T)
Validate func(*T) error
}
func NewCollection[T any](db *Database, name string, conf *CollectionConfig[T]) *Collection[T] {
if conf == nil {
conf = &CollectionConfig[T]{}
}
if conf.Copy == nil {
conf.Copy = func(from *T) *T {
to := new(T)
*to = *from
return to
}
}
if conf.Sanitize == nil {
conf.Sanitize = func(*T) {}
}
if conf.Validate == nil {
conf.Validate = func(*T) error {
return nil
}
}
c := &Collection[T]{
db: db,
name: name,
collectionID: crc64.Checksum([]byte(name), crc64Table),
copy: conf.Copy,
sanitize: conf.Sanitize,
validate: conf.Validate,
indices: []Index[T]{},
uniqueIndices: []Index[T]{},
buf: &bytes.Buffer{},
}
db.addCollection(c.collectionID, c, &collectionState[T]{
Indices: []indexState[T]{},
})
c.ByID = c.addIndex(indexConfig[T]{
Name: "ByID",
Unique: true,
Compare: func(lhs, rhs *T) int {
l := c.getID(lhs)
r := c.getID(rhs)
if l < r {
return -1
} else if l > r {
return 1
}
return 0
},
})
return c
}
func (c Collection[T]) Name() string {
return c.name
}
type indexConfig[T any] struct {
Name string
Unique bool
// If an index isn't unique, an additional comparison by ID is added if
// two items are otherwise equal.
Compare func(lhs, rhs *T) int
// If not nil, indicates if a given item should be in the index.
Include func(item *T) bool
}
func (c Collection[T]) Get(tx *Snapshot, id uint64) (*T, bool) {
x := new(T)
c.setID(x, id)
return c.ByID.Get(tx, x)
}
func (c Collection[T]) List(tx *Snapshot, ids []uint64, out []*T) []*T {
if len(ids) == 0 {
return out[:0]
}
if cap(out) < len(ids) {
out = make([]*T, len(ids))
}
out = out[:0]
for _, id := range ids {
item, ok := c.Get(tx, id)
if ok {
out = append(out, item)
}
}
return out
}
// AddIndex: Add an index to the collection.
func (c *Collection[T]) addIndex(conf indexConfig[T]) Index[T] {
var less func(*T, *T) bool
if conf.Unique {
less = func(lhs, rhs *T) bool {
return conf.Compare(lhs, rhs) == -1
}
} else {
less = func(lhs, rhs *T) bool {
switch conf.Compare(lhs, rhs) {
case -1:
return true
case 1:
return false
default:
return c.getID(lhs) < c.getID(rhs)
}
}
}
indexState := indexState[T]{
BTree: btree.NewG(256, less),
}
index := Index[T]{
collectionID: c.collectionID,
name: conf.Name,
indexID: c.getState(c.db.Snapshot()).addIndex(indexState),
include: conf.Include,
copy: c.copy,
}
c.indices = append(c.indices, index)
if conf.Unique {
c.uniqueIndices = append(c.uniqueIndices, index)
}
return index
}
func (c Collection[T]) Insert(tx *Snapshot, userItem *T) error {
if err := c.ensureMutable(tx); err != nil {
return err
}
item := c.copy(userItem)
c.sanitize(item)
if err := c.validate(item); err != nil {
return err
}
for i := range c.uniqueIndices {
if c.uniqueIndices[i].insertConflict(tx, item) {
return ErrDuplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name)
}
}
tx.store(c.collectionID, c.getID(item), item)
for i := range c.indices {
c.indices[i].insert(tx, item)
}
return nil
}
func (c Collection[T]) Update(tx *Snapshot, userItem *T) error {
if err := c.ensureMutable(tx); err != nil {
return err
}
item := c.copy(userItem)
c.sanitize(item)
if err := c.validate(item); err != nil {
return err
}
old, ok := c.ByID.get(tx, item)
if !ok {
return ErrNotFound
}
for i := range c.uniqueIndices {
if c.uniqueIndices[i].updateConflict(tx, item) {
return ErrDuplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name)
}
}
tx.store(c.collectionID, c.getID(item), item)
for i := range c.indices {
c.indices[i].update(tx, old, item)
}
return nil
}
func (c Collection[T]) Upsert(tx *Snapshot, item *T) error {
err := c.Insert(tx, item)
if err == nil {
return nil
}
if errors.Is(err, ErrDuplicate) {
return c.Update(tx, item)
}
return err
}
func (c Collection[T]) Delete(tx *Snapshot, itemID uint64) error {
if err := c.ensureMutable(tx); err != nil {
return err
}
return c.deleteItem(tx, itemID)
}
func (c Collection[T]) getByID(tx *Snapshot, itemID uint64) (*T, bool) {
x := new(T)
c.setID(x, itemID)
return c.ByID.get(tx, x)
}
func (c Collection[T]) ensureMutable(tx *Snapshot) error {
if !tx.writable() {
return ErrReadOnly
}
state := c.getState(tx)
if state.Version != tx.version {
tx.collections[c.collectionID] = state.clone(tx.version)
}
return nil
}
// For initial data loading.
func (c Collection[T]) insertItem(tx *Snapshot, itemID uint64, data []byte) error {
item := new(T)
if err := json.Unmarshal(data, item); err != nil {
return errs.Encoding.WithErr(err).WithCollection(c.name)
}
// Check for insert conflict.
for _, index := range c.uniqueIndices {
if index.insertConflict(tx, item) {
return ErrDuplicate
}
}
// Do the insert.
for _, index := range c.indices {
index.insert(tx, item)
}
return nil
}
func (c Collection[T]) deleteItem(tx *Snapshot, itemID uint64) error {
item, ok := c.getByID(tx, itemID)
if !ok {
return ErrNotFound
}
tx.delete(c.collectionID, itemID)
for i := range c.indices {
c.indices[i].delete(tx, item)
}
return nil
}
// upsertItem inserts or updates the item with itemID and the given serialized
// form. It's called by
func (c Collection[T]) upsertItem(tx *Snapshot, itemID uint64, data []byte) error {
item, ok := c.getByID(tx, itemID)
if ok {
tx.delete(c.collectionID, itemID)
for i := range c.indices {
c.indices[i].delete(tx, item)
}
}
item = new(T)
if err := json.Unmarshal(data, item); err != nil {
return errs.Encoding.WithErr(err).WithCollection(c.name)
}
// Do the insert.
for _, index := range c.indices {
index.insert(tx, item)
}
return nil
}
func (c Collection[T]) getID(t *T) uint64 {
return *((*uint64)(unsafe.Pointer(t)))
}
func (c Collection[T]) setID(t *T, id uint64) {
*((*uint64)(unsafe.Pointer(t))) = id
}
func (c Collection[T]) getState(tx *Snapshot) *collectionState[T] {
return tx.collections[c.collectionID].(*collectionState[T])
}

View File

@@ -0,0 +1,76 @@
package mdb
import (
"git.crumpington.com/public/jldb/lib/errs"
"log"
"os"
"os/exec"
"testing"
"time"
)
func TestCrashConsistency(t *testing.T) {
if testing.Short() {
t.Skip("Sipping test in short mode.")
}
// Build the test binary.
err := exec.Command(
"go", "build",
"-o", "testing/crashconsistency/p",
"testing/crashconsistency/main.go").Run()
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll("testing/crashconsistency/p")
rootDir := t.TempDir()
defer os.RemoveAll(rootDir)
for i := 0; i < 32; i++ {
cmd := exec.Command("testing/crashconsistency/p", rootDir)
if err := cmd.Start(); err != nil {
t.Fatal(err)
}
time.Sleep(time.Second / 2)
for {
if err := cmd.Process.Kill(); err != nil {
log.Printf("Kill failed: %v", err)
time.Sleep(time.Second)
continue
}
break
}
var (
db DataDB
err error
)
for {
db, err = OpenDataDB(rootDir)
if err == nil {
break
}
if errs.Locked.Is(err) {
log.Printf("Locked.")
time.Sleep(time.Second / 10)
continue
}
t.Fatal(err)
}
tx := db.Snapshot()
computed := db.ComputeCRC(tx)
stored := db.ReadCRC(tx)
if computed != stored {
t.Fatal(stored, computed)
}
db.Close()
}
}

5
mdb/crc.go Normal file
View File

@@ -0,0 +1,5 @@
package mdb
import "hash/crc64"
var crc64Table = crc64.MakeTable(crc64.ECMA)

67
mdb/db-primary.go Normal file
View File

@@ -0,0 +1,67 @@
package mdb
/*
func (db *Database) openPrimary() (err error) {
wal, err := cwal.Open(db.walRootDir, cwal.Config{
SegMinCount: db.conf.WALSegMinCount,
SegMaxAgeSec: db.conf.WALSegMaxAgeSec,
})
pFile, err := pfile.Open(db.pageFilePath,
pFile, err := openPageFileAndReplayWAL(db.rootDir)
if err != nil {
return err
}
defer pFile.Close()
pfHeader, err := pFile.ReadHeader()
if err != nil {
return err
}
tx := db.Snapshot()
tx.seqNum = pfHeader.SeqNum
tx.updatedAt = pfHeader.UpdatedAt
pIndex, err := pagefile.NewIndex(pFile)
if err != nil {
return err
}
err = pFile.IterateAllocated(pIndex, func(cID, iID uint64, data []byte) error {
return db.loadItem(tx, cID, iID, data)
})
if err != nil {
return err
}
w, err := cwal.OpenWriter(db.walRootDir, &cwal.WriterConfig{
SegMinCount: db.conf.WALSegMinCount,
SegMaxAgeSec: db.conf.WALSegMaxAgeSec,
})
if err != nil {
return err
}
db.done.Add(1)
go txAggregator{
Stop: db.stop,
Done: db.done,
ModChan: db.modChan,
W: w,
Index: pIndex,
Snapshot: db.snapshot,
}.Run()
db.done.Add(1)
go (&fileWriter{
Stop: db.stop,
Done: db.done,
PageFilePath: db.pageFilePath,
WALRootDir: db.walRootDir,
}).Run()
return nil
}
*/

118
mdb/db-rep.go Normal file
View File

@@ -0,0 +1,118 @@
package mdb
import (
"git.crumpington.com/public/jldb/lib/errs"
"git.crumpington.com/public/jldb/lib/wal"
"git.crumpington.com/public/jldb/mdb/change"
"git.crumpington.com/public/jldb/mdb/pfile"
"log"
"net"
"os"
)
func (db *Database) repSendState(conn net.Conn) error {
pf, err := pfile.Open(pageFilePath(db.rootDir))
if err != nil {
return err
}
defer pf.Close()
return pf.Send(conn, db.conf.NetTimeout)
}
func (db *Database) repRecvState(conn net.Conn) error {
finalPath := pageFilePath(db.rootDir)
tmpPath := finalPath + ".dl"
if err := pfile.Recv(conn, tmpPath, db.conf.NetTimeout); err != nil {
return err
}
if err := os.Rename(tmpPath, finalPath); err != nil {
return errs.Unexpected.WithErr(err)
}
return nil
}
func (db *Database) repInitStorage() (err error) {
db.pf, err = pfile.Open(pageFilePath(db.rootDir))
return err
}
func (db *Database) repReplay(rec wal.Record) (err error) {
db.changes, err = change.Read(db.changes[:0], rec.Reader)
if err != nil {
return err
}
return db.pf.ApplyChanges(db.changes)
}
func (db *Database) repLoadFromStorage() (err error) {
db.idx, err = pfile.NewIndex(db.pf)
if err != nil {
return err
}
tx := db.snapshot.Load()
err = pfile.IterateAllocated(db.pf, db.idx, func(cID, iID uint64, data []byte) error {
return db.loadItem(tx, cID, iID, data)
})
if err != nil {
return err
}
db.snapshot.Store(tx)
return nil
}
func (db *Database) loadItem(tx *Snapshot, cID, iID uint64, data []byte) error {
c, ok := db.collections[cID]
if !ok {
log.Printf("Failed to find collection %d for item in page file.", cID)
return nil
}
return c.insertItem(tx, iID, data)
}
func (db *Database) repApply(rec wal.Record) (err error) {
db.changes, err = change.Read(db.changes[:0], rec.Reader)
if err != nil {
return err
}
if err := db.pf.ApplyChanges(db.changes); err != nil {
return err
}
if db.rep.Primary() {
return nil
}
// For secondary, we need to also apply changes to memory.
tx := db.snapshot.Load().begin()
for _, change := range db.changes {
if err = db.applyChange(tx, change); err != nil {
return err
}
}
tx.seqNum = rec.SeqNum
tx.timestampMS = rec.TimestampMS
db.snapshot.Store(tx)
return nil
}
func (db *Database) applyChange(tx *Snapshot, change change.Change) error {
c, ok := db.collections[change.CollectionID]
if !ok {
return nil
}
if change.Store {
return c.upsertItem(tx, change.ItemID, change.Data)
}
// The only error this could return is NotFound. We'll ignore that error here.
c.deleteItem(tx, change.ItemID)
return nil
}

129
mdb/db-secondary.go Normal file
View File

@@ -0,0 +1,129 @@
package mdb
/*
func (db *Database) openSecondary() (err error) {
if db.shouldLoadFromPrimary() {
if err := db.loadFromPrimary(); err != nil {
return err
}
}
log.Printf("Opening page-file...")
pFile, err := openPageFileAndReplayWAL(db.rootDir)
if err != nil {
return err
}
defer pFile.Close()
pfHeader, err := pFile.ReadHeader()
if err != nil {
return err
}
log.Printf("Building page-file index...")
pIndex, err := pagefile.NewIndex(pFile)
if err != nil {
return err
}
tx := db.Snapshot()
tx.seqNum = pfHeader.SeqNum
tx.updatedAt = pfHeader.UpdatedAt
log.Printf("Loading data into memory...")
err = pFile.IterateAllocated(pIndex, func(cID, iID uint64, data []byte) error {
return db.loadItem(tx, cID, iID, data)
})
if err != nil {
return err
}
log.Printf("Creating writer...")
w, err := cswal.OpenWriter(db.walRootDir, &cswal.WriterConfig{
SegMinCount: db.conf.WALSegMinCount,
SegMaxAgeSec: db.conf.WALSegMaxAgeSec,
})
if err != nil {
return err
}
db.done.Add(1)
go (&walFollower{
Stop: db.stop,
Done: db.done,
W: w,
Client: NewClient(db.conf.PrimaryURL, db.conf.ReplicationPSK, db.conf.NetTimeout),
}).Run()
db.done.Add(1)
go (&follower{
Stop: db.stop,
Done: db.done,
WALRootDir: db.walRootDir,
SeqNum: pfHeader.SeqNum,
ApplyChanges: db.applyChanges,
}).Run()
db.done.Add(1)
go (&fileWriter{
Stop: db.stop,
Done: db.done,
PageFilePath: db.pageFilePath,
WALRootDir: db.walRootDir,
}).Run()
return nil
}
func (db *Database) shouldLoadFromPrimary() bool {
if _, err := os.Stat(db.walRootDir); os.IsNotExist(err) {
log.Printf("WAL doesn't exist.")
return true
}
if _, err := os.Stat(db.pageFilePath); os.IsNotExist(err) {
log.Printf("Page-file doesn't exist.")
return true
}
return false
}
func (db *Database) loadFromPrimary() error {
client := NewClient(db.conf.PrimaryURL, db.conf.ReplicationPSK, db.conf.NetTimeout)
defer client.Disconnect()
log.Printf("Loading data from primary...")
if err := os.RemoveAll(db.pageFilePath); err != nil {
log.Printf("Failed to remove page-file: %s", err)
return errs.IO.WithErr(err) // Caller can retry.
}
if err := os.RemoveAll(db.walRootDir); err != nil {
log.Printf("Failed to remove WAL: %s", err)
return errs.IO.WithErr(err) // Caller can retry.
}
err := client.DownloadPageFile(db.pageFilePath+".tmp", db.pageFilePath)
if err != nil {
log.Printf("Failed to get page-file from primary: %s", err)
return err // Caller can retry.
}
pfHeader, err := pagefile.ReadHeader(db.pageFilePath)
if err != nil {
log.Printf("Failed to read page-file sequence number: %s", err)
return err // Caller can retry.
}
if err = cswal.CreateEx(db.walRootDir, pfHeader.SeqNum+1); err != nil {
log.Printf("Failed to initialize WAL: %s", err)
return err // Caller can retry.
}
return nil
}
*/

852
mdb/db-testcases_test.go Normal file
View File

@@ -0,0 +1,852 @@
package mdb
import (
"errors"
"fmt"
"reflect"
"strings"
"testing"
)
type DBTestCase struct {
Name string
Steps []DBTestStep
}
type DBTestStep struct {
Name string
Update func(t *testing.T, db TestDB, tx *Snapshot) error
ExpectedUpdateError error
State DBState
}
type DBState struct {
UsersByID []User
UsersByEmail []User
UsersByName []User
UsersByBlocked []User
DataByID []UserDataItem
DataByName []UserDataItem
}
var testDBTestCases = []DBTestCase{{
Name: "Insert update",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Update",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user, ok := db.Users.ByID.Get(tx, &User{ID: 1})
if !ok {
return ErrNotFound
}
user.Name = "Bob"
user.Email = "b@c.com"
return db.Users.Update(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Bob", Email: "b@c.com"}},
UsersByEmail: []User{{ID: 1, Name: "Bob", Email: "b@c.com"}},
UsersByName: []User{{ID: 1, Name: "Bob", Email: "b@c.com"}},
},
}},
}, {
Name: "Insert delete",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Delete",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
return db.Users.Delete(tx, 1)
},
State: DBState{},
}},
}, {
Name: "Insert duplicate one tx (ID)",
Steps: []DBTestStep{{
Name: "Insert with duplicate",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
if err := db.Users.Insert(tx, user); err != nil {
return err
}
user2 := &User{ID: 1, Name: "Bob", Email: "b@c.com"}
return db.Users.Insert(tx, user2)
},
ExpectedUpdateError: ErrDuplicate,
State: DBState{},
}},
}, {
Name: "Insert duplicate one tx (email)",
Steps: []DBTestStep{{
Name: "Insert with duplicate",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
if err := db.Users.Insert(tx, user); err != nil {
return err
}
user2 := &User{ID: 2, Name: "Bob", Email: "a@b.com"}
return db.Users.Insert(tx, user2)
},
ExpectedUpdateError: ErrDuplicate,
State: DBState{},
}},
}, {
Name: "Insert duplicate two txs (ID)",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Insert duplicate",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Bob", Email: "b@c.com"}
return db.Users.Insert(tx, user)
},
ExpectedUpdateError: ErrDuplicate,
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Insert duplicate two txs (email)",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Insert duplicate",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 2, Name: "Bob", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
ExpectedUpdateError: ErrDuplicate,
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Insert read-only snapshot",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(db.Snapshot(), user)
},
ExpectedUpdateError: ErrReadOnly,
}},
}, {
Name: "Insert partial index",
Steps: []DBTestStep{{
Name: "Insert Alice",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 5, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Insert Bob",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{
{ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true},
{ID: 5, Name: "Alice", Email: "a@b.com"},
},
UsersByEmail: []User{
{ID: 5, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true},
},
UsersByName: []User{
{ID: 5, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true},
},
UsersByBlocked: []User{
{ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true},
},
},
}},
}, {
Name: "Update not found",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 5, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Update",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 4, Name: "Alice", Email: "x@y.com"}
return db.Users.Update(tx, user)
},
ExpectedUpdateError: ErrNotFound,
State: DBState{
UsersByID: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Update read-only snapshot",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Update",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user, ok := db.Users.ByID.Get(tx, &User{ID: 1})
if !ok {
return ErrNotFound
}
user.Name = "Bob"
user.Email = "b@c.com"
return db.Users.Update(db.Snapshot(), user)
},
ExpectedUpdateError: ErrReadOnly,
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Insert into two collections",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
if err := db.Users.Insert(tx, user); err != nil {
return err
}
data := &UserDataItem{ID: 1, UserID: user.ID, Name: "Item1", Data: "xyz"}
return db.UserData.Insert(tx, data)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
DataByID: []UserDataItem{{ID: 1, UserID: 1, Name: "Item1", Data: "xyz"}},
DataByName: []UserDataItem{{ID: 1, UserID: 1, Name: "Item1", Data: "xyz"}},
},
}},
}, {
Name: "Update into index",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Update",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}
return db.Users.Update(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
UsersByBlocked: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
},
}},
}, {
Name: "Update out of index",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
UsersByBlocked: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}},
},
}, {
Name: "Update",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Update(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Update duplicate one tx",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user1 := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
if err := db.Users.Insert(tx, user1); err != nil {
return err
}
user2 := &User{ID: 2, Name: "Bob", Email: "b@c.com"}
if err := db.Users.Insert(tx, user2); err != nil {
return err
}
user2.Email = "a@b.com"
return db.Users.Update(tx, user2)
},
ExpectedUpdateError: ErrDuplicate,
State: DBState{},
}},
}, {
Name: "Update duplicate two txs",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user1 := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
if err := db.Users.Insert(tx, user1); err != nil {
return err
}
user2 := &User{ID: 2, Name: "Bob", Email: "b@c.com"}
return db.Users.Insert(tx, user2)
},
State: DBState{
UsersByID: []User{
{ID: 1, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com"},
},
UsersByEmail: []User{
{ID: 1, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com"},
},
UsersByName: []User{
{ID: 1, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com"},
},
},
}, {
Name: "Update",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
u, ok := db.Users.ByID.Get(tx, &User{ID: 2})
if !ok {
return ErrNotFound
}
u.Email = "a@b.com"
return db.Users.Update(tx, u)
},
ExpectedUpdateError: ErrDuplicate,
State: DBState{
UsersByID: []User{
{ID: 1, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com"},
},
UsersByEmail: []User{
{ID: 1, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com"},
},
UsersByName: []User{
{ID: 1, Name: "Alice", Email: "a@b.com"},
{ID: 2, Name: "Bob", Email: "b@c.com"},
},
},
}},
}, {
Name: "Delete read only",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Delete",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
return db.Users.Delete(db.Snapshot(), 1)
},
ExpectedUpdateError: ErrReadOnly,
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Delete not found",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Delete",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
return db.Users.Delete(tx, 2)
},
ExpectedUpdateError: ErrNotFound,
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Index general",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
user := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
return db.Users.Insert(tx, user)
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Get found",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
expected := &User{ID: 1, Name: "Alice", Email: "a@b.com"}
u, ok := db.Users.ByID.Get(tx, &User{ID: 1})
if !ok {
return ErrNotFound
}
if !reflect.DeepEqual(u, expected) {
return errors.New("Not equal (id)")
}
u, ok = db.Users.ByEmail.Get(tx, &User{Email: "a@b.com"})
if !ok {
return ErrNotFound
}
if !reflect.DeepEqual(u, expected) {
return errors.New("Not equal (email)")
}
return nil
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Get not found",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
if _, ok := db.Users.ByID.Get(tx, &User{ID: 2}); ok {
return errors.New("Found (id)")
}
if _, ok := db.Users.ByEmail.Get(tx, &User{Email: "x@b.com"}); ok {
return errors.New("Found (email)")
}
return nil
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Has (true)",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
if ok := db.Users.ByID.Has(tx, &User{ID: 1}); !ok {
return errors.New("Not found (id)")
}
if ok := db.Users.ByEmail.Has(tx, &User{Email: "a@b.com"}); !ok {
return errors.New("Not found (email)")
}
return nil
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}, {
Name: "Has (false)",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
if ok := db.Users.ByID.Has(tx, &User{ID: 2}); ok {
return errors.New("Found (id)")
}
if ok := db.Users.ByEmail.Has(tx, &User{Email: "x@b.com"}); ok {
return errors.New("Found (email)")
}
return nil
},
State: DBState{
UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}},
},
}},
}, {
Name: "Mutate while iterating",
Steps: []DBTestStep{{
Name: "Insert",
Update: func(t *testing.T, db TestDB, tx *Snapshot) error {
for i := 0; i < 4; i++ {
user := &User{
ID: uint64(i) + 1,
Name: fmt.Sprintf("User%d", i),
Email: fmt.Sprintf("user.%d@x.com", i),
}
if err := db.Users.Insert(tx, user); err != nil {
return err
}
}
return nil
},
State: DBState{
UsersByID: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1", Email: "user.1@x.com"},
{ID: 3, Name: "User2", Email: "user.2@x.com"},
{ID: 4, Name: "User3", Email: "user.3@x.com"},
},
UsersByEmail: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1", Email: "user.1@x.com"},
{ID: 3, Name: "User2", Email: "user.2@x.com"},
{ID: 4, Name: "User3", Email: "user.3@x.com"},
},
UsersByName: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1", Email: "user.1@x.com"},
{ID: 3, Name: "User2", Email: "user.2@x.com"},
{ID: 4, Name: "User3", Email: "user.3@x.com"},
},
},
}, {
Name: "Modify while iterating",
Update: func(t *testing.T, db TestDB, tx *Snapshot) (err error) {
first := true
pivot := User{Name: "User1"}
db.Users.ByName.AscendAfter(tx, &pivot, func(u *User) bool {
u.Name += "Mod"
if err = db.Users.Update(tx, u); err != nil {
return false
}
if first {
first = false
return true
}
prev, ok := db.Users.ByID.Get(tx, &User{ID: u.ID - 1})
if !ok {
err = errors.New("Previous user not found")
return false
}
if !strings.HasSuffix(prev.Name, "Mod") {
err = errors.New("Incorrect user name: " + prev.Name)
return false
}
return true
})
return nil
},
State: DBState{
UsersByID: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1Mod", Email: "user.1@x.com"},
{ID: 3, Name: "User2Mod", Email: "user.2@x.com"},
{ID: 4, Name: "User3Mod", Email: "user.3@x.com"},
},
UsersByEmail: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1Mod", Email: "user.1@x.com"},
{ID: 3, Name: "User2Mod", Email: "user.2@x.com"},
{ID: 4, Name: "User3Mod", Email: "user.3@x.com"},
},
UsersByName: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1Mod", Email: "user.1@x.com"},
{ID: 3, Name: "User2Mod", Email: "user.2@x.com"},
{ID: 4, Name: "User3Mod", Email: "user.3@x.com"},
},
},
}, {
Name: "Iterate after modifying",
Update: func(t *testing.T, db TestDB, tx *Snapshot) (err error) {
u := &User{ID: 5, Name: "User4Mod", Email: "user.4@x.com"}
if err := db.Users.Insert(tx, u); err != nil {
return err
}
first := true
db.Users.ByName.DescendAfter(tx, &User{Name: "User5Mod"}, func(u *User) bool {
u.Name = strings.TrimSuffix(u.Name, "Mod")
if err = db.Users.Update(tx, u); err != nil {
return false
}
if first {
first = false
return true
}
prev, ok := db.Users.ByID.Get(tx, &User{ID: u.ID + 1})
if !ok {
err = errors.New("Previous user not found")
return false
}
if strings.HasSuffix(prev.Name, "Mod") {
err = errors.New("Incorrect user name: " + prev.Name)
return false
}
return true
})
return nil
},
State: DBState{
UsersByID: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1", Email: "user.1@x.com"},
{ID: 3, Name: "User2", Email: "user.2@x.com"},
{ID: 4, Name: "User3", Email: "user.3@x.com"},
{ID: 5, Name: "User4", Email: "user.4@x.com"},
},
UsersByEmail: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1", Email: "user.1@x.com"},
{ID: 3, Name: "User2", Email: "user.2@x.com"},
{ID: 4, Name: "User3", Email: "user.3@x.com"},
{ID: 5, Name: "User4", Email: "user.4@x.com"},
},
UsersByName: []User{
{ID: 1, Name: "User0", Email: "user.0@x.com"},
{ID: 2, Name: "User1", Email: "user.1@x.com"},
{ID: 3, Name: "User2", Email: "user.2@x.com"},
{ID: 4, Name: "User3", Email: "user.3@x.com"},
{ID: 5, Name: "User4", Email: "user.4@x.com"},
},
},
}},
}}

138
mdb/db-testlist_test.go Normal file
View File

@@ -0,0 +1,138 @@
package mdb
import (
"fmt"
"reflect"
"testing"
)
func TestDBList(t *testing.T) {
db := NewTestDBPrimary(t, t.TempDir())
var (
user1 = User{
ID: NewID(),
Name: "User1",
Email: "user1@gmail.com",
}
user2 = User{
ID: NewID(),
Name: "User2",
Email: "user2@gmail.com",
}
user3 = User{
ID: NewID(),
Name: "User3",
Email: "user3@gmail.com",
}
user1Data = make([]UserDataItem, 10)
user2Data = make([]UserDataItem, 4)
user3Data = make([]UserDataItem, 8)
)
err := db.Update(func(tx *Snapshot) error {
if err := db.Users.Insert(tx, &user1); err != nil {
return err
}
if err := db.Users.Insert(tx, &user2); err != nil {
return err
}
for i := range user1Data {
user1Data[i] = UserDataItem{
ID: NewID(),
UserID: user1.ID,
Name: fmt.Sprintf("Name1: %d", i),
Data: fmt.Sprintf("Data: %d", i),
}
if err := db.UserData.Insert(tx, &user1Data[i]); err != nil {
return err
}
}
for i := range user2Data {
user2Data[i] = UserDataItem{
ID: NewID(),
UserID: user2.ID,
Name: fmt.Sprintf("Name2: %d", i),
Data: fmt.Sprintf("Data: %d", i),
}
if err := db.UserData.Insert(tx, &user2Data[i]); err != nil {
return err
}
}
for i := range user3Data {
user3Data[i] = UserDataItem{
ID: NewID(),
UserID: user3.ID,
Name: fmt.Sprintf("Name3: %d", i),
Data: fmt.Sprintf("Data: %d", i),
}
if err := db.UserData.Insert(tx, &user3Data[i]); err != nil {
return err
}
}
return nil
})
if err != nil {
t.Fatal(err)
}
type TestCase struct {
Name string
Args ListArgs[UserDataItem]
Expected []UserDataItem
}
cases := []TestCase{
{
Name: "User1 all",
Args: ListArgs[UserDataItem]{
After: &UserDataItem{
UserID: user1.ID,
},
While: func(item *UserDataItem) bool {
return item.UserID == user1.ID
},
},
Expected: user1Data,
}, {
Name: "User1 limited",
Args: ListArgs[UserDataItem]{
After: &UserDataItem{
UserID: user1.ID,
},
While: func(item *UserDataItem) bool {
return item.UserID == user1.ID
},
Limit: 4,
},
Expected: user1Data[:4],
},
}
for _, tc := range cases {
t.Run(tc.Name, func(t *testing.T) {
tx := db.Snapshot()
l := db.UserData.ByName.List(tx, tc.Args, nil)
if len(l) != len(tc.Expected) {
t.Fatal(tc.Name, l)
}
for i := range l {
if !reflect.DeepEqual(*l[i], tc.Expected[i]) {
t.Fatal(tc.Name, l)
}
}
})
}
}

164
mdb/db-testrunner_test.go Normal file
View File

@@ -0,0 +1,164 @@
package mdb
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
)
func TestDBRunTests(t *testing.T) {
t.Helper()
for _, testCase := range testDBTestCases {
testCase := testCase
t.Run(testCase.Name, func(t *testing.T) {
t.Parallel()
testRunner_testCase(t, testCase)
})
}
}
func testRunner_testCase(t *testing.T, testCase DBTestCase) {
rootDir := t.TempDir()
db := NewTestDBPrimary(t, rootDir)
mux := http.NewServeMux()
mux.HandleFunc("/rep/", db.Handle)
testServer := httptest.NewServer(mux)
defer testServer.Close()
rootDir2 := t.TempDir()
db2 := NewTestSecondaryDB(t, rootDir2, testServer.URL+"/rep/")
defer db2.Close()
snapshots := make([]*Snapshot, 0, len(testCase.Steps))
// Run each step and it's associated check function.
for _, step := range testCase.Steps {
t.Run(step.Name, func(t *testing.T) {
err := db.Update(func(tx *Snapshot) error {
return step.Update(t, db, tx)
})
if !errors.Is(err, step.ExpectedUpdateError) {
t.Fatal(err, step.ExpectedUpdateError)
}
snapshot := db.Snapshot()
snapshots = append(snapshots, snapshot)
testRunner_checkState(t, db, snapshot, step.State)
})
}
// Run each step's check function again with stored snapshot.
for i, step := range testCase.Steps {
snapshot := snapshots[i]
t.Run(step.Name+"-checkSnapshot", func(t *testing.T) {
testRunner_checkState(t, db, snapshot, step.State)
})
}
pInfo := db.Info()
for {
info := db2.Info()
if info.SeqNum == pInfo.SeqNum {
break
}
time.Sleep(time.Millisecond)
}
// TODO: Why is this necessary?
time.Sleep(time.Second)
finalStep := testCase.Steps[len(testCase.Steps)-1]
secondarySnapshot := db2.Snapshot()
t.Run("Check secondary", func(t *testing.T) {
testRunner_checkState(t, db2, secondarySnapshot, finalStep.State)
})
if err := db.Close(); err != nil {
t.Fatal(err)
}
db = NewTestDBPrimary(t, rootDir)
snapshot := db.Snapshot()
// Run the final step's check function again with a newly loaded db.
t.Run("Check after reload", func(t *testing.T) {
testRunner_checkState(t, db, snapshot, finalStep.State)
})
t.Run("Check that primary and secondary are equal", func(t *testing.T) {
db.AssertEqual(t, db2.Database)
})
db.Close()
}
func testRunner_checkState(
t *testing.T,
db TestDB,
tx *Snapshot,
state DBState,
) {
t.Helper()
checkSlicesEqual(t, "UsersByID", db.Users.ByID.Dump(tx), state.UsersByID)
checkSlicesEqual(t, "UsersByEmail", db.Users.ByEmail.Dump(tx), state.UsersByEmail)
checkSlicesEqual(t, "UsersByName", db.Users.ByName.Dump(tx), state.UsersByName)
checkSlicesEqual(t, "UsersByBlocked", db.Users.ByBlocked.Dump(tx), state.UsersByBlocked)
checkSlicesEqual(t, "DataByID", db.UserData.ByID.Dump(tx), state.DataByID)
checkSlicesEqual(t, "DataByName", db.UserData.ByName.Dump(tx), state.DataByName)
checkMinMaxEqual(t, "UsersByID", tx, db.Users.ByID, state.UsersByID)
checkMinMaxEqual(t, "UsersByEmail", tx, db.Users.ByEmail, state.UsersByEmail)
checkMinMaxEqual(t, "UsersByName", tx, db.Users.ByName, state.UsersByName)
checkMinMaxEqual(t, "UsersByBlocked", tx, db.Users.ByBlocked, state.UsersByBlocked)
checkMinMaxEqual(t, "DataByID", tx, db.UserData.ByID, state.DataByID)
checkMinMaxEqual(t, "DataByName", tx, db.UserData.ByName, state.DataByName)
}
func checkSlicesEqual[T any](t *testing.T, name string, actual, expected []T) {
t.Helper()
if len(actual) != len(expected) {
t.Fatal(name, len(actual), len(expected))
}
for i := range actual {
if !reflect.DeepEqual(actual[i], expected[i]) {
t.Fatal(name, actual[i], expected[i])
}
}
}
func checkMinMaxEqual[T any](t *testing.T, name string, tx *Snapshot, index Index[T], expected []T) {
if len(expected) == 0 {
if min, ok := index.Min(tx); ok {
t.Fatal(min)
}
if max, ok := index.Max(tx); ok {
t.Fatal(max)
}
return
}
min, ok := index.Min(tx)
if !ok {
t.Fatal("No min")
}
max, ok := index.Max(tx)
if !ok {
t.Fatal("No max")
}
if !reflect.DeepEqual(*min, expected[0]) {
t.Fatal(min, expected[0])
}
if !reflect.DeepEqual(*max, expected[len(expected)-1]) {
t.Fatal(max)
}
}

105
mdb/db-txaggregator.go Normal file
View File

@@ -0,0 +1,105 @@
package mdb
import (
"bytes"
"git.crumpington.com/public/jldb/mdb/change"
)
type txMod struct {
Update func(tx *Snapshot) error
Resp chan error
}
func (db *Database) runTXAggreagtor() {
defer db.done.Done()
var (
tx *Snapshot
mod txMod
seqNum int64
timestampMS int64
err error
buf = &bytes.Buffer{}
toNotify = make([]chan error, 0, db.conf.MaxConcurrentUpdates)
)
READ_FIRST:
toNotify = toNotify[:0]
select {
case mod = <-db.modChan:
goto BEGIN
case <-db.stop:
goto END
}
BEGIN:
tx = db.snapshot.Load().begin()
goto APPLY_MOD
CLONE:
tx = tx.clone()
goto APPLY_MOD
APPLY_MOD:
if err = mod.Update(tx); err != nil {
mod.Resp <- err
goto ROLLBACK
}
toNotify = append(toNotify, mod.Resp)
goto NEXT
ROLLBACK:
if len(toNotify) == 0 {
goto READ_FIRST
}
tx = tx.rollback()
goto NEXT
NEXT:
select {
case mod = <-db.modChan:
goto CLONE
default:
goto WRITE
}
WRITE:
db.idx.StageChanges(tx.changes)
buf.Reset()
if err = change.Write(tx.changes, buf); err != nil {
db.idx.UnstageChanges(tx.changes)
}
if err == nil {
seqNum, timestampMS, err = db.rep.Append(int64(buf.Len()), buf)
}
if err != nil {
db.idx.UnstageChanges(tx.changes)
} else {
db.idx.ApplyChanges(tx.changes)
tx.seqNum = seqNum
tx.timestampMS = timestampMS
tx.setReadOnly()
db.snapshot.Store(tx)
}
for i := range toNotify {
toNotify[i] <- err
}
goto READ_FIRST
END:
}

36
mdb/db-userdata_test.go Normal file
View File

@@ -0,0 +1,36 @@
package mdb
import (
"cmp"
"strings"
)
type UserDataItem struct {
ID uint64
UserID uint64
Name string
Data string
}
type UserData struct {
*Collection[UserDataItem]
ByName Index[UserDataItem] // Unique index on (Token).
}
func NewUserDataCollection(db *Database) UserData {
userData := UserData{}
userData.Collection = NewCollection[UserDataItem](db, "UserData", nil)
userData.ByName = NewUniqueIndex(
userData.Collection,
"ByName",
func(lhs, rhs *UserDataItem) int {
if x := cmp.Compare(lhs.UserID, rhs.UserID); x != 0 {
return x
}
return strings.Compare(lhs.Name, rhs.Name)
})
return userData
}

50
mdb/db-users_test.go Normal file
View File

@@ -0,0 +1,50 @@
package mdb
import "strings"
type User struct {
ID uint64
Name string
Email string
Admin bool
Blocked bool
}
type Users struct {
*Collection[User]
ByEmail Index[User] // Unique index on (Email).
ByName Index[User] // Index on (Name).
ByBlocked Index[User] // Partial index on (Blocked,Email).
}
func NewUserCollection(db *Database) Users {
users := Users{}
users.Collection = NewCollection[User](db, "Users", nil)
users.ByEmail = NewUniqueIndex(
users.Collection,
"ByEmail",
func(lhs, rhs *User) int {
return strings.Compare(lhs.Email, rhs.Email)
})
users.ByName = NewIndex(
users.Collection,
"ByName",
func(lhs, rhs *User) int {
return strings.Compare(lhs.Name, rhs.Name)
})
users.ByBlocked = NewPartialIndex(
users.Collection,
"ByBlocked",
func(lhs, rhs *User) int {
return strings.Compare(lhs.Email, rhs.Email)
},
func(item *User) bool {
return item.Blocked
})
return users
}

184
mdb/db.go Normal file
View File

@@ -0,0 +1,184 @@
package mdb
import (
"fmt"
"git.crumpington.com/public/jldb/lib/errs"
"git.crumpington.com/public/jldb/lib/rep"
"git.crumpington.com/public/jldb/mdb/change"
"git.crumpington.com/public/jldb/mdb/pfile"
"net/http"
"os"
"sync"
"sync/atomic"
"time"
)
type Config struct {
RootDir string
Primary bool
ReplicationPSK string
NetTimeout time.Duration // Default is 1 minute.
// WAL settings.
WALSegMinCount int64 // Minimum Change sets in a segment. Default is 1024.
WALSegMaxAgeSec int64 // Maximum age of a segment. Default is 1 hour.
WALSegGCAgeSec int64 // Segment age for garbage collection. Default is 7 days.
// Necessary for secondary.
PrimaryEndpoint string
// MaxConcurrentUpdates restricts the number of concurently running updates,
// and also limits the maximum number of changes that may be aggregated in
// the WAL.
//
// Default is 32.
MaxConcurrentUpdates int
}
func (c Config) repConfig() rep.Config {
return rep.Config{
RootDir: repDirPath(c.RootDir),
Primary: c.Primary,
ReplicationPSK: c.ReplicationPSK,
NetTimeout: c.NetTimeout,
WALSegMinCount: c.WALSegMinCount,
WALSegMaxAgeSec: c.WALSegMaxAgeSec,
WALSegGCAgeSec: c.WALSegGCAgeSec,
PrimaryEndpoint: c.PrimaryEndpoint,
}
}
type Database struct {
rep *rep.Replicator
rootDir string
conf Config
pf *pfile.File
idx *pfile.Index
changes []change.Change
// The Snapshot stored here is read-only. It will be replaced as needed by
// the txAggregator (primary), or the follower (secondary).
snapshot *atomic.Pointer[Snapshot]
collections map[uint64]collection
stop chan struct{}
done *sync.WaitGroup
txModPool chan txMod
modChan chan txMod
}
func New(conf Config) *Database {
if conf.MaxConcurrentUpdates <= 0 {
conf.MaxConcurrentUpdates = 32
}
db := &Database{
rootDir: conf.RootDir,
conf: conf,
snapshot: &atomic.Pointer[Snapshot]{},
collections: map[uint64]collection{},
stop: make(chan struct{}),
done: &sync.WaitGroup{},
txModPool: make(chan txMod, conf.MaxConcurrentUpdates),
modChan: make(chan txMod),
}
db.snapshot.Store(newSnapshot())
for i := 0; i < conf.MaxConcurrentUpdates; i++ {
db.txModPool <- txMod{Resp: make(chan error, 1)}
}
return db
}
func (db *Database) Open() (err error) {
if err := os.MkdirAll(db.rootDir, 0700); err != nil {
return errs.IO.WithErr(err)
}
db.rep, err = rep.Open(
rep.App{
SendState: db.repSendState,
RecvState: db.repRecvState,
InitStorage: db.repInitStorage,
Replay: db.repReplay,
LoadFromStorage: db.repLoadFromStorage,
Apply: db.repApply,
},
db.conf.repConfig())
if err != nil {
return err
}
if db.conf.Primary {
db.done.Add(1)
go db.runTXAggreagtor()
}
return nil
}
func (db *Database) Close() error {
select {
case <-db.stop:
return nil
default:
}
close(db.stop)
db.rep.Close()
db.done.Wait()
db.snapshot = nil
db.collections = nil
return nil
}
func (db *Database) Snapshot() *Snapshot {
return db.snapshot.Load()
}
func (db *Database) Update(update func(tx *Snapshot) error) error {
if !db.conf.Primary {
return errs.ReadOnly.WithMsg("cannot update secondary directly")
}
mod := <-db.txModPool
mod.Update = update
db.modChan <- mod
err := <-mod.Resp
db.txModPool <- mod
return err
}
func (db *Database) Info() Info {
tx := db.Snapshot()
repInfo := db.rep.Info()
return Info{
SeqNum: tx.seqNum,
TimestampMS: tx.timestampMS,
WALFirstSeqNum: repInfo.WALFirstSeqNum,
WALLastSeqNum: repInfo.WALLastSeqNum,
WALLastTimestampMS: repInfo.WALLastTimestampMS,
}
}
func (db *Database) addCollection(id uint64, c collection, collectionState any) {
if _, ok := db.collections[id]; ok {
panic(fmt.Sprintf("Collection %s uses duplicate ID %d.", c.Name(), id))
}
db.collections[id] = c
db.snapshot.Load().addCollection(id, collectionState)
}
func (db *Database) Handle(w http.ResponseWriter, r *http.Request) {
db.rep.Handle(w, r)
}

54
mdb/db_test.go Normal file
View File

@@ -0,0 +1,54 @@
package mdb
import (
"testing"
"time"
)
type TestDB struct {
*Database
Users Users
UserData UserData
}
func NewTestDBPrimary(t *testing.T, rootDir string) TestDB {
db := New(Config{
RootDir: rootDir,
Primary: true,
NetTimeout: 8 * time.Second,
ReplicationPSK: "123",
})
testDB := TestDB{
Database: db,
Users: NewUserCollection(db),
UserData: NewUserDataCollection(db),
}
if err := testDB.Open(); err != nil {
t.Fatal(err)
}
return testDB
}
func NewTestSecondaryDB(t *testing.T, rootDir, primaryURL string) TestDB {
db := New(Config{
RootDir: rootDir,
PrimaryEndpoint: primaryURL,
NetTimeout: 8 * time.Second,
ReplicationPSK: "123",
})
testDB := TestDB{
Database: db,
Users: NewUserCollection(db),
UserData: NewUserDataCollection(db),
}
if err := testDB.Open(); err != nil {
t.Fatal(err)
}
return testDB
}

59
mdb/equality_test.go Normal file
View File

@@ -0,0 +1,59 @@
package mdb
import (
"fmt"
"reflect"
"testing"
)
func (i Index[T]) AssertEqual(t *testing.T, tx1, tx2 *Snapshot) {
t.Helper()
state1 := i.getState(tx1)
state2 := i.getState(tx2)
if state1.BTree.Len() != state2.BTree.Len() {
t.Fatalf("(%s) Unequal lengths: %d != %d",
i.name,
state1.BTree.Len(),
state2.BTree.Len())
}
errStr := ""
i.Ascend(tx1, func(item1 *T) bool {
item2, ok := i.Get(tx2, item1)
if !ok {
errStr = fmt.Sprintf("Indices don't match. %v not found.", item1)
return false
}
if !reflect.DeepEqual(item1, item2) {
errStr = fmt.Sprintf("%v != %v", item1, item2)
return false
}
return true
})
if errStr != "" {
t.Fatal(errStr)
}
}
func (c *Collection[T]) AssertEqual(t *testing.T, tx1, tx2 *Snapshot) {
t.Helper()
c.ByID.AssertEqual(t, tx1, tx2)
for _, idx := range c.indices {
idx.AssertEqual(t, tx1, tx2)
}
}
func (db *Database) AssertEqual(t *testing.T, db2 *Database) {
tx1 := db.Snapshot()
tx2 := db.Snapshot()
for _, c := range db.collections {
cc := c.(interface {
AssertEqual(t *testing.T, tx1, tx2 *Snapshot)
})
cc.AssertEqual(t, tx1, tx2)
}
}

11
mdb/errors.go Normal file
View File

@@ -0,0 +1,11 @@
package mdb
import (
"git.crumpington.com/public/jldb/lib/errs"
)
var (
ErrNotFound = errs.NotFound
ErrReadOnly = errs.ReadOnly
ErrDuplicate = errs.Duplicate
)

100
mdb/filewriter.go Normal file
View File

@@ -0,0 +1,100 @@
package mdb
/*
// The fileWriter writes changes from the WAL to the data file. It's run by the
// primary, and, for the primary, is the only way the pagefile is modified.
type fileWriter struct {
Stop chan struct{}
Done *sync.WaitGroup
PageFilePath string
WALRootDir string
}
func (w *fileWriter) Run() {
defer w.Done.Done()
for {
w.runOnce()
select {
case <-w.Stop:
return
default:
time.Sleep(time.Second)
}
}
}
func (w *fileWriter) runOnce() {
f, err := pagefile.Open(w.PageFilePath)
if err != nil {
w.logf("Failed to open page file: %v", err)
return
}
defer f.Close()
header, err := w.readHeader(f)
if err != nil {
w.logf("Failed to get header from page file: %v", err)
return
}
it, err := cswal.NewIterator(w.WALRootDir, header.SeqNum+1)
if err != nil {
w.logf("Failed to get WAL iterator: %v", err)
return
}
defer it.Close()
for {
hasNext := it.Next(time.Second)
select {
case <-w.Stop:
return
default:
}
if !hasNext {
if it.Error() != nil {
w.logf("Iteration error: %v", it.Error())
return
}
continue
}
rec := it.Record()
if err := w.applyChanges(f, rec); err != nil {
w.logf("Failed to apply changes: %v", err)
return
}
}
}
func (w *fileWriter) readHeader(f *pagefile.File) (pagefile.Header, error) {
defer f.RLock()()
return f.ReadHeader()
}
func (w *fileWriter) applyChanges(f *pagefile.File, rec *cswal.Record) error {
defer f.WLock()()
if err := f.ApplyChanges(rec.Changes); err != nil {
w.logf("Failed to apply changes to page file: %v", err)
return err
}
header := pagefile.Header{
SeqNum: rec.SeqNum,
UpdatedAt: rec.CreatedAt,
}
if err := f.WriteHeader(header); err != nil {
w.logf("Failed to write page file header: %v", err)
return err
}
return nil
}
func (w *fileWriter) logf(pattern string, args ...interface{}) {
log.Printf("[FILE-WRITER] "+pattern, args...)
}
*/

68
mdb/follower.go Normal file
View File

@@ -0,0 +1,68 @@
package mdb
/*
type follower struct {
Stop chan struct{}
Done *sync.WaitGroup
WALRootDir string
SeqNum uint64 // Current max applied sequence number.
ApplyChanges func(rec *cswal.Record) error
seqNum uint64 // Current max applied sequence number.
}
func (f *follower) Run() {
defer f.Done.Done()
f.seqNum = f.SeqNum
for {
f.runOnce()
select {
case <-f.Stop:
return
default:
// Something went wrong.
time.Sleep(time.Second)
}
}
}
func (f *follower) runOnce() {
it, err := cswal.NewIterator(f.WALRootDir, f.seqNum+1)
if err != nil {
f.logf("Failed to get WAL iterator: %v", errs.FmtDetails(err))
return
}
defer it.Close()
for {
hasNext := it.Next(time.Second)
select {
case <-f.Stop:
return
default:
}
if !hasNext {
if it.Error() != nil {
f.logf("Iteration error: %v", errs.FmtDetails(it.Error()))
return
}
continue
}
rec := it.Record()
if err := f.ApplyChanges(rec); err != nil {
f.logf("Failed to apply changes: %s", errs.FmtDetails(err))
return
}
f.seqNum = rec.SeqNum
}
}
func (f *follower) logf(pattern string, args ...interface{}) {
log.Printf("[FOLLOWER] "+pattern, args...)
}
*/

13
mdb/functions.go Normal file
View File

@@ -0,0 +1,13 @@
package mdb
import (
"path/filepath"
)
func pageFilePath(rootDir string) string {
return filepath.Join(rootDir, "pagefile")
}
func repDirPath(rootDir string) string {
return filepath.Join(rootDir, "rep")
}

12
mdb/functions_test.go Normal file
View File

@@ -0,0 +1,12 @@
package mdb
import (
"testing"
)
func TestPageFilePath(t *testing.T) {
pageFilePath := pageFilePath("/tmp")
if pageFilePath != "/tmp/pagefile" {
t.Fatal(pageFilePath)
}
}

8
mdb/id.go Normal file
View File

@@ -0,0 +1,8 @@
package mdb
import "git.crumpington.com/public/jldb/lib/idgen"
// Safely generate a new ID.
func NewID() uint64 {
return idgen.Next()
}

11
mdb/index-internal.go Normal file
View File

@@ -0,0 +1,11 @@
package mdb
import "github.com/google/btree"
type indexState[T any] struct {
BTree *btree.BTreeG[*T]
}
func (i indexState[T]) clone() indexState[T] {
return indexState[T]{BTree: i.BTree.Clone()}
}

236
mdb/index.go Normal file
View File

@@ -0,0 +1,236 @@
package mdb
import (
"unsafe"
"github.com/google/btree"
)
func NewIndex[T any](
c *Collection[T],
name string,
compare func(lhs, rhs *T) int,
) Index[T] {
return c.addIndex(indexConfig[T]{
Name: name,
Unique: false,
Compare: compare,
Include: nil,
})
}
func NewPartialIndex[T any](
c *Collection[T],
name string,
compare func(lhs, rhs *T) int,
include func(*T) bool,
) Index[T] {
return c.addIndex(indexConfig[T]{
Name: name,
Unique: false,
Compare: compare,
Include: include,
})
}
func NewUniqueIndex[T any](
c *Collection[T],
name string,
compare func(lhs, rhs *T) int,
) Index[T] {
return c.addIndex(indexConfig[T]{
Name: name,
Unique: true,
Compare: compare,
Include: nil,
})
}
func NewUniquePartialIndex[T any](
c *Collection[T],
name string,
compare func(lhs, rhs *T) int,
include func(*T) bool,
) Index[T] {
return c.addIndex(indexConfig[T]{
Name: name,
Unique: true,
Compare: compare,
Include: include,
})
}
// ----------------------------------------------------------------------------
type Index[T any] struct {
name string
collectionID uint64
indexID uint64
include func(*T) bool
copy func(*T) *T
}
func (i Index[T]) Get(tx *Snapshot, in *T) (item *T, ok bool) {
tPtr, ok := i.get(tx, in)
if !ok {
return item, false
}
return i.copy(tPtr), true
}
func (i Index[T]) get(tx *Snapshot, in *T) (*T, bool) {
return i.btree(tx).Get(in)
}
func (i Index[T]) Has(tx *Snapshot, in *T) bool {
return i.btree(tx).Has(in)
}
func (i Index[T]) Min(tx *Snapshot) (item *T, ok bool) {
tPtr, ok := i.btree(tx).Min()
if !ok {
return item, false
}
return i.copy(tPtr), true
}
func (i Index[T]) Max(tx *Snapshot) (item *T, ok bool) {
tPtr, ok := i.btree(tx).Max()
if !ok {
return item, false
}
return i.copy(tPtr), true
}
func (i Index[T]) Ascend(tx *Snapshot, each func(*T) bool) {
i.btreeForIter(tx).Ascend(func(t *T) bool {
return each(i.copy(t))
})
}
func (i Index[T]) AscendAfter(tx *Snapshot, after *T, each func(*T) bool) {
i.btreeForIter(tx).AscendGreaterOrEqual(after, func(t *T) bool {
return each(i.copy(t))
})
}
func (i Index[T]) Descend(tx *Snapshot, each func(*T) bool) {
i.btreeForIter(tx).Descend(func(t *T) bool {
return each(i.copy(t))
})
}
func (i Index[T]) DescendAfter(tx *Snapshot, after *T, each func(*T) bool) {
i.btreeForIter(tx).DescendLessOrEqual(after, func(t *T) bool {
return each(i.copy(t))
})
}
type ListArgs[T any] struct {
Desc bool // True for descending order, otherwise ascending.
After *T // If after is given, iterate after (and including) the value.
While func(*T) bool // Continue iterating until While is false.
Limit int // Maximum number of items to return. 0 => All.
}
func (i Index[T]) List(tx *Snapshot, args ListArgs[T], out []*T) []*T {
if args.Limit < 0 {
return nil
}
if args.While == nil {
args.While = func(*T) bool { return true }
}
size := args.Limit
if size == 0 {
size = 32 // Why not?
}
items := out[:0]
each := func(item *T) bool {
if !args.While(item) {
return false
}
items = append(items, item)
return args.Limit == 0 || len(items) < args.Limit
}
if args.Desc {
if args.After != nil {
i.DescendAfter(tx, args.After, each)
} else {
i.Descend(tx, each)
}
} else {
if args.After != nil {
i.AscendAfter(tx, args.After, each)
} else {
i.Ascend(tx, each)
}
}
return items
}
// ----------------------------------------------------------------------------
func (i Index[T]) insertConflict(tx *Snapshot, item *T) bool {
return i.btree(tx).Has(item)
}
func (i Index[T]) updateConflict(tx *Snapshot, item *T) bool {
current, ok := i.btree(tx).Get(item)
return ok && i.getID(current) != i.getID(item)
}
// This should only be called after insertConflict. Additionally, the caller
// should ensure that the index has been properly cloned for write before
// writing.
func (i Index[T]) insert(tx *Snapshot, item *T) {
if i.include != nil && !i.include(item) {
return
}
i.btree(tx).ReplaceOrInsert(item)
}
func (i Index[T]) update(tx *Snapshot, old, new *T) {
bt := i.btree(tx)
bt.Delete(old)
// The insert call will also check the include function if available.
i.insert(tx, new)
}
func (i Index[T]) delete(tx *Snapshot, item *T) {
i.btree(tx).Delete(item)
}
// ----------------------------------------------------------------------------
func (i Index[T]) getState(tx *Snapshot) indexState[T] {
return tx.collections[i.collectionID].(*collectionState[T]).Indices[i.indexID]
}
// Get the current btree for get/has/update/delete, etc.
func (i Index[T]) btree(tx *Snapshot) *btree.BTreeG[*T] {
return i.getState(tx).BTree
}
func (i Index[T]) btreeForIter(tx *Snapshot) *btree.BTreeG[*T] {
cState := tx.collections[i.collectionID].(*collectionState[T])
bt := cState.Indices[i.indexID].BTree
// If snapshot and index are writable, return a clone.
if tx.writable() && cState.Version == tx.version {
bt = bt.Clone()
}
return bt
}
func (i Index[T]) getID(t *T) uint64 {
return *((*uint64)(unsafe.Pointer(t)))
}

9
mdb/index_test.go Normal file
View File

@@ -0,0 +1,9 @@
package mdb
func (i Index[T]) Dump(tx *Snapshot) (l []T) {
i.Ascend(tx, func(t *T) bool {
l = append(l, *t)
return true
})
return l
}

11
mdb/info.go Normal file
View File

@@ -0,0 +1,11 @@
package mdb
type Info struct {
SeqNum int64 // In-memory sequence number.
TimestampMS int64 // In-memory timestamp.
FileSeqNum int64 // Page file sequence number.
FileTimestampMS int64 // Page file timestamp.
WALFirstSeqNum int64 // WAL min sequence number.
WALLastSeqNum int64 // WAL max sequence number.
WALLastTimestampMS int64 // WAL timestamp.
}

57
mdb/pfile/alloclist.go Normal file
View File

@@ -0,0 +1,57 @@
package pfile
import "slices"
type allocList map[[2]uint64][]uint64
func newAllocList() *allocList {
al := allocList(map[[2]uint64][]uint64{})
return &al
}
func (al allocList) Create(collectionID, itemID, page uint64) {
key := al.key(collectionID, itemID)
al[key] = []uint64{page}
}
// Push is used to add pages to the storage when loading. It will append
// pages to the appropriate list, or return false if the list isn't found.
func (al allocList) Push(collectionID, itemID, page uint64) bool {
key := al.key(collectionID, itemID)
if _, ok := al[key]; !ok {
return false
}
al[key] = append(al[key], page)
return true
}
func (al allocList) Store(collectionID, itemID uint64, pages []uint64) {
key := al.key(collectionID, itemID)
al[key] = slices.Clone(pages)
}
func (al allocList) Remove(collectionID, itemID uint64) []uint64 {
key := al.key(collectionID, itemID)
pages := al[key]
delete(al, key)
return pages
}
func (al allocList) Iterate(
each func(collectionID, itemID uint64, pages []uint64) error,
) error {
for key, pages := range al {
if err := each(key[0], key[1], pages); err != nil {
return err
}
}
return nil
}
func (al allocList) Len() int {
return len(al)
}
func (al allocList) key(collectionID, itemID uint64) [2]uint64 {
return [2]uint64{collectionID, itemID}
}

172
mdb/pfile/alloclist_test.go Normal file
View File

@@ -0,0 +1,172 @@
package pfile
import (
"errors"
"reflect"
"testing"
)
func (al allocList) Assert(t *testing.T, state map[[2]uint64][]uint64) {
t.Helper()
if len(al) != len(state) {
t.Fatalf("Expected %d items, but found %d.", len(state), len(al))
}
for key, expected := range state {
val, ok := al[key]
if !ok {
t.Fatalf("Expected to find key %v.", key)
}
if !reflect.DeepEqual(val, expected) {
t.Fatalf("For %v, expected %v but got %v.", key, expected, val)
}
}
}
func (al *allocList) With(collectionID, itemID uint64, pages ...uint64) *allocList {
al.Store(collectionID, itemID, pages)
return al
}
func (al *allocList) Equals(rhs *allocList) bool {
if len(*rhs) != len(*al) {
return false
}
for key, val := range *rhs {
actual := (*al)[key]
if !reflect.DeepEqual(val, actual) {
return false
}
}
return true
}
func TestAllocList(t *testing.T) {
const (
CREATE = "CREATE"
PUSH = "PUSH"
STORE = "STORE"
REMOVE = "REMOVE"
)
type TestCase struct {
Name string
Action string
Key [2]uint64
Page uint64
Pages []uint64 // For STORE command.
Expected *allocList
ExpectedLen int
}
testCases := []TestCase{{
Name: "Create something",
Action: CREATE,
Key: [2]uint64{1, 1},
Page: 1,
Expected: newAllocList().With(1, 1, 1),
ExpectedLen: 1,
}, {
Name: "Push onto something",
Action: PUSH,
Key: [2]uint64{1, 1},
Page: 2,
Expected: newAllocList().With(1, 1, 1, 2),
ExpectedLen: 1,
}, {
Name: "Push onto something again",
Action: PUSH,
Key: [2]uint64{1, 1},
Page: 3,
Expected: newAllocList().With(1, 1, 1, 2, 3),
ExpectedLen: 1,
}, {
Name: "Store something",
Action: STORE,
Key: [2]uint64{2, 2},
Pages: []uint64{4, 5, 6},
Expected: newAllocList().With(1, 1, 1, 2, 3).With(2, 2, 4, 5, 6),
ExpectedLen: 2,
}, {
Name: "Remove something",
Action: REMOVE,
Key: [2]uint64{1, 1},
Expected: newAllocList().With(2, 2, 4, 5, 6),
ExpectedLen: 1,
}}
al := newAllocList()
for _, tc := range testCases {
switch tc.Action {
case CREATE:
al.Create(tc.Key[0], tc.Key[1], tc.Page)
case PUSH:
al.Push(tc.Key[0], tc.Key[1], tc.Page)
case STORE:
al.Store(tc.Key[0], tc.Key[1], tc.Pages)
case REMOVE:
al.Remove(tc.Key[0], tc.Key[1])
default:
t.Fatalf("Unknown action: %s", tc.Action)
}
if !al.Equals(tc.Expected) {
t.Fatal(tc.Name, al, tc.Expected)
}
if al.Len() != tc.ExpectedLen {
t.Fatal(tc.Name, al.Len(), tc.ExpectedLen)
}
}
}
func TestAllocListIterate_eachError(t *testing.T) {
al := newAllocList().With(1, 1, 2, 3, 4, 5)
myErr := errors.New("xxx")
err := al.Iterate(func(collectionID, itemID uint64, pageIDs []uint64) error {
return myErr
})
if err != myErr {
t.Fatal(err)
}
}
func TestAllocListIterate(t *testing.T) {
al := newAllocList().With(1, 1, 2, 3, 4, 5).With(2, 2, 6, 7)
expected := map[uint64][]uint64{
1: {2, 3, 4, 5},
2: {6, 7},
}
err := al.Iterate(func(collectionID, itemID uint64, pageIDs []uint64) error {
e, ok := expected[collectionID]
if !ok {
t.Fatalf("Not found: %d", collectionID)
}
if !reflect.DeepEqual(e, pageIDs) {
t.Fatalf("%v != %v", pageIDs, e)
}
return nil
})
if err != nil {
t.Fatal(err)
}
}
func TestAllocListPushNoHead(t *testing.T) {
al := newAllocList().With(1, 1, 2, 3, 4, 5).With(2, 2, 6, 7)
if !al.Push(1, 1, 8) {
t.Fatal("Failed to push onto head page")
}
if al.Push(1, 2, 9) {
t.Fatal("Pushed with no head.")
}
}

58
mdb/pfile/change_test.go Normal file
View File

@@ -0,0 +1,58 @@
package pfile
import (
crand "crypto/rand"
"git.crumpington.com/public/jldb/mdb/change"
"math/rand"
)
func randomChangeList() (changes []change.Change) {
count := 1 + rand.Intn(8)
for i := 0; i < count; i++ {
change := change.Change{
CollectionID: 1 + uint64(rand.Int63n(10)),
ItemID: 1 + uint64(rand.Int63n(10)),
}
if rand.Float32() < 0.95 {
change.Data = randBytes(1 + rand.Intn(pageDataSize*4))
change.Store = true
}
changes = append(changes, change)
}
return changes
}
type changeListBuilder []change.Change
func (b *changeListBuilder) Clear() *changeListBuilder {
*b = (*b)[:0]
return b
}
func (b *changeListBuilder) Store(cID, iID, dataSize uint64) *changeListBuilder {
data := make([]byte, dataSize)
crand.Read(data)
*b = append(*b, change.Change{
CollectionID: cID,
ItemID: iID,
Store: true,
Data: data,
})
return b
}
func (b *changeListBuilder) Delete(cID, iID uint64) *changeListBuilder {
*b = append(*b, change.Change{
CollectionID: cID,
ItemID: iID,
Store: false,
})
return b
}
func (b *changeListBuilder) Build() []change.Change {
return *b
}

67
mdb/pfile/freelist.go Normal file
View File

@@ -0,0 +1,67 @@
package pfile
import "container/heap"
// ----------------------------------------------------------------------------
// The intHeap is used to store the free list.
// ----------------------------------------------------------------------------
type intHeap []uint64
func (h intHeap) Len() int { return len(h) }
func (h intHeap) Less(i, j int) bool { return h[i] < h[j] }
func (h intHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] }
func (h *intHeap) Push(x any) {
// Push and Pop use pointer receivers because they modify the slice's length,
// not just its contents.
*h = append(*h, x.(uint64))
}
func (h *intHeap) Pop() any {
old := *h
n := len(old)
x := old[n-1]
*h = old[0 : n-1]
return x
}
// ----------------------------------------------------------------------------
// Free list
// ----------------------------------------------------------------------------
type freeList struct {
h intHeap
nextPage uint64
}
// newFreeList creates a new free list that will return available pages from
// smallest to largest. If there are no available pages, it will return new
// pages starting from nextPage.
func newFreeList(pageCount uint64) *freeList {
return &freeList{
h: []uint64{},
nextPage: pageCount,
}
}
func (f *freeList) Push(pages ...uint64) {
for _, page := range pages {
heap.Push(&f.h, page)
}
}
func (f *freeList) Pop(count int, out []uint64) []uint64 {
out = out[:0]
for len(out) < count && len(f.h) > 0 {
out = append(out, heap.Pop(&f.h).(uint64))
}
for len(out) < count {
out = append(out, f.nextPage)
f.nextPage++
}
return out
}

View File

@@ -0,0 +1,90 @@
package pfile
import (
"math/rand"
"reflect"
"testing"
)
func (fl *freeList) Assert(t *testing.T, pageIDs ...uint64) {
t.Helper()
if len(fl.h) != len(pageIDs) {
t.Fatalf("FreeList: Expected %d pages but got %d.\n%v != %v",
len(pageIDs), len(fl.h), fl.h, pageIDs)
}
containsPageID := func(pageID uint64) bool {
for _, v := range fl.h {
if v == pageID {
return true
}
}
return false
}
for _, pageID := range pageIDs {
if !containsPageID(pageID) {
t.Fatalf("Page not free: %d", pageID)
}
}
}
func TestFreeList(t *testing.T) {
t.Parallel()
p0 := uint64(1 + rand.Int63())
type TestCase struct {
Name string
Put []uint64
Alloc int
Expected []uint64
}
testCases := []TestCase{
{
Name: "Alloc first page",
Put: []uint64{},
Alloc: 1,
Expected: []uint64{p0},
}, {
Name: "Alloc second page",
Put: []uint64{},
Alloc: 1,
Expected: []uint64{p0 + 1},
}, {
Name: "Put second page",
Put: []uint64{p0 + 1},
Alloc: 0,
Expected: []uint64{},
}, {
Name: "Alloc 2 pages",
Put: []uint64{},
Alloc: 2,
Expected: []uint64{p0 + 1, p0 + 2},
}, {
Name: "Put back and alloc pages",
Put: []uint64{p0},
Alloc: 3,
Expected: []uint64{p0, p0 + 3, p0 + 4},
}, {
Name: "Put back large and alloc",
Put: []uint64{p0, p0 + 2, p0 + 4, p0 + 442},
Alloc: 4,
Expected: []uint64{p0, p0 + 2, p0 + 4, p0 + 442},
},
}
fl := newFreeList(p0)
var pages []uint64
for _, tc := range testCases {
fl.Push(tc.Put...)
pages = fl.Pop(tc.Alloc, pages)
if !reflect.DeepEqual(pages, tc.Expected) {
t.Fatal(tc.Name, pages, tc.Expected)
}
}
}

1
mdb/pfile/header.go Normal file
View File

@@ -0,0 +1 @@
package pfile

105
mdb/pfile/index.go Normal file
View File

@@ -0,0 +1,105 @@
package pfile
import (
"git.crumpington.com/public/jldb/lib/errs"
"git.crumpington.com/public/jldb/mdb/change"
)
type Index struct {
fList *freeList
aList allocList
seen map[[2]uint64]struct{}
mask []bool
}
func NewIndex(f *File) (*Index, error) {
idx := &Index{
fList: newFreeList(0),
aList: *newAllocList(),
seen: map[[2]uint64]struct{}{},
mask: []bool{},
}
err := f.iterate(func(pageID uint64, page dataPage) error {
header := page.Header()
switch header.PageType {
case pageTypeHead:
idx.aList.Create(header.CollectionID, header.ItemID, pageID)
case pageTypeData:
if !idx.aList.Push(header.CollectionID, header.ItemID, pageID) {
return errs.Corrupt.WithMsg("encountered data page with no corresponding head page")
}
case pageTypeFree:
idx.fList.Push(pageID)
}
return nil
})
return idx, err
}
func (idx *Index) StageChanges(changes []change.Change) {
clear(idx.seen)
if cap(idx.mask) < len(changes) {
idx.mask = make([]bool, len(changes))
}
idx.mask = idx.mask[:len(changes)]
for i := len(changes) - 1; i >= 0; i-- {
key := [2]uint64{changes[i].CollectionID, changes[i].ItemID}
if _, ok := idx.seen[key]; ok {
idx.mask[i] = false
continue
}
idx.seen[key] = struct{}{}
idx.mask[i] = true
}
for i, active := range idx.mask {
if !active {
continue
}
if changes[i].Store {
count := idx.getPageCountForData(len(changes[i].Data))
changes[i].WritePageIDs = idx.fList.Pop(count, changes[i].WritePageIDs)
}
if pages := idx.aList.Remove(changes[i].CollectionID, changes[i].ItemID); pages != nil {
changes[i].ClearPageIDs = pages
}
}
}
func (idx *Index) UnstageChanges(changes []change.Change) {
for i := range changes {
if len(changes[i].WritePageIDs) > 0 {
idx.fList.Push(changes[i].WritePageIDs...)
changes[i].WritePageIDs = changes[i].WritePageIDs[:0]
}
if len(changes[i].ClearPageIDs) > 0 {
idx.aList.Store(changes[i].CollectionID, changes[i].ItemID, changes[i].ClearPageIDs)
changes[i].ClearPageIDs = changes[i].ClearPageIDs[:0]
}
}
}
func (idx *Index) ApplyChanges(changes []change.Change) {
for i := range changes {
if len(changes[i].WritePageIDs) > 0 {
idx.aList.Store(changes[i].CollectionID, changes[i].ItemID, changes[i].WritePageIDs)
}
if len(changes[i].ClearPageIDs) > 0 {
idx.fList.Push(changes[i].ClearPageIDs...)
}
}
}
func (idx *Index) getPageCountForData(dataSize int) int {
count := dataSize / pageDataSize
if dataSize%pageDataSize != 0 {
count++
}
return count
}

139
mdb/pfile/index_test.go Normal file
View File

@@ -0,0 +1,139 @@
package pfile
import (
"testing"
)
type IndexState struct {
FreeList []uint64
AllocList map[[2]uint64][]uint64
}
func (idx *Index) Assert(t *testing.T, state IndexState) {
t.Helper()
idx.fList.Assert(t, state.FreeList...)
idx.aList.Assert(t, state.AllocList)
}
func TestIndex(t *testing.T) {
pf, idx := newForTesting(t)
defer pf.Close()
idx.Assert(t, IndexState{
FreeList: []uint64{},
AllocList: map[[2]uint64][]uint64{},
})
p0 := uint64(0)
l := (&changeListBuilder{}).
Store(1, 1, pageDataSize+1).
Build()
idx.StageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{},
AllocList: map[[2]uint64][]uint64{},
})
// Unstage a change: free-list gets pages back.
idx.UnstageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{p0, p0 + 1},
AllocList: map[[2]uint64][]uint64{},
})
// Stage a change: free-list entries are used again.
l = (*changeListBuilder)(&l).
Clear().
Store(1, 1, pageDataSize+1).
Store(2, 2, pageDataSize-1).
Store(3, 3, pageDataSize).
Build()
idx.StageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{},
AllocList: map[[2]uint64][]uint64{},
})
// Apply changes: alloc-list is updated.
idx.ApplyChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{},
AllocList: map[[2]uint64][]uint64{
{1, 1}: {p0, p0 + 1},
{2, 2}: {p0 + 2},
{3, 3}: {p0 + 3},
},
})
// Clear some things.
l = (*changeListBuilder)(&l).
Clear().
Store(1, 1, pageDataSize).
Delete(2, 2).
Build()
idx.StageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{},
AllocList: map[[2]uint64][]uint64{
{3, 3}: {p0 + 3},
},
})
// Ustaging will push the staged page p0+4 into the free list.
idx.UnstageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{p0 + 4},
AllocList: map[[2]uint64][]uint64{
{1, 1}: {p0, p0 + 1},
{2, 2}: {p0 + 2},
{3, 3}: {p0 + 3},
},
})
idx.StageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{},
AllocList: map[[2]uint64][]uint64{
{3, 3}: {p0 + 3},
},
})
idx.ApplyChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{p0, p0 + 1, p0 + 2},
AllocList: map[[2]uint64][]uint64{
{1, 1}: {p0 + 4},
{3, 3}: {p0 + 3},
},
})
// Duplicate updates.
l = (*changeListBuilder)(&l).
Clear().
Store(2, 2, pageDataSize).
Store(3, 3, pageDataSize+1).
Store(3, 3, pageDataSize).
Build()
idx.StageChanges(l)
idx.Assert(t, IndexState{
FreeList: []uint64{p0 + 2},
AllocList: map[[2]uint64][]uint64{
{1, 1}: {p0 + 4},
},
})
}

18
mdb/pfile/iterate.go Normal file
View File

@@ -0,0 +1,18 @@
package pfile
import "bytes"
func IterateAllocated(
pf *File,
idx *Index,
each func(collectionID, itemID uint64, data []byte) error,
) error {
buf := &bytes.Buffer{}
return idx.aList.Iterate(func(collectionID, itemID uint64, pages []uint64) error {
buf.Reset()
if err := pf.readData(pages[0], buf); err != nil {
return err
}
return each(collectionID, itemID, buf.Bytes())
})
}

64
mdb/pfile/main_test.go Normal file
View File

@@ -0,0 +1,64 @@
package pfile
import (
"bytes"
crand "crypto/rand"
"git.crumpington.com/public/jldb/lib/wal"
"git.crumpington.com/public/jldb/mdb/change"
"path/filepath"
"testing"
)
func newForTesting(t *testing.T) (*File, *Index) {
t.Helper()
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "pagefile")
pf, err := Open(filePath)
if err != nil {
t.Fatal(err)
}
idx, err := NewIndex(pf)
if err != nil {
t.Fatal(err)
}
return pf, idx
}
func randBytes(size int) []byte {
buf := make([]byte, size)
if _, err := crand.Read(buf); err != nil {
panic(err)
}
return buf
}
func changesToRec(changes []change.Change) wal.Record {
buf := &bytes.Buffer{}
if err := change.Write(changes, buf); err != nil {
panic(err)
}
return wal.Record{
DataSize: int64(buf.Len()),
Reader: buf,
}
}
func TestChangesToRec(t *testing.T) {
changes := []change.Change{
{
CollectionID: 2,
ItemID: 3,
Store: true,
Data: []byte{2, 3, 4},
WritePageIDs: []uint64{0, 1},
ClearPageIDs: []uint64{2, 3},
},
}
rec := changesToRec(changes)
c2 := []change.Change{}
c2, _ = change.Read(c2, rec.Reader)
}

70
mdb/pfile/page.go Normal file
View File

@@ -0,0 +1,70 @@
package pfile
import (
"hash/crc32"
"git.crumpington.com/public/jldb/lib/errs"
"unsafe"
)
// ----------------------------------------------------------------------------
const (
pageSize = 512
pageHeaderSize = 40
pageDataSize = pageSize - pageHeaderSize
pageTypeFree = 0
pageTypeHead = 1
pageTypeData = 2
)
var emptyPage = func() dataPage {
p := newDataPage()
h := p.Header()
h.CRC = p.ComputeCRC()
return p
}()
// ----------------------------------------------------------------------------
type pageHeader struct {
CRC uint32 // IEEE CRC-32 checksum.
PageType uint32 // One of the PageType* constants.
CollectionID uint64 //
ItemID uint64
DataSize uint64
NextPage uint64
}
// ----------------------------------------------------------------------------
type dataPage []byte
func newDataPage() dataPage {
p := dataPage(make([]byte, pageSize))
return p
}
func (p dataPage) Header() *pageHeader {
return (*pageHeader)(unsafe.Pointer(&p[0]))
}
func (p dataPage) ComputeCRC() uint32 {
return crc32.ChecksumIEEE(p[4:])
}
func (p dataPage) Data() []byte {
return p[pageHeaderSize:]
}
func (p dataPage) Write(data []byte) int {
return copy(p[pageHeaderSize:], data)
}
func (p dataPage) Validate() error {
header := p.Header()
if header.CRC != p.ComputeCRC() {
return errs.Corrupt.WithMsg("CRC mismatch on data page.")
}
return nil
}

103
mdb/pfile/page_test.go Normal file
View File

@@ -0,0 +1,103 @@
package pfile
import (
"bytes"
crand "crypto/rand"
"git.crumpington.com/public/jldb/lib/errs"
"math/rand"
"testing"
)
func randomPage(t *testing.T) dataPage {
p := newDataPage()
h := p.Header()
x := rand.Float32()
if x > 0.66 {
h.PageType = pageTypeFree
h.DataSize = 0
} else if x < 0.33 {
h.PageType = pageTypeHead
h.DataSize = rand.Uint64()
} else {
h.PageType = pageTypeData
h.DataSize = rand.Uint64()
}
h.CollectionID = rand.Uint64()
h.ItemID = rand.Uint64()
dataSize := h.DataSize
if h.DataSize > pageDataSize {
dataSize = pageDataSize
}
if _, err := crand.Read(p.Data()[:dataSize]); err != nil {
t.Fatal(err)
}
h.CRC = p.ComputeCRC()
return p
}
// ----------------------------------------------------------------------------
func TestPageValidate(t *testing.T) {
for i := 0; i < 100; i++ {
p := randomPage(t)
// Should be valid initially.
if err := p.Validate(); err != nil {
t.Fatal(err)
}
for i := 0; i < pageSize; i++ {
p[i]++
if err := p.Validate(); !errs.Corrupt.Is(err) {
t.Fatal(err)
}
p[i]--
}
// Should be valid initially.
if err := p.Validate(); err != nil {
t.Fatal(err)
}
}
}
func TestPageEmptyIsValid(t *testing.T) {
if err := emptyPage.Validate(); err != nil {
t.Fatal(err)
}
}
func TestPageWrite(t *testing.T) {
for i := 0; i < 100; i++ {
page := newDataPage()
h := page.Header()
h.PageType = pageTypeData
h.CollectionID = rand.Uint64()
h.ItemID = rand.Uint64()
h.DataSize = uint64(1 + rand.Int63n(2*pageDataSize))
data := make([]byte, h.DataSize)
crand.Read(data)
n := page.Write(data)
h.CRC = page.ComputeCRC()
if n > pageDataSize || n < 1 {
t.Fatal(n)
}
if !bytes.Equal(data[:n], page.Data()[:n]) {
t.Fatal(data[:n], page.Data()[:n])
}
if err := page.Validate(); err != nil {
t.Fatal(err)
}
}
}

307
mdb/pfile/pagefile.go Normal file
View File

@@ -0,0 +1,307 @@
package pfile
import (
"bufio"
"bytes"
"compress/gzip"
"encoding/binary"
"io"
"git.crumpington.com/public/jldb/lib/errs"
"git.crumpington.com/public/jldb/mdb/change"
"net"
"os"
"sync"
"time"
)
type File struct {
lock sync.RWMutex
f *os.File
page dataPage
}
func Open(path string) (*File, error) {
f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0600)
if err != nil {
return nil, errs.IO.WithErr(err)
}
pf := &File{f: f}
pf.page = newDataPage()
return pf, nil
}
func (pf *File) Close() error {
pf.lock.Lock()
defer pf.lock.Unlock()
if err := pf.f.Close(); err != nil {
return errs.IO.WithErr(err)
}
return nil
}
// ----------------------------------------------------------------------------
// Writing
// ----------------------------------------------------------------------------
func (pf *File) ApplyChanges(changes []change.Change) error {
pf.lock.Lock()
defer pf.lock.Unlock()
return pf.applyChanges(changes)
}
func (pf *File) applyChanges(changes []change.Change) error {
for _, change := range changes {
if len(change.WritePageIDs) > 0 {
if err := pf.writeChangePages(change); err != nil {
return err
}
}
for _, id := range change.ClearPageIDs {
if err := pf.writePage(emptyPage, id); err != nil {
return err
}
}
}
if err := pf.f.Sync(); err != nil {
return errs.IO.WithErr(err)
}
return nil
}
func (pf *File) writeChangePages(change change.Change) error {
page := pf.page
header := page.Header()
header.PageType = pageTypeHead
header.CollectionID = change.CollectionID
header.ItemID = change.ItemID
header.DataSize = uint64(len(change.Data))
pageIDs := change.WritePageIDs
data := change.Data
for len(change.Data) > 0 && len(pageIDs) > 0 {
pageID := pageIDs[0]
pageIDs = pageIDs[1:]
if len(pageIDs) > 0 {
header.NextPage = pageIDs[0]
} else {
header.NextPage = 0
}
n := page.Write(data)
data = data[n:]
page.Header().CRC = page.ComputeCRC()
if err := pf.writePage(page, pageID); err != nil {
return err
}
// All but first page has pageTypeData.
header.PageType = pageTypeData
}
if len(pageIDs) > 0 {
return errs.Unexpected.WithMsg("Too many pages provided for given data.")
}
if len(data) > 0 {
return errs.Unexpected.WithMsg("Not enough pages for given data.")
}
return nil
}
func (pf *File) writePage(page dataPage, id uint64) error {
if _, err := pf.f.WriteAt(page, int64(id*pageSize)); err != nil {
return errs.IO.WithErr(err)
}
return nil
}
// ----------------------------------------------------------------------------
// Reading
// ----------------------------------------------------------------------------
func (pf *File) iterate(each func(pageID uint64, page dataPage) error) error {
pf.lock.RLock()
defer pf.lock.RUnlock()
page := pf.page
fi, err := pf.f.Stat()
if err != nil {
return errs.IO.WithErr(err)
}
fileSize := fi.Size()
if fileSize%pageSize != 0 {
return errs.Corrupt.WithMsg("File size isn't a multiple of page size.")
}
maxPage := uint64(fileSize / pageSize)
if _, err := pf.f.Seek(0, io.SeekStart); err != nil {
return errs.IO.WithErr(err)
}
r := bufio.NewReaderSize(pf.f, 1024*1024)
for pageID := uint64(0); pageID < maxPage; pageID++ {
if _, err := r.Read(page); err != nil {
return errs.IO.WithErr(err)
}
if err := page.Validate(); err != nil {
return err
}
if err := each(pageID, page); err != nil {
return err
}
}
return nil
}
func (pf *File) readData(id uint64, buf *bytes.Buffer) error {
page := pf.page
// The head page.
if err := pf.readPage(page, id); err != nil {
return err
}
remaining := int(page.Header().DataSize)
for {
data := page.Data()
if len(data) > remaining {
data = data[:remaining]
}
buf.Write(data)
remaining -= len(data)
if page.Header().NextPage == 0 {
break
}
if err := pf.readPage(page, page.Header().NextPage); err != nil {
return err
}
}
if remaining != 0 {
return errs.Corrupt.WithMsg("Incorrect data size. %d remaining.", remaining)
}
return nil
}
func (pf *File) readPage(p dataPage, id uint64) error {
if _, err := pf.f.ReadAt(p, int64(id*pageSize)); err != nil {
return errs.IO.WithErr(err)
}
return p.Validate()
}
// ----------------------------------------------------------------------------
// Send / Recv
// ----------------------------------------------------------------------------
func (pf *File) Send(conn net.Conn, timeout time.Duration) error {
pf.lock.RLock()
defer pf.lock.RUnlock()
if _, err := pf.f.Seek(0, io.SeekStart); err != nil {
return errs.IO.WithErr(err)
}
fi, err := pf.f.Stat()
if err != nil {
return errs.IO.WithErr(err)
}
remaining := fi.Size()
conn.SetWriteDeadline(time.Now().Add(timeout))
if err := binary.Write(conn, binary.LittleEndian, remaining); err != nil {
return err
}
buf := make([]byte, 1024*1024)
w, err := gzip.NewWriterLevel(conn, 3)
if err != nil {
return errs.Unexpected.WithErr(err)
}
defer w.Close()
for remaining > 0 {
n, err := pf.f.Read(buf)
if err != nil {
return errs.IO.WithErr(err)
}
conn.SetWriteDeadline(time.Now().Add(timeout))
if _, err := w.Write(buf[:n]); err != nil {
return errs.IO.WithErr(err)
}
remaining -= int64(n)
w.Flush()
}
return nil
}
func Recv(conn net.Conn, filePath string, timeout time.Duration) error {
defer conn.Close()
f, err := os.Create(filePath)
if err != nil {
return errs.IO.WithErr(err)
}
defer f.Close()
remaining := uint64(0)
if err := binary.Read(conn, binary.LittleEndian, &remaining); err != nil {
return err
}
r, err := gzip.NewReader(conn)
if err != nil {
return errs.Unexpected.WithErr(err)
}
defer r.Close()
buf := make([]byte, 1024*1024)
for remaining > 0 {
conn.SetReadDeadline(time.Now().Add(timeout))
n, err := io.ReadFull(r, buf)
if err != nil && n == 0 {
return errs.IO.WithErr(err)
}
remaining -= uint64(n)
if _, err := f.Write(buf[:n]); err != nil {
return errs.IO.WithErr(err)
}
}
if err := f.Sync(); err != nil {
return errs.IO.WithErr(err)
}
return nil
}

View File

@@ -0,0 +1,94 @@
package pfile
import (
"bytes"
"os"
"path/filepath"
"testing"
)
type FileState struct {
SeqNum uint64
Data map[[2]uint64][]byte
}
func (pf *File) Assert(t *testing.T, state pFileState) {
t.Helper()
pf.lock.RLock()
defer pf.lock.RUnlock()
idx, err := NewIndex(pf)
if err != nil {
t.Fatal(err)
}
data := map[[2]uint64][]byte{}
err = IterateAllocated(pf, idx, func(cID, iID uint64, fileData []byte) error {
data[[2]uint64{cID, iID}] = bytes.Clone(fileData)
return nil
})
if err != nil {
t.Fatal(err)
}
if len(data) != len(state.Data) {
t.Fatalf("Expected %d items but got %d.", len(state.Data), len(data))
}
for key, expected := range state.Data {
val, ok := data[key]
if !ok {
t.Fatalf("No data found for key %v.", key)
}
if !bytes.Equal(val, expected) {
t.Fatalf("Incorrect data for key %v.", key)
}
}
}
func TestFileStateUpdateRandom(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
walDir := filepath.Join(tmpDir, "wal")
pageFilePath := filepath.Join(tmpDir, "pagefile")
if err := os.MkdirAll(walDir, 0700); err != nil {
t.Fatal(err)
}
pf, err := Open(pageFilePath)
if err != nil {
t.Fatal(err)
}
idx, err := NewIndex(pf)
if err != nil {
t.Fatal(err)
}
state := pFileState{
Data: map[[2]uint64][]byte{},
}
for i := uint64(1); i < 256; i++ {
changes := randomChangeList()
idx.StageChanges(changes)
if err := pf.ApplyChanges(changes); err != nil {
t.Fatal(err)
}
idx.ApplyChanges(changes)
for _, ch := range changes {
if !ch.Store {
delete(state.Data, [2]uint64{ch.CollectionID, ch.ItemID})
} else {
state.Data[[2]uint64{ch.CollectionID, ch.ItemID}] = ch.Data
}
}
pf.Assert(t, state)
}
}

57
mdb/pfile/record_test.go Normal file
View File

@@ -0,0 +1,57 @@
package pfile
import (
"bytes"
"git.crumpington.com/public/jldb/lib/wal"
"git.crumpington.com/public/jldb/mdb/change"
)
// ----------------------------------------------------------------------------
type pFileState struct {
Data map[[2]uint64][]byte
}
// ----------------------------------------------------------------------------
type recBuilder struct {
changes []change.Change
rec wal.Record
}
func NewRecBuilder(seqNum, timestamp int64) *recBuilder {
return &recBuilder{
rec: wal.Record{
SeqNum: seqNum,
TimestampMS: timestamp,
},
changes: []change.Change{},
}
}
func (b *recBuilder) Store(cID, iID uint64, data string) *recBuilder {
b.changes = append(b.changes, change.Change{
CollectionID: cID,
ItemID: iID,
Store: true,
Data: []byte(data),
})
return b
}
func (b *recBuilder) Delete(cID, iID uint64) *recBuilder {
b.changes = append(b.changes, change.Change{
CollectionID: cID,
ItemID: iID,
Store: false,
})
return b
}
func (b *recBuilder) Record() wal.Record {
buf := &bytes.Buffer{}
change.Write(b.changes, buf)
b.rec.DataSize = int64(buf.Len())
b.rec.Reader = buf
return b.rec
}

View File

@@ -0,0 +1,62 @@
package pfile
/*
func TestSendRecv(t *testing.T) {
tmpDir := t.TempDir()
filePath1 := filepath.Join(tmpDir, "1")
filePath2 := filepath.Join(tmpDir, "2")
defer os.RemoveAll(tmpDir)
f1, err := os.Create(filePath1)
if err != nil {
t.Fatal(err)
}
size := rand.Int63n(1024 * 1024 * 128)
buf := make([]byte, size)
crand.Read(buf)
if _, err := f1.Write(buf); err != nil {
t.Fatal(err)
}
if err := f1.Close(); err != nil {
t.Fatal(err)
}
c1, c2 := net.Pipe()
errChan := make(chan error)
go func() {
err := Send(filePath1, c1, time.Second)
if err != nil {
log.Printf("Send error: %v", err)
}
errChan <- err
}()
go func() {
err := Recv(filePath2, c2, time.Second)
if err != nil {
log.Printf("Recv error: %v", err)
}
errChan <- err
}()
if err := <-errChan; err != nil {
t.Fatal(err)
}
if err := <-errChan; err != nil {
t.Fatal(err)
}
buf2, err := os.ReadFile(filePath2)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(buf, buf2) {
t.Fatal("Not equal.")
}
}
*/

106
mdb/snapshot.go Normal file
View File

@@ -0,0 +1,106 @@
package mdb
import (
"bytes"
"encoding/json"
"git.crumpington.com/public/jldb/mdb/change"
"sync/atomic"
)
type Snapshot struct {
parent atomic.Pointer[Snapshot]
// The Snapshot's version is incremented each time it's cloned.
version uint64
// The snapshot's seqNum is set when it becomes active (read-only).]
seqNum int64
timestampMS int64
collections map[uint64]any // Map from collection ID to *collectionState[T].
changes []change.Change
}
func newSnapshot() *Snapshot {
return &Snapshot{
collections: map[uint64]any{},
changes: []change.Change{},
}
}
func (s *Snapshot) addCollection(id uint64, c any) {
s.collections[id] = c
}
func (s *Snapshot) writable() bool {
return s.parent.Load() != nil
}
func (s *Snapshot) setReadOnly() {
s.parent.Store(nil)
s.changes = s.changes[:0]
}
func (s *Snapshot) store(cID, iID uint64, item any) {
change := s.appendChange(cID, iID)
change.Store = true
buf := bytes.NewBuffer(change.Data[:0])
if err := json.NewEncoder(buf).Encode(item); err != nil {
panic(err)
}
change.Data = buf.Bytes()
}
func (s *Snapshot) delete(cID, iID uint64) {
change := s.appendChange(cID, iID)
change.Store = false
}
func (s *Snapshot) appendChange(cID, iID uint64) *change.Change {
if len(s.changes) == cap(s.changes) {
s.changes = append(s.changes, change.Change{})
} else {
s.changes = s.changes[:len(s.changes)+1]
}
change := &s.changes[len(s.changes)-1]
change.CollectionID = cID
change.ItemID = iID
change.Store = false
change.ClearPageIDs = change.ClearPageIDs[:0]
change.WritePageIDs = change.WritePageIDs[:0]
change.Data = change.Data[:0]
return change
}
func (s *Snapshot) begin() *Snapshot {
c := s.clone()
c.changes = c.changes[:0]
return c
}
func (s *Snapshot) clone() *Snapshot {
collections := make(map[uint64]any, len(s.collections))
for k, v := range s.collections {
collections[k] = v
}
c := &Snapshot{
version: s.version + 1,
collections: collections,
changes: s.changes[:],
}
c.parent.Store(s)
return c
}
func (s *Snapshot) rollback() *Snapshot {
parent := s.parent.Load()
if parent == nil {
return nil
}
// Don't throw away allocated changes.
parent.changes = s.changes[:len(parent.changes)]
return parent
}

View File

@@ -0,0 +1,41 @@
package mdb
import (
"log"
"os"
"sync/atomic"
"testing"
"time"
)
func TestDBIsolation(t *testing.T) {
if testing.Short() {
t.Skip("Skipping test in short mode.")
}
rootDir := t.TempDir()
defer os.RemoveAll(rootDir)
db, err := OpenDataDB(rootDir)
if err != nil {
t.Fatal(err)
}
done := &atomic.Bool{}
go func() {
defer done.Store(true)
db.ModifyFor(8 * time.Second)
}()
count := 0
for !done.Load() {
count++
tx := db.Snapshot()
computed := db.ComputeCRC(tx)
stored := db.ReadCRC(tx)
if computed != stored {
t.Fatal(stored, computed)
}
}
log.Printf("Read: %d", count)
}

151
mdb/testdb_test.go Normal file
View File

@@ -0,0 +1,151 @@
package mdb
import (
"crypto/rand"
"errors"
"hash/crc32"
"log"
mrand "math/rand"
"runtime"
"slices"
"sync"
"sync/atomic"
"time"
)
type DataItem struct {
ID uint64
Data []byte
}
type DataCollection struct {
*Collection[DataItem]
}
func NewDataCollection(db *Database) DataCollection {
return DataCollection{
Collection: NewCollection(db, "Data", &CollectionConfig[DataItem]{
Copy: func(in *DataItem) *DataItem {
out := &DataItem{}
*out = *in
out.Data = slices.Clone(in.Data)
return out
},
}),
}
}
type CRCItem struct {
ID uint64 // Always 1
CRC32 uint32
}
type CRCCollection struct {
*Collection[CRCItem]
}
func NewCRCCollection(db *Database) CRCCollection {
return CRCCollection{
Collection: NewCollection[CRCItem](db, "CRC", nil),
}
}
type DataDB struct {
*Database
Datas DataCollection
CRCs CRCCollection
}
func OpenDataDB(rootDir string) (DataDB, error) {
db := New(Config{
RootDir: rootDir,
Primary: true,
})
testdb := DataDB{
Database: db,
Datas: NewDataCollection(db),
CRCs: NewCRCCollection(db),
}
return testdb, testdb.Open()
}
func (db DataDB) ModifyFor(dt time.Duration) {
wg := sync.WaitGroup{}
var count int64
for i := 0; i < runtime.NumCPU(); i++ {
wg.Add(1)
go func() {
defer wg.Done()
t0 := time.Now()
for time.Since(t0) < dt {
atomic.AddInt64(&count, 1)
db.modifyOnce()
}
}()
}
wg.Wait()
log.Printf("Modified: %d", count)
}
func (db DataDB) modifyOnce() {
isErr := mrand.Float64() < 0.1
err := db.Update(func(tx *Snapshot) error {
h := crc32.NewIEEE()
for dataID := uint64(1); dataID < 10; dataID++ {
d := DataItem{
ID: dataID,
Data: make([]byte, 256),
}
rand.Read(d.Data)
h.Write(d.Data)
if err := db.Datas.Upsert(tx, &d); err != nil {
return err
}
}
crc := CRCItem{
ID: 1,
}
if !isErr {
crc.CRC32 = h.Sum32()
return db.CRCs.Upsert(tx, &crc)
}
crc.CRC32 = 1
if err := db.CRCs.Upsert(tx, &crc); err != nil {
return err
}
return errors.New("ERROR")
})
if isErr != (err != nil) {
panic(err)
}
}
func (db DataDB) ComputeCRC(tx *Snapshot) uint32 {
h := crc32.NewIEEE()
for dataID := uint64(1); dataID < 10; dataID++ {
d, ok := db.Datas.ByID.Get(tx, &DataItem{ID: dataID})
if !ok {
continue
}
h.Write(d.Data)
}
return h.Sum32()
}
func (db DataDB) ReadCRC(tx *Snapshot) uint32 {
r, ok := db.CRCs.ByID.Get(tx, &CRCItem{ID: 1})
if !ok {
return 0
}
return r.CRC32
}

View File

@@ -0,0 +1,162 @@
package main
import (
"crypto/rand"
"errors"
"hash/crc32"
"git.crumpington.com/public/jldb/mdb"
"log"
mrand "math/rand"
"os"
"runtime"
"slices"
"sync"
"sync/atomic"
"time"
)
type DataItem struct {
ID uint64
Data []byte
}
type DataCollection struct {
*mdb.Collection[DataItem]
}
func NewDataCollection(db *mdb.Database) DataCollection {
return DataCollection{
Collection: mdb.NewCollection(db, "Data", &mdb.CollectionConfig[DataItem]{
Copy: func(in *DataItem) *DataItem {
out := new(DataItem)
*out = *in
out.Data = slices.Clone(in.Data)
return out
},
}),
}
}
type CRCItem struct {
ID uint64 // Always 1
CRC32 uint32
}
type CRCCollection struct {
*mdb.Collection[CRCItem]
}
func NewCRCCollection(db *mdb.Database) CRCCollection {
return CRCCollection{
Collection: mdb.NewCollection[CRCItem](db, "CRC", nil),
}
}
type DataDB struct {
*mdb.Database
Datas DataCollection
CRCs CRCCollection
}
func OpenDataDB(rootDir string) (DataDB, error) {
db := mdb.New(mdb.Config{RootDir: rootDir, Primary: true})
testdb := DataDB{
Database: db,
Datas: NewDataCollection(db),
CRCs: NewCRCCollection(db),
}
if err := db.Open(); err != nil {
return testdb, err
}
return testdb, nil
}
func (db DataDB) ModifyFor(dt time.Duration) {
wg := sync.WaitGroup{}
var count int64
for i := 0; i < runtime.NumCPU(); i++ {
wg.Add(1)
go func() {
defer wg.Done()
t0 := time.Now()
for time.Since(t0) < dt {
atomic.AddInt64(&count, 1)
db.modifyOnce()
}
}()
}
wg.Wait()
log.Printf("Modified: %d", count)
}
func (db DataDB) modifyOnce() {
isErr := mrand.Float64() < 0.1
err := db.Update(func(tx *mdb.Snapshot) error {
h := crc32.NewIEEE()
for dataID := uint64(1); dataID < 10; dataID++ {
d := DataItem{
ID: dataID,
Data: make([]byte, 256),
}
rand.Read(d.Data)
h.Write(d.Data)
if err := db.Datas.Upsert(tx, &d); err != nil {
return err
}
}
crc := CRCItem{
ID: 1,
}
if !isErr {
crc.CRC32 = h.Sum32()
return db.CRCs.Upsert(tx, &crc)
}
crc.CRC32 = 1
if err := db.CRCs.Upsert(tx, &crc); err != nil {
return err
}
return errors.New("ERROR")
})
if isErr != (err != nil) {
panic(err)
}
}
func (db DataDB) ComputeCRC(tx *mdb.Snapshot) uint32 {
h := crc32.NewIEEE()
for dataID := uint64(1); dataID < 10; dataID++ {
d, ok := db.Datas.ByID.Get(tx, &DataItem{ID: dataID})
if !ok {
continue
}
h.Write(d.Data)
}
return h.Sum32()
}
func (db DataDB) ReadCRC(tx *mdb.Snapshot) uint32 {
r, ok := db.CRCs.ByID.Get(tx, &CRCItem{ID: 1})
if !ok {
return 0
}
return r.CRC32
}
func main() {
db, err := OpenDataDB(os.Args[1])
if err != nil {
log.Fatal(err)
}
db.ModifyFor(time.Minute)
}

92
mdb/txaggregator.go Normal file
View File

@@ -0,0 +1,92 @@
package mdb
/*
type txAggregator struct {
Stop chan struct{}
Done *sync.WaitGroup
ModChan chan txMod
W *cswal.Writer
Index *pagefile.Index
Snapshot *atomic.Pointer[Snapshot]
}
func (p txAggregator) Run() {
defer p.Done.Done()
defer p.W.Close()
var (
tx *Snapshot
mod txMod
rec cswal.Record
err error
toNotify = make([]chan error, 0, 1024)
)
READ_FIRST:
toNotify = toNotify[:0]
select {
case mod = <-p.ModChan:
goto BEGIN
case <-p.Stop:
goto END
}
BEGIN:
tx = p.Snapshot.Load().begin()
goto APPLY_MOD
CLONE:
tx = tx.clone()
goto APPLY_MOD
APPLY_MOD:
if err = mod.Update(tx); err != nil {
mod.Resp <- err
goto ROLLBACK
}
toNotify = append(toNotify, mod.Resp)
goto NEXT
ROLLBACK:
if len(toNotify) == 0 {
goto READ_FIRST
}
tx = tx.rollback()
goto NEXT
NEXT:
select {
case mod = <-p.ModChan:
goto CLONE
default:
goto WRITE
}
WRITE:
rec, err = writeChangesToWAL(tx.changes, p.Index, p.W)
if err == nil {
tx.seqNum = rec.SeqNum
tx.updatedAt = rec.CreatedAt
tx.setReadOnly()
p.Snapshot.Store(tx)
}
for i := range toNotify {
toNotify[i] <- err
}
goto READ_FIRST
END:
}
*/

8
mdb/types.go Normal file
View File

@@ -0,0 +1,8 @@
package mdb
type collection interface {
Name() string
insertItem(tx *Snapshot, itemID uint64, data []byte) error
upsertItem(tx *Snapshot, itemID uint64, data []byte) error
deleteItem(tx *Snapshot, itemID uint64) error
}

35
mdb/walfollower.go Normal file
View File

@@ -0,0 +1,35 @@
package mdb
/*
type walFollower struct {
Stop chan struct{}
Done *sync.WaitGroup
W *cswal.Writer
Client *Client
}
func (f *walFollower) Run() {
go func() {
<-f.Stop
f.Client.Close()
}()
defer f.Done.Done()
for {
f.runOnce()
select {
case <-f.Stop:
return
default:
time.Sleep(time.Second)
}
}
}
func (f *walFollower) runOnce() {
if err := f.Client.StreamWAL(f.W); err != nil {
log.Printf("[WAL-FOLLOWER] Recv failed: %s", err)
}
}
*/