Cleanup + testing

master
jdl 2022-07-29 21:36:42 +02:00
parent deefaff7ee
commit caf7ecf366
30 changed files with 526 additions and 1322 deletions

View File

@ -9,6 +9,7 @@ import (
type BTreeIndex[T any] struct { type BTreeIndex[T any] struct {
c *Collection[T] c *Collection[T]
name string
modLock sync.Mutex modLock sync.Mutex
bt atomic.Value // *btree.BTreeG[*T] bt atomic.Value // *btree.BTreeG[*T]
getID func(*T) uint64 getID func(*T) uint64
@ -18,12 +19,14 @@ type BTreeIndex[T any] struct {
func NewBTreeIndex[T any]( func NewBTreeIndex[T any](
c *Collection[T], c *Collection[T],
name string,
less func(*T, *T) bool, less func(*T, *T) bool,
include func(*T) bool, include func(*T) bool,
) *BTreeIndex[T] { ) *BTreeIndex[T] {
t := &BTreeIndex[T]{ t := &BTreeIndex[T]{
c: c, c: c,
name: name,
getID: c.getID, getID: c.getID,
less: less, less: less,
include: include, include: include,

View File

@ -278,6 +278,7 @@ func TestBTreeIndex_load_ErrDuplicate(t *testing.T) {
testWithDB(t, "", func(t *testing.T, db *DB) { testWithDB(t, "", func(t *testing.T, db *DB) {
idx := NewBTreeIndex( idx := NewBTreeIndex(
db.Users.c, db.Users.c,
"extid",
func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID }, func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID },
nil) nil)

View File

@ -61,10 +61,8 @@ func (db *Database) Start() {
wg.Wait() wg.Wait()
} }
func (db *Database) WALStatus() (ws WALStatus) { func (db *Database) MaxSeqNum() uint64 {
ws.MaxSeqNumKV = db.kv.WALMaxSeqNum() return db.kv.MaxSeqNum()
ws.MaxSeqNumWAL = db.kv.MaxSeqNum()
return
} }
func (db *Database) Close() { func (db *Database) Close() {

View File

@ -1,13 +1 @@
package mdb package mdb
import "time"
func (db *Database) WaitForWAL() {
for {
status := db.WALStatus()
if status.MaxSeqNumWAL == status.MaxSeqNumKV {
return
}
time.Sleep(100 * time.Millisecond)
}
}

View File

@ -17,8 +17,6 @@ func TestItemMap(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
db.WaitForWAL()
if err := db.Users.c.items.EqualsKV(); err != nil { if err := db.Users.c.items.EqualsKV(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -3,32 +3,93 @@ package kvstore
const sqlSchema = ` const sqlSchema = `
BEGIN IMMEDIATE; BEGIN IMMEDIATE;
CREATE TABLE IF NOT EXISTS data(
SeqNum INTEGER NOT NULL PRIMARY KEY,
Deleted INTEGER NOT NULL DEFAULT 0,
Data BLOB NOT NULL
) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS data_deleted_index ON data(Deleted,SeqNum);
CREATE TABLE IF NOT EXISTS log(
SeqNum INTEGER NOT NULL PRIMARY KEY,
CreatedAt INTEGER NOT NULL,
Collection TEXT NOT NULL,
ID INTEGER NOT NULL,
Store INTEGER NOT NULL
) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS log_created_at_index ON log(CreatedAt);
CREATE TABLE IF NOT EXISTS kv( CREATE TABLE IF NOT EXISTS kv(
Collection TEXT NOT NULL, Collection TEXT NOT NULL,
ID INTEGER NOT NULL, ID INTEGER NOT NULL,
Data BLOB, SeqNum INTEGER NOT NULL,
PRIMARY KEY (Collection, ID) PRIMARY KEY (Collection, ID)
) WITHOUT ROWID; ) WITHOUT ROWID;
CREATE TABLE IF NOT EXISTS maxSeqNum( CREATE VIEW IF NOT EXISTS kvdata AS
ID INTEGER NOT NULL PRIMARY KEY, SELECT
Value INTEGER NOT NULL kv.Collection,
) WITHOUT ROWID; kv.ID,
data.Data
FROM kv
JOIN data ON kv.SeqNum=data.SeqNum;
INSERT INTO maxSeqNum(ID, Value) VALUES (1, 0) ON CONFLICT DO NOTHING; CREATE VIEW IF NOT EXISTS logdata AS
SELECT
log.SeqNum,
log.Collection,
log.ID,
log.Store,
data.data
FROM log
LEFT JOIN data on log.SeqNum=data.SeqNum;
COMMIT;` COMMIT;`
const sqlKVUpsert = ` // ----------------------------------------------------------------------------
INSERT INTO kv
(Collection,ID,Data) VALUES (?,?,?) const sqlInsertData = `INSERT INTO data(SeqNum,Data) VALUES(?,?)`
ON CONFLICT(Collection,ID) DO UPDATE SET
Data=excluded.Data const sqlInsertKV = `INSERT INTO kv(Collection,ID,SeqNum) VALUES (?,?,?)
ON CONFLICT(Collection,ID) DO UPDATE SET SeqNum=excluded.SeqNum
WHERE ID=excluded.ID`
// ----------------------------------------------------------------------------
const sqlDeleteKV = `DELETE FROM kv WHERE Collection=? AND ID=?`
const sqlDeleteData = `UPDATE data SET Deleted=1
WHERE SeqNum=(
SELECT SeqNum FROM kv WHERE Collection=? AND ID=?)`
// ----------------------------------------------------------------------------
const sqlInsertLog = `INSERT INTO log(SeqNum,CreatedAt,Collection,ID,Store)
VALUES(?,?,?,?,?)`
// ----------------------------------------------------------------------------
const sqlKVIterate = `SELECT ID,Data FROM kvdata WHERE Collection=?`
const sqlLogIterate = `
SELECT SeqNum,Collection,ID,Store,Data
FROM logdata
WHERE SeqNum > ?
ORDER BY SeqNum ASC`
const sqlMaxSeqNumGet = `SELECT COALESCE(MAX(SeqNum),0) FROM log`
const sqlCleanQuery = `
DELETE FROM
log
WHERE WHERE
ID=excluded.ID` CreatedAt < ? AND
SeqNum < (SELECT MAX(SeqNum) FROM log;
const sqlKVDelete = `DELETE FROM kv WHERE Collection=? AND ID=?` DELETE FROM
const sqlKVIterate = `SELECT ID,Data FROM kv WHERE Collection=?` data
WHERE
const sqlMaxSeqNumGet = `SELECT Value FROM maxSeqNum WHERE ID=1` Deleted != 0 AND
const sqlMaxSeqNumSet = `UPDATE maxSeqNum SET Value=? WHERE ID=1` SeqNum < (SELECT MIN(SeqNum) FROM log;`

9
kvstore/globals.go Normal file
View File

@ -0,0 +1,9 @@
package kvstore
import "time"
var (
connTimeout = 16 * time.Second
heartbeatInterval = 4 * time.Second
pollInterval = 500 * time.Millisecond
)

View File

@ -1,12 +1,10 @@
package wal package kvstore
import ( import "encoding/binary"
"encoding/binary"
)
const recHeaderSize = 22 const recHeaderSize = 22
func encodeRecordHeader(rec Record, buf []byte) { func encodeRecordHeader(rec record, buf []byte) {
// SeqNum (8) // SeqNum (8)
// ID (8) // ID (8)
// DataLen (4) // DataLen (4)
@ -30,7 +28,7 @@ func encodeRecordHeader(rec Record, buf []byte) {
buf[0] = byte(len(rec.Collection)) buf[0] = byte(len(rec.Collection))
} }
func decodeRecHeader(header []byte) (rec Record, colLen, dataLen int) { func decodeRecHeader(header []byte) (rec record, colLen, dataLen int) {
buf := header buf := header
rec.SeqNum = binary.LittleEndian.Uint64(buf[:8]) rec.SeqNum = binary.LittleEndian.Uint64(buf[:8])

View File

@ -75,7 +75,7 @@ func TestShipping(t *testing.T) {
run("simple", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) { run("simple", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
M := 10 M := 10
N := 100 N := 1000
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
@ -85,7 +85,7 @@ func TestShipping(t *testing.T) {
go func() { go func() {
defer wg.Done() defer wg.Done()
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
time.Sleep(10 * time.Millisecond) time.Sleep(time.Millisecond)
prim.randAction() prim.randAction()
} }
}() }()
@ -113,9 +113,15 @@ func TestShipping(t *testing.T) {
nw.CloseClient() nw.CloseClient()
wg.Wait() wg.Wait()
prim.equalsKV("a", sec) if err := prim.equalsKV("a", sec); err != nil {
prim.equalsKV("b", sec) t.Fatal(err)
prim.equalsKV("c", sec) }
if err := prim.equalsKV("b", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("c", sec); err != nil {
t.Fatal(err)
}
}) })
run("net failures", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) { run("net failures", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
@ -172,12 +178,18 @@ func TestShipping(t *testing.T) {
} }
}() }()
sec.waitForSeqNum(uint64(M * N)) sec.waitForSeqNum(prim.MaxSeqNum())
wg.Wait() wg.Wait()
prim.equalsKV("a", sec) if err := prim.equalsKV("a", sec); err != nil {
prim.equalsKV("b", sec) t.Fatal(err)
prim.equalsKV("c", sec) }
if err := prim.equalsKV("b", sec); err != nil {
t.Fatal(err)
}
if err := prim.equalsKV("c", sec); err != nil {
t.Fatal(err)
}
}) })
} }

View File

@ -2,71 +2,69 @@ package kvstore
import ( import (
"database/sql" "database/sql"
"fmt"
"net"
"sync" "sync"
"time" "time"
"git.crumpington.com/private/mdb/kvstore/wal"
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
type KV struct { type KV struct {
primary bool primary bool
lockPath string dbPath string
dataPath string
walPath string
w *wal.Writer
f *wal.Follower
db *sql.DB db *sql.DB
stop chan struct{} maxSeqNumStmt *sql.Stmt
done sync.WaitGroup logIterateStmt *sql.Stmt
w *writer
onStore func(string, uint64, []byte) onStore func(string, uint64, []byte)
onDelete func(string, uint64) onDelete func(string, uint64)
closeLock sync.Mutex closeLock sync.Mutex
shippingLock sync.Mutex recvLock sync.Mutex
} }
func (kv *KV) init() { func newKV(
dir string,
primary bool,
onStore func(string, uint64, []byte),
onDelete func(string, uint64),
) *KV {
kv := &KV{
dbPath: dbPath(dir),
primary: primary,
onStore: onStore,
onDelete: onDelete,
}
opts := `?_journal=WAL` opts := `?_journal=WAL`
db, err := sql.Open("sqlite3", kv.dataPath+opts) db, err := sql.Open("sqlite3", kv.dbPath+opts)
must(err) must(err)
_, err = db.Exec(sqlSchema) _, err = db.Exec(sqlSchema)
must(err) must(err)
if kv.primary { kv.maxSeqNumStmt, err = db.Prepare(sqlMaxSeqNumGet)
kv.w = wal.NewWriterPrimary(kv.walPath) must(err)
} else { kv.logIterateStmt, err = db.Prepare(sqlLogIterate)
kv.w = wal.NewWriterSecondary(kv.walPath) must(err)
}
_, err = db.Exec(sqlSchema)
must(err)
kv.f = wal.NewFollower(kv.walPath)
kv.db = db kv.db = db
kv.stop = make(chan struct{})
kv.commit()
}
func (kv *KV) start() { if kv.primary {
// Spawn follower in background to write data from WAL to data. kv.w = newWriter(kv.db)
kv.done.Add(1) kv.w.Start(kv.MaxSeqNum())
go kv.background() }
return kv
} }
func NewPrimary(dir string) *KV { func NewPrimary(dir string) *KV {
kv := &KV{ return newKV(dir, true, nil, nil)
primary: true,
dataPath: dataPath(dir),
walPath: walPath(dir),
}
kv.init()
kv.start()
return kv
} }
func NewSecondary( func NewSecondary(
@ -74,20 +72,7 @@ func NewSecondary(
onStore func(collection string, id uint64, data []byte), onStore func(collection string, id uint64, data []byte),
onDelete func(collection string, id uint64), onDelete func(collection string, id uint64),
) *KV { ) *KV {
kv := &KV{ return newKV(dir, false, onStore, onDelete)
primary: false,
dataPath: dataPath(dir),
walPath: walPath(dir),
}
kv.init()
kv.onStore = onStore
kv.onDelete = onDelete
kv.start()
return kv
} }
func (kv *KV) Primary() bool { func (kv *KV) Primary() bool {
@ -95,14 +80,10 @@ func (kv *KV) Primary() bool {
} }
func (kv *KV) MaxSeqNum() (seqNum uint64) { func (kv *KV) MaxSeqNum() (seqNum uint64) {
kv.db.QueryRow(sqlMaxSeqNumGet).Scan(&seqNum) must(kv.maxSeqNumStmt.QueryRow().Scan(&seqNum))
return seqNum return seqNum
} }
func (kv *KV) WALMaxSeqNum() uint64 {
return kv.w.MaxSeqNum()
}
func (kv *KV) Iterate(collection string, each func(id uint64, data []byte)) { func (kv *KV) Iterate(collection string, each func(id uint64, data []byte)) {
rows, err := kv.db.Query(sqlKVIterate, collection) rows, err := kv.db.Query(sqlKVIterate, collection)
must(err) must(err)
@ -123,139 +104,25 @@ func (kv *KV) Close() {
kv.closeLock.Lock() kv.closeLock.Lock()
defer kv.closeLock.Unlock() defer kv.closeLock.Unlock()
if kv.w == nil { if kv.w != nil {
return kv.w.Stop()
} }
kv.stop <- struct{}{} if kv.db != nil {
kv.done.Wait() kv.db.Close()
kv.db = nil
kv.w.Close() }
kv.f.Close()
kv.db.Close()
kv.w = nil
kv.f = nil
kv.db = nil
} }
func (kv *KV) Store(collection string, id uint64, data []byte) { func (kv *KV) Store(collection string, id uint64, data []byte) {
if !kv.primary {
panic("Store called on secondary.")
}
kv.w.Store(collection, id, data) kv.w.Store(collection, id, data)
} }
func (kv *KV) Delete(collection string, id uint64) { func (kv *KV) Delete(collection string, id uint64) {
if !kv.primary {
panic("Delete called on secondary.")
}
kv.w.Delete(collection, id) kv.w.Delete(collection, id)
} }
func (kv *KV) SyncSend(conn net.Conn) { func (kv *KV) CleanBefore(seconds int64) {
if !kv.primary { _, err := kv.db.Exec(sqlCleanQuery, time.Now().Unix()-seconds)
panic("SyncSend called on secondary.") must(err)
}
kv.f.SendWAL(conn)
}
func (kv *KV) SyncRecv(conn net.Conn) {
if kv.primary {
panic("SyncRecv called on primary.")
}
if !kv.shippingLock.TryLock() {
return
}
defer kv.shippingLock.Unlock()
kv.w.RecvWAL(conn)
}
func (kv *KV) background() {
defer kv.done.Done()
commitTicker := time.NewTicker(commitInterval)
defer commitTicker.Stop()
cleanTicker := time.NewTicker(cleanInterval)
defer cleanTicker.Stop()
for {
select {
case <-commitTicker.C:
kv.commit()
case <-cleanTicker.C:
kv.w.DeleteBefore(cleanBeforeSecs)
case <-kv.stop:
return
}
}
}
func (kv *KV) commit() {
maxSeqNum := kv.MaxSeqNum()
if maxSeqNum == kv.f.MaxSeqNum() {
return
}
tx, err := kv.db.Begin()
must(err)
upsert, err := tx.Prepare(sqlKVUpsert)
must(err)
delete, err := tx.Prepare(sqlKVDelete)
must(err)
var (
doUpsert func(wal.Record) error
doDelete func(wal.Record) error
)
if kv.primary {
doUpsert = func(rec wal.Record) (err error) {
_, err = upsert.Exec(rec.Collection, rec.ID, rec.Data)
return err
}
doDelete = func(rec wal.Record) (err error) {
_, err = delete.Exec(rec.Collection, rec.ID)
return err
}
} else {
doUpsert = func(rec wal.Record) (err error) {
kv.onStore(rec.Collection, rec.ID, rec.Data)
_, err = upsert.Exec(rec.Collection, rec.ID, rec.Data)
return err
}
doDelete = func(rec wal.Record) (err error) {
kv.onDelete(rec.Collection, rec.ID)
_, err = delete.Exec(rec.Collection, rec.ID)
return err
}
}
err = kv.f.Replay(maxSeqNum, func(rec wal.Record) error {
if rec.SeqNum != maxSeqNum+1 {
return fmt.Errorf("expected sequence number %d but got %d", maxSeqNum+1, rec.SeqNum)
}
if rec.Store {
err = doUpsert(rec)
} else {
err = doDelete(rec)
}
maxSeqNum = rec.SeqNum
return err
})
must(err)
_, err = tx.Exec(sqlMaxSeqNumSet, maxSeqNum)
must(err)
must(tx.Commit())
} }

View File

@ -2,8 +2,10 @@ package kvstore
import ( import (
"fmt" "fmt"
"log"
"math/rand" "math/rand"
"os" "os"
"reflect"
"testing" "testing"
"time" "time"
) )
@ -13,8 +15,8 @@ import (
func (kv *KV) waitForSeqNum(x uint64) { func (kv *KV) waitForSeqNum(x uint64) {
for { for {
seqNum := kv.MaxSeqNum() seqNum := kv.MaxSeqNum()
//log.Printf("%d/%d", seqNum, x) log.Printf("%d/%d", seqNum, x)
if seqNum == x { if seqNum >= x {
return return
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)
@ -48,6 +50,24 @@ func (kv *KV) equals(collection string, expected map[uint64]string) error {
} }
func (kv *KV) equalsKV(collection string, rhs *KV) error { func (kv *KV) equalsKV(collection string, rhs *KV) error {
l1 := []record{}
kv.replay(0, func(rec record) error {
l1 = append(l1, rec)
return nil
})
idx := -1
err := rhs.replay(0, func(rec record) error {
idx++
if !reflect.DeepEqual(rec, l1[idx]) {
return fmt.Errorf("Records not equal: %d %v %v", idx, rec, l1[idx])
}
return nil
})
if err != nil {
return err
}
return kv.equals(collection, rhs.dump(collection)) return kv.equals(collection, rhs.dump(collection))
} }

93
kvstore/sync-recv.go Normal file
View File

@ -0,0 +1,93 @@
package kvstore
import (
"encoding/binary"
"log"
"net"
"time"
)
func (kv *KV) SyncRecv(conn net.Conn) {
defer conn.Close()
if kv.primary {
panic("SyncRecv called on primary.")
}
if !kv.recvLock.TryLock() {
return
}
defer kv.recvLock.Unlock()
// It's important that we stop when this routine exits so that
// all queued writes are committed to the database before syncing
// has a chance to restart.
//kv.startWriteLoop()
//defer kv.stopWriteLoop()
w := newWriter(kv.db)
w.Start(kv.MaxSeqNum())
defer w.Stop()
headerBuf := make([]byte, recHeaderSize)
buf := make([]byte, 8)
afterSeqNum := kv.MaxSeqNum()
expectedSeqNum := afterSeqNum + 1
// Send fromID to the conn.
conn.SetWriteDeadline(time.Now().Add(connTimeout))
binary.LittleEndian.PutUint64(buf, afterSeqNum)
if _, err := conn.Write(buf); err != nil {
log.Printf("RecvWAL failed to send after sequence number: %v", err)
return
}
conn.SetWriteDeadline(time.Time{})
for {
conn.SetReadDeadline(time.Now().Add(connTimeout))
if _, err := conn.Read(headerBuf); err != nil {
log.Printf("RecvWAL failed to read header: %v", err)
return
}
rec, colLen, dataLen := decodeRecHeader(headerBuf)
// Heartbeat.
if rec.SeqNum == 0 {
continue
}
if rec.SeqNum != expectedSeqNum {
log.Printf("Expected sequence number %d but got %d.",
expectedSeqNum, rec.SeqNum)
return
}
expectedSeqNum++
if cap(buf) < colLen {
buf = make([]byte, colLen)
}
buf = buf[:colLen]
if _, err := conn.Read(buf); err != nil {
log.Printf("RecvWAL failed to read collection name: %v", err)
return
}
rec.Collection = string(buf)
if rec.Store {
rec.Data = make([]byte, dataLen)
if _, err := conn.Read(rec.Data); err != nil {
log.Printf("RecvWAL failed to read data: %v", err)
return
}
w.StoreAsync(rec.Collection, rec.ID, rec.Data)
kv.onStore(rec.Collection, rec.ID, rec.Data)
} else {
w.DeleteAsync(rec.Collection, rec.ID)
kv.onDelete(rec.Collection, rec.ID)
}
}
}

View File

@ -1,93 +1,38 @@
package wal package kvstore
import ( import (
"database/sql"
"encoding/binary" "encoding/binary"
"log" "log"
"net" "net"
"time" "time"
) )
type Record struct { func (kv *KV) SyncSend(conn net.Conn) {
SeqNum uint64
Collection string
ID uint64
Store bool
Data []byte
}
type Follower struct {
db *sql.DB
selectStmt *sql.Stmt
}
func NewFollower(walPath string) *Follower {
db := initWAL(walPath)
selectStmt, err := db.Prepare(sqlWALFollowQuery)
must(err)
return &Follower{
db: db,
selectStmt: selectStmt,
}
}
func (f *Follower) Close() {
f.db.Close()
}
func (f *Follower) MaxSeqNum() (n uint64) {
must(f.db.QueryRow(sqlWALMaxSeqNum).Scan(&n))
return
}
func (f *Follower) Replay(afterSeqNum uint64, each func(rec Record) error) error {
rec := Record{}
rows, err := f.selectStmt.Query(afterSeqNum)
must(err)
defer rows.Close()
for rows.Next() {
must(rows.Scan(
&rec.SeqNum,
&rec.Collection,
&rec.ID,
&rec.Store,
&rec.Data))
if err = each(rec); err != nil {
return err
}
}
return nil
}
func (f *Follower) SendWAL(conn net.Conn) {
defer conn.Close() defer conn.Close()
var ( var (
buf = make([]byte, 8) seqNumBuf = make([]byte, 8)
headerBuf = make([]byte, recHeaderSize) headerBuf = make([]byte, recHeaderSize)
empty = make([]byte, recHeaderSize) empty = make([]byte, recHeaderSize)
tStart time.Time tStart time.Time
err error err error
) )
// Read the fromID from the conn. // Read afterSeqNum from the conn.
conn.SetReadDeadline(time.Now().Add(16 * time.Second)) conn.SetReadDeadline(time.Now().Add(connTimeout))
if _, err := conn.Read(buf[:8]); err != nil { if _, err := conn.Read(seqNumBuf[:8]); err != nil {
log.Printf("SendWAL failed to read from ID: %v", err) log.Printf("SyncSend failed to read afterSeqNum: %v", err)
return return
} }
afterSeqNum := binary.LittleEndian.Uint64(buf[:8]) afterSeqNum := binary.LittleEndian.Uint64(seqNumBuf[:8])
POLL: POLL:
conn.SetWriteDeadline(time.Now().Add(connTimeout)) conn.SetWriteDeadline(time.Now().Add(connTimeout))
tStart = time.Now() tStart = time.Now()
for time.Since(tStart) < heartbeatInterval { for time.Since(tStart) < heartbeatInterval {
if f.MaxSeqNum() > afterSeqNum { if kv.MaxSeqNum() > afterSeqNum {
goto REPLAY goto REPLAY
} }
time.Sleep(pollInterval) time.Sleep(pollInterval)
@ -106,7 +51,7 @@ HEARTBEAT:
REPLAY: REPLAY:
err = f.Replay(afterSeqNum, func(rec Record) error { err = kv.replay(afterSeqNum, func(rec record) error {
conn.SetWriteDeadline(time.Now().Add(connTimeout)) conn.SetWriteDeadline(time.Now().Add(connTimeout))
afterSeqNum = rec.SeqNum afterSeqNum = rec.SeqNum
@ -139,3 +84,23 @@ REPLAY:
goto POLL goto POLL
} }
func (kv *KV) replay(afterSeqNum uint64, each func(rec record) error) error {
rec := record{}
rows, err := kv.logIterateStmt.Query(afterSeqNum)
must(err)
defer rows.Close()
for rows.Next() {
must(rows.Scan(
&rec.SeqNum,
&rec.Collection,
&rec.ID,
&rec.Store,
&rec.Data))
if err = each(rec); err != nil {
return err
}
}
return nil
}

19
kvstore/types.go Normal file
View File

@ -0,0 +1,19 @@
package kvstore
import "sync"
type modJob struct {
Collection string
ID uint64
Store bool
Data []byte
Ready *sync.WaitGroup
}
type record struct {
SeqNum uint64
Collection string
ID uint64
Store bool
Data []byte
}

View File

@ -2,13 +2,6 @@ package kvstore
import ( import (
"path/filepath" "path/filepath"
"time"
)
const (
commitInterval = 250 * time.Millisecond // How often to commit from WAL.
cleanInterval = time.Minute // How often to clean WAL.
cleanBeforeSecs = 86400 * 7 // Clean WAL entries older than.
) )
func must(err error) { func must(err error) {
@ -17,10 +10,6 @@ func must(err error) {
} }
} }
func dataPath(dir string) string { func dbPath(dir string) string {
return filepath.Join(dir, "data") return filepath.Join(dir, "db")
}
func walPath(dir string) string {
return filepath.Join(dir, "wal")
} }

View File

@ -1,39 +0,0 @@
package wal
const sqlSchema = `
BEGIN IMMEDIATE;
CREATE TABLE IF NOT EXISTS wal(
SeqNum INTEGER NOT NULL PRIMARY KEY,
CreatedAt INTEGER NOT NULL,
Collection TEXT NOT NULL,
ID INTEGER NOT NULL,
Store INTEGER NOT NULL,
Data BLOB
) WITHOUT ROWID;
CREATE INDEX IF NOT EXISTS wal_created_at_index ON wal(CreatedAt);
COMMIT;
`
const sqlWALMaxSeqNum = `
SELECT COALESCE(MAX(SeqNum), 0) FROM wal;
`
const sqlWALInsert = `
INSERT INTO wal(
SeqNum,CreatedAt,Collection,ID,Store,Data
) VALUES (?,?,?,?,?,?)`
const sqlWALFollowQuery = `
SELECT
SeqNum,Collection,ID,Store,Data
FROM
wal
WHERE
SeqNum > ?
ORDER BY SeqNum ASC`
const sqlWALDeleteQuery = `
DELETE FROM wal WHERE CreatedAt < ? AND SeqNum < (SELECT MAX(SeqNum) FROM wal)`

View File

@ -1,123 +0,0 @@
package wal
import (
"errors"
"os"
"testing"
"time"
)
// ----------------------------------------------------------------------------
func (f *Follower) getReplay(afterSeqNum uint64) (l []Record) {
f.Replay(afterSeqNum, func(rec Record) error {
l = append(l, rec)
return nil
})
return l
}
func (f *Follower) waitForSeqNum(n uint64) {
for {
maxSeqNum := f.MaxSeqNum()
//log.Printf("%d/%d", maxSeqNum, n)
if maxSeqNum == n {
return
}
time.Sleep(100 * time.Millisecond)
}
}
// ----------------------------------------------------------------------------
func TestFollower(t *testing.T) {
run := func(name string, inner func(t *testing.T, walPath string, w *Writer, f *Follower)) {
t.Run(name, func(t *testing.T) {
walPath := randPath() + ".wal"
defer os.RemoveAll(walPath)
w := newWriter(walPath, true)
defer w.Close()
f := NewFollower(walPath)
defer f.Close()
inner(t, walPath, w, f)
})
}
run("simple", func(t *testing.T, walPath string, w *Writer, f *Follower) {
w.Store("a", 1, _b("Hello"))
w.Delete("b", 1)
w.Store("a", 2, _b("World"))
w.Store("a", 1, _b("Good bye"))
expected := []Record{
{SeqNum: 1, Collection: "a", ID: 1, Store: true, Data: _b("Hello")},
{SeqNum: 2, Collection: "b", ID: 1},
{SeqNum: 3, Collection: "a", ID: 2, Store: true, Data: _b("World")},
{SeqNum: 4, Collection: "a", ID: 1, Store: true, Data: _b("Good bye")},
}
recs := f.getReplay(0)
if err := recsEqual(recs, expected); err != nil {
t.Fatal(err)
}
for i := 1; i < 4; i++ {
recs = f.getReplay(uint64(i))
if err := recsEqual(recs, expected[i:]); err != nil {
t.Fatal(err)
}
}
})
run("write async", func(t *testing.T, walPath string, w *Writer, f *Follower) {
w.storeAsync("a", 1, _b("hello1"))
w.storeAsync("a", 2, _b("hello2"))
w.deleteAsync("a", 1)
w.storeAsync("a", 3, _b("hello3"))
w.storeAsync("b", 1, _b("b1"))
f.waitForSeqNum(5)
expected := []Record{
{SeqNum: 1, Collection: "a", ID: 1, Store: true, Data: _b("hello1")},
{SeqNum: 2, Collection: "a", ID: 2, Store: true, Data: _b("hello2")},
{SeqNum: 3, Collection: "a", ID: 1, Store: false},
{SeqNum: 4, Collection: "a", ID: 3, Store: true, Data: _b("hello3")},
{SeqNum: 5, Collection: "b", ID: 1, Store: true, Data: _b("b1")},
}
recs := f.getReplay(0)
if err := recsEqual(recs, expected); err != nil {
t.Fatal(err)
}
for i := 1; i < 4; i++ {
recs = f.getReplay(uint64(i))
if err := recsEqual(recs, expected[i:]); err != nil {
t.Fatal(err)
}
}
})
run("replay error", func(t *testing.T, walPath string, w *Writer, f *Follower) {
expectedErr := errors.New("My error")
w.Store("a", 1, _b("Hello"))
w.Delete("b", 1)
w.Store("a", 2, _b("World"))
w.Store("a", 1, _b("Good bye"))
err := f.Replay(0, func(rec Record) error {
if rec.Collection == "b" {
return expectedErr
}
return nil
})
if err != expectedErr {
t.Fatal(err)
}
})
}

View File

@ -1,9 +0,0 @@
package wal
import "time"
var (
connTimeout = 16 * time.Second // For sending / receiving WAL.
heartbeatInterval = 2 * time.Second // Used in Follower.SendLog
pollInterval = 500 * time.Millisecond // Used in Follower.SendLog
)

View File

@ -1,20 +0,0 @@
package wal
import (
"database/sql"
"sync"
)
var initLock sync.Mutex
func initWAL(walPath string) *sql.DB {
initLock.Lock()
defer initLock.Unlock()
db, err := sql.Open("sqlite3", walPath+"?_journal=WAL")
must(err)
_, err = db.Exec(sqlSchema)
must(err)
return db
}

View File

@ -1,208 +0,0 @@
package wal
import (
"fmt"
"math/rand"
"os"
"testing"
"time"
"git.crumpington.com/private/mdb/testconn"
)
func TestShipping(t *testing.T) {
run := func(name string, inner func(
t *testing.T,
wWALPath string,
fWALPath string,
w *Writer,
nw *testconn.Network,
)) {
t.Run(name, func(t *testing.T) {
wWALPath := randPath() + ".wal"
fWALPath := randPath() + ".wal"
w := NewWriterPrimary(wWALPath)
defer w.Close()
nw := testconn.NewNetwork()
defer nw.CloseClient()
defer nw.CloseServer()
defer os.RemoveAll(wWALPath)
defer os.RemoveAll(fWALPath)
inner(t, wWALPath, fWALPath, w, nw)
})
}
run("simple", func(t *testing.T, wWALPath, fWALPath string, w *Writer, nw *testconn.Network) {
// Write 100 entries in background.
go func() {
for i := 0; i < 100; i++ {
time.Sleep(10 * time.Millisecond)
w.Store("x", uint64(i+10), _b(fmt.Sprintf("data %d", i)))
}
}()
// Run a sender in the background.
go func() {
f := NewFollower(wWALPath)
conn := nw.Accept()
f.SendWAL(conn)
}()
// Run the follower.
go func() {
w := NewWriterSecondary(fWALPath)
conn := nw.Dial()
w.RecvWAL(conn)
}()
time.Sleep(time.Second)
// Wait for follower to get 100 entries, then close connection.
f := NewFollower(fWALPath)
defer f.Close()
f.waitForSeqNum(100)
if err := walsEqual(wWALPath, fWALPath); err != nil {
t.Fatal(err)
}
})
run("net failures", func(t *testing.T, wWALPath, fWALPath string, w *Writer, nw *testconn.Network) {
defer nw.CloseClient()
defer nw.CloseServer()
N := 4000
sleepTime := time.Millisecond
go func() {
for i := 0; i < N; i++ {
time.Sleep(sleepTime)
if rand.Float64() < 0.9 {
w.Store(randString(), randID(), _b(randString()))
} else {
w.Delete(randString(), randID())
}
}
}()
// Run a sender in the background.
go func() {
sender := NewFollower(wWALPath)
f := NewFollower(fWALPath)
defer f.Close()
for f.MaxSeqNum() < uint64(N) {
if conn := nw.Accept(); conn != nil {
sender.SendWAL(conn)
}
}
}()
// Run the follower in the background.
go func() {
f := NewFollower(fWALPath)
defer f.Close()
w := NewWriterSecondary(fWALPath)
for f.MaxSeqNum() < uint64(N) {
if conn := nw.Dial(); conn != nil {
w.RecvWAL(conn)
}
}
}()
// Disconnect the network randomly.
go func() {
f := NewFollower(fWALPath)
defer f.Close()
for f.MaxSeqNum() < uint64(N) {
time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime))))
if rand.Float64() < 0.5 {
nw.CloseClient()
} else {
nw.CloseServer()
}
}
}()
f := NewFollower(fWALPath)
defer f.Close()
// Wait for follower to get 100 entries, then close connection.
f.waitForSeqNum(uint64(N))
if err := walsEqual(wWALPath, fWALPath); err != nil {
t.Fatal(err)
}
})
run("secondary too far behind", func(t *testing.T, wWALPath, fWALPath string, w *Writer, nw *testconn.Network) {
// Write some entries to the primary.
// MaxSeqNum will be 10.
for i := 0; i < 10; i++ {
w.Store(randString(), randID(), _b(randString()))
}
// Delete everything.
w.DeleteBefore(-1)
// Run a sender in the background.
go func() {
f := NewFollower(wWALPath)
defer f.Close()
conn := nw.Accept()
f.SendWAL(conn)
}()
// Run the follower.
go func() {
w := NewWriterSecondary(fWALPath)
defer w.Close()
conn := nw.Dial()
w.RecvWAL(conn)
}()
time.Sleep(time.Second)
f := NewFollower(fWALPath)
defer f.Close()
if f.MaxSeqNum() != 0 {
t.Fatal(f.MaxSeqNum())
}
})
}
func TestShippingEncoding(t *testing.T) {
recs := []Record{
{SeqNum: 10, Collection: "x", ID: 44, Store: true, Data: _b("Hello")},
{SeqNum: 24, Collection: "abc", ID: 3, Store: true, Data: _b("x")},
{SeqNum: 81, Collection: "qrs", ID: 102, Store: false},
}
buf := make([]byte, recHeaderSize)
for _, rec := range recs {
encodeRecordHeader(rec, buf)
out, colLen, dataLen := decodeRecHeader(buf)
if out.SeqNum != rec.SeqNum {
t.Fatal(out, rec)
}
if out.ID != rec.ID {
t.Fatal(out, rec)
}
if out.Store != rec.Store {
t.Fatal(out, rec)
}
if colLen != len(rec.Collection) {
t.Fatal(out, rec)
}
if dataLen != len(rec.Data) {
t.Fatal(out, rec)
}
}
}

View File

@ -1,7 +0,0 @@
package wal
func must(err error) {
if err != nil {
panic(err)
}
}

View File

@ -1,66 +0,0 @@
package wal
import (
"crypto/rand"
"encoding/hex"
"fmt"
mrand "math/rand"
"os"
"path/filepath"
"reflect"
)
func _b(in string) []byte {
return []byte(in)
}
func randString() string {
buf := make([]byte, 1+mrand.Intn(20))
rand.Read(buf)
return hex.EncodeToString(buf)
}
func randID() uint64 {
return uint64(mrand.Uint32())
}
func randPath() string {
buf := make([]byte, 8)
rand.Read(buf)
return filepath.Join(os.TempDir(), hex.EncodeToString(buf))
}
func readWAL(walPath string) (l []Record) {
f := NewFollower(walPath)
defer f.Close()
f.Replay(0, func(rec Record) error {
l = append(l, rec)
return nil
})
return l
}
func walEqual(walPath string, expected []Record) error {
recs := readWAL(walPath)
return recsEqual(recs, expected)
}
func recsEqual(recs, expected []Record) error {
if len(recs) != len(expected) {
return fmt.Errorf("Expected %d records but found %d",
len(expected), len(recs))
}
for i, rec := range recs {
exp := expected[i]
if !reflect.DeepEqual(rec, exp) {
return fmt.Errorf("Mismatched records: %v != %v", rec, exp)
}
}
return nil
}
func walsEqual(path1, path2 string) error {
return recsEqual(readWAL(path1), readWAL(path2))
}

View File

@ -1,111 +0,0 @@
package wal
import (
"database/sql"
"sync"
"time"
)
func (w *Writer) start() {
w.lock.Lock()
defer w.lock.Unlock()
if w.running {
return
}
w.insertQ = make(chan insertJob, 1024)
var maxSeqNum uint64
row := w.db.QueryRow(sqlWALMaxSeqNum)
must(row.Scan(&maxSeqNum))
w.doneWG.Add(1)
go w.insertProc(maxSeqNum)
w.running = true
}
func (w *Writer) stop() {
w.lock.Lock()
defer w.lock.Unlock()
if !w.running {
return
}
close(w.insertQ)
w.doneWG.Wait()
w.running = false
}
func (w *Writer) insertProc(maxSeqNum uint64) {
defer w.doneWG.Done()
var (
job insertJob
tx *sql.Tx
insert *sql.Stmt
ok bool
err error
newSeqNum uint64
now int64
wgs = make([]*sync.WaitGroup, 10)
)
var ()
BEGIN:
newSeqNum = maxSeqNum
wgs = wgs[:0]
job, ok = <-w.insertQ
if !ok {
return
}
tx, err = w.db.Begin()
must(err)
insert, err = tx.Prepare(sqlWALInsert)
must(err)
now = time.Now().Unix()
LOOP:
newSeqNum++
_, err = insert.Exec(
newSeqNum,
now,
job.Collection,
job.ID,
job.Store,
job.Data)
must(err)
if job.Ready != nil {
wgs = append(wgs, job.Ready)
}
select {
case job, ok = <-w.insertQ:
if ok {
goto LOOP
}
default:
}
goto COMMIT
COMMIT:
must(tx.Commit())
maxSeqNum = newSeqNum
for i := range wgs {
wgs[i].Done()
}
goto BEGIN
}

View File

@ -1,210 +0,0 @@
package wal
import (
"database/sql"
"encoding/binary"
"log"
"net"
"sync"
"time"
_ "github.com/mattn/go-sqlite3"
)
type insertJob struct {
Collection string
ID uint64
Store bool
Data []byte
Ready *sync.WaitGroup
}
type Writer struct {
primary bool
db *sql.DB
insert *sql.Stmt
lock sync.Mutex
running bool
insertQ chan insertJob
doneWG sync.WaitGroup
recvLock sync.Mutex
}
func NewWriterPrimary(walPath string) *Writer {
return newWriter(walPath, true)
}
func NewWriterSecondary(walPath string) *Writer {
return newWriter(walPath, false)
}
func newWriter(walPath string, primary bool) *Writer {
db := initWAL(walPath)
insert, err := db.Prepare(sqlWALInsert)
must(err)
w := &Writer{
primary: primary,
db: db,
insert: insert,
}
if primary {
w.start()
}
return w
}
func (w *Writer) Close() {
if w.db == nil {
return
}
w.stop()
w.db.Close()
w.db = nil
}
func (w *Writer) Store(collection string, id uint64, data []byte) {
if !w.primary {
panic("Store called on secondary.")
}
job := insertJob{
Collection: collection,
ID: id,
Store: true,
Data: data,
Ready: &sync.WaitGroup{},
}
job.Ready.Add(1)
w.insertQ <- job
job.Ready.Wait()
}
func (w *Writer) Delete(collection string, id uint64) {
if !w.primary {
panic("Delete called on secondary.")
}
job := insertJob{
Collection: collection,
ID: id,
Store: false,
Ready: &sync.WaitGroup{},
}
job.Ready.Add(1)
w.insertQ <- job
job.Ready.Wait()
}
// Called single-threaded from RecvWAL.
func (w *Writer) storeAsync(collection string, id uint64, data []byte) {
w.insertQ <- insertJob{
Collection: collection,
ID: id,
Store: true,
Data: data,
}
}
// Called single-threaded from RecvWAL.
func (w *Writer) deleteAsync(collection string, id uint64) {
w.insertQ <- insertJob{
Collection: collection,
ID: id,
Store: false,
}
}
func (w *Writer) MaxSeqNum() (n uint64) {
w.db.QueryRow(sqlWALMaxSeqNum).Scan(&n)
return
}
func (w *Writer) RecvWAL(conn net.Conn) {
defer conn.Close()
if w.primary {
panic("RecvWAL called on primary.")
}
if !w.recvLock.TryLock() {
log.Printf("Multiple calls to RecvWAL. Dropping connection.")
return
}
defer w.recvLock.Unlock()
headerBuf := make([]byte, recHeaderSize)
buf := make([]byte, 8)
afterSeqNum := w.MaxSeqNum()
expectedSeqNum := afterSeqNum + 1
// Send fromID to the conn.
conn.SetWriteDeadline(time.Now().Add(connTimeout))
binary.LittleEndian.PutUint64(buf, afterSeqNum)
if _, err := conn.Write(buf); err != nil {
log.Printf("RecvWAL failed to send after sequence number: %v", err)
return
}
conn.SetWriteDeadline(time.Time{})
// Start processing inserts.
w.start()
defer w.stop()
for {
conn.SetReadDeadline(time.Now().Add(connTimeout))
if _, err := conn.Read(headerBuf); err != nil {
log.Printf("RecvWAL failed to read header: %v", err)
return
}
rec, colLen, dataLen := decodeRecHeader(headerBuf)
// Heartbeat.
if rec.SeqNum == 0 {
continue
}
if rec.SeqNum != expectedSeqNum {
log.Printf("Expected sequence number %d but got %d.",
expectedSeqNum, rec.SeqNum)
return
}
expectedSeqNum++
if cap(buf) < colLen {
buf = make([]byte, colLen)
}
buf = buf[:colLen]
if _, err := conn.Read(buf); err != nil {
log.Printf("RecvWAL failed to collection name: %v", err)
return
}
rec.Collection = string(buf)
if rec.Store {
rec.Data = make([]byte, dataLen)
if _, err := conn.Read(rec.Data); err != nil {
log.Printf("RecvWAL failed to data: %v", err)
return
}
}
if rec.Store {
w.storeAsync(rec.Collection, rec.ID, rec.Data)
} else {
w.deleteAsync(rec.Collection, rec.ID)
}
}
}
func (w *Writer) DeleteBefore(seconds int64) {
_, err := w.db.Exec(sqlWALDeleteQuery, time.Now().Unix()-seconds)
must(err)
}

View File

@ -1,208 +0,0 @@
package wal
import (
"bytes"
"fmt"
"os"
"strconv"
"sync"
"testing"
"time"
)
// ----------------------------------------------------------------------------
func (w *Writer) waitForSeqNum(n uint64) {
for {
if w.MaxSeqNum() == n {
return
}
time.Sleep(time.Millisecond)
}
}
// ----------------------------------------------------------------------------
func TestWriter(t *testing.T) {
run := func(name string, inner func(t *testing.T, walPath string, w *Writer)) {
t.Run(name, func(t *testing.T) {
walPath := randPath() + ".wal"
defer os.RemoveAll(walPath)
w := newWriter(walPath, true)
defer w.Close()
inner(t, walPath, w)
})
}
run("simple", func(t *testing.T, walPath string, w *Writer) {
w.Store("a", 1, _b("Hello"))
w.Delete("b", 1)
w.Store("a", 2, _b("World"))
w.Store("a", 1, _b("Good bye"))
err := walEqual(walPath, []Record{
{SeqNum: 1, Collection: "a", ID: 1, Store: true, Data: _b("Hello")},
{SeqNum: 2, Collection: "b", ID: 1},
{SeqNum: 3, Collection: "a", ID: 2, Store: true, Data: _b("World")},
{SeqNum: 4, Collection: "a", ID: 1, Store: true, Data: _b("Good bye")},
})
if err != nil {
t.Fatal(err)
}
})
run("write close write", func(t *testing.T, walPath string, w *Writer) {
w.Store("a", 1, _b("Hello"))
w.Close()
w = newWriter(walPath, true)
w.Delete("b", 1)
w.Close()
w = newWriter(walPath, true)
w.Store("a", 2, _b("World"))
w.Close()
w = newWriter(walPath, true)
w.Store("a", 1, _b("Good bye"))
err := walEqual(walPath, []Record{
{SeqNum: 1, Collection: "a", ID: 1, Store: true, Data: _b("Hello")},
{SeqNum: 2, Collection: "b", ID: 1},
{SeqNum: 3, Collection: "a", ID: 2, Store: true, Data: _b("World")},
{SeqNum: 4, Collection: "a", ID: 1, Store: true, Data: _b("Good bye")},
})
if err != nil {
t.Fatal(err)
}
})
run("write concurrent", func(t *testing.T, walPath string, w *Writer) {
N := 32
wg := sync.WaitGroup{}
expected := make([][]Record, N)
for i := 0; i < N; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
collection := fmt.Sprintf("%d", i)
for j := 0; j < 1024; j++ {
rec := Record{
Collection: collection,
ID: uint64(j + 1),
Store: true,
Data: _b(fmt.Sprintf("%d", j)),
}
w.Store(rec.Collection, rec.ID, rec.Data)
expected[i] = append(expected[i], rec)
}
}(i)
}
wg.Wait()
recs := readWAL(walPath)
found := make([][]Record, N)
for _, rec := range recs {
rec := rec
index, err := strconv.ParseInt(rec.Collection, 10, 64)
if err != nil {
t.Fatal(err)
}
found[index] = append(found[index], rec)
}
if len(found) != len(expected) {
t.Fatal(len(found), len(expected))
}
for i := range found {
fnd := found[i]
exp := expected[i]
if len(fnd) != len(exp) {
t.Fatal(i, len(fnd), len(exp))
}
for j := range fnd {
f := fnd[j]
e := exp[j]
ok := f.Collection == e.Collection &&
f.ID == e.ID &&
f.Store == e.Store &&
bytes.Equal(f.Data, e.Data)
if !ok {
t.Fatal(i, j, f, e)
}
}
}
})
run("store delete async", func(t *testing.T, walPath string, w *Writer) {
w.storeAsync("a", 1, _b("hello1"))
w.storeAsync("a", 2, _b("hello2"))
w.deleteAsync("a", 1)
w.storeAsync("a", 3, _b("hello3"))
w.storeAsync("b", 1, _b("b1"))
w.waitForSeqNum(5)
err := walEqual(walPath, []Record{
{SeqNum: 1, Collection: "a", ID: 1, Store: true, Data: _b("hello1")},
{SeqNum: 2, Collection: "a", ID: 2, Store: true, Data: _b("hello2")},
{SeqNum: 3, Collection: "a", ID: 1, Store: false},
{SeqNum: 4, Collection: "a", ID: 3, Store: true, Data: _b("hello3")},
{SeqNum: 5, Collection: "b", ID: 1, Store: true, Data: _b("b1")},
})
if err != nil {
t.Fatal(err)
}
})
run("store delete async with close", func(t *testing.T, walPath string, w *Writer) {
w.storeAsync("a", 1, _b("hello1"))
w.Close()
w = newWriter(walPath, true)
w.storeAsync("a", 2, _b("hello2"))
w.Close()
w = newWriter(walPath, true)
w.deleteAsync("a", 1)
w.Close()
w = newWriter(walPath, true)
w.storeAsync("a", 3, _b("hello3"))
w.Close()
w = newWriter(walPath, true)
w.storeAsync("b", 1, _b("b1"))
w.Close()
w = newWriter(walPath, true)
w.waitForSeqNum(5)
err := walEqual(walPath, []Record{
{SeqNum: 1, Collection: "a", ID: 1, Store: true, Data: _b("hello1")},
{SeqNum: 2, Collection: "a", ID: 2, Store: true, Data: _b("hello2")},
{SeqNum: 3, Collection: "a", ID: 1, Store: false},
{SeqNum: 4, Collection: "a", ID: 3, Store: true, Data: _b("hello3")},
{SeqNum: 5, Collection: "b", ID: 1, Store: true, Data: _b("b1")},
})
if err != nil {
t.Fatal(err)
}
})
// This is really just a benchmark.
run("store async many", func(t *testing.T, walPath string, w *Writer) {
N := 32768
for i := 0; i < N; i++ {
w.storeAsync("a", 1, _b("x"))
}
w.waitForSeqNum(uint64(N))
})
}

187
kvstore/writer.go Normal file
View File

@ -0,0 +1,187 @@
package kvstore
import (
"database/sql"
"sync"
"time"
)
type writer struct {
db *sql.DB
modQ chan modJob
stop chan struct{}
wg sync.WaitGroup
}
func newWriter(db *sql.DB) *writer {
return &writer{
db: db,
stop: make(chan struct{}, 1),
modQ: make(chan modJob, 1024),
}
}
func (w *writer) Start(maxSeqNum uint64) {
w.wg.Add(1)
go w.run(maxSeqNum)
}
func (w *writer) Stop() {
select {
case w.stop <- struct{}{}:
default:
}
w.wg.Wait()
}
func (w *writer) Store(collection string, id uint64, data []byte) {
job := modJob{
Collection: collection,
ID: id,
Store: true,
Data: data,
Ready: &sync.WaitGroup{},
}
job.Ready.Add(1)
w.modQ <- job
job.Ready.Wait()
}
func (w *writer) Delete(collection string, id uint64) {
job := modJob{
Collection: collection,
ID: id,
Store: false,
Ready: &sync.WaitGroup{},
}
job.Ready.Add(1)
w.modQ <- job
job.Ready.Wait()
}
func (w *writer) StoreAsync(collection string, id uint64, data []byte) {
w.modQ <- modJob{
Collection: collection,
ID: id,
Store: true,
Data: data,
}
}
func (w *writer) DeleteAsync(collection string, id uint64) {
w.modQ <- modJob{
Collection: collection,
ID: id,
Store: false,
}
}
func (w *writer) run(maxSeqNum uint64) {
defer w.wg.Done()
var (
job modJob
tx *sql.Tx
insertData *sql.Stmt
insertKV *sql.Stmt
deleteData *sql.Stmt
deleteKV *sql.Stmt
insertLog *sql.Stmt
err error
newSeqNum uint64
now int64
wgs = make([]*sync.WaitGroup, 10)
)
BEGIN:
insertData = nil
deleteData = nil
newSeqNum = maxSeqNum
wgs = wgs[:0]
select {
case job = <-w.modQ:
case <-w.stop:
return
}
tx, err = w.db.Begin()
must(err)
now = time.Now().Unix()
insertLog, err = tx.Prepare(sqlInsertLog)
must(err)
LOOP:
if job.Ready != nil {
wgs = append(wgs, job.Ready)
}
newSeqNum++
if job.Store {
goto STORE
} else {
goto DELETE
}
STORE:
if insertData == nil {
insertData, err = tx.Prepare(sqlInsertData)
must(err)
insertKV, err = tx.Prepare(sqlInsertKV)
must(err)
}
_, err = insertData.Exec(newSeqNum, job.Data)
must(err)
_, err = insertKV.Exec(job.Collection, job.ID, newSeqNum)
must(err)
_, err = insertLog.Exec(newSeqNum, now, job.Collection, job.ID, true)
must(err)
goto NEXT
DELETE:
if deleteData == nil {
deleteData, err = tx.Prepare(sqlDeleteData)
must(err)
deleteKV, err = tx.Prepare(sqlDeleteKV)
must(err)
}
_, err = deleteData.Exec(job.Collection, job.ID)
must(err)
_, err = deleteKV.Exec(job.Collection, job.ID)
must(err)
_, err = insertLog.Exec(newSeqNum, now, job.Collection, job.ID, false)
must(err)
goto NEXT
NEXT:
select {
case job = <-w.modQ:
goto LOOP
default:
}
goto COMMIT
COMMIT:
must(tx.Commit())
maxSeqNum = newSeqNum
for i := range wgs {
wgs[i].Done()
}
goto BEGIN
}

View File

@ -143,8 +143,8 @@ func (idx *MapIndex[K, T]) updateConflict(new *T) bool {
if idx.include != nil && !idx.include(new) { if idx.include != nil && !idx.include(new) {
return false return false
} }
val, ok := idx.mapGet(idx.getKey(new)) cur, ok := idx.mapGet(idx.getKey(new))
return ok && idx.getID(val) != idx.getID(new) return ok && idx.getID(cur) != idx.getID(new)
} }
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------

View File

@ -101,7 +101,7 @@ func TestShipping(t *testing.T) {
} }
}) })
run("flakey network", func(t *testing.T, db, db2 *DB, network *testconn.Network) { run("unstable network", func(t *testing.T, db, db2 *DB, network *testconn.Network) {
sleepTimeout := time.Millisecond sleepTimeout := time.Millisecond
updateWG := sync.WaitGroup{} updateWG := sync.WaitGroup{}
@ -131,9 +131,7 @@ func TestShipping(t *testing.T) {
for { for {
// Stop when no longer updating and WAL files match. // Stop when no longer updating and WAL files match.
if !updating.Get() { if !updating.Get() {
ws := db.WALStatus() if db.MaxSeqNum() == db2.MaxSeqNum() {
ws2 := db2.WALStatus()
if ws.MaxSeqNumWAL == ws2.MaxSeqNumWAL {
recving.Set(false) recving.Set(false)
return return
} }
@ -152,9 +150,7 @@ func TestShipping(t *testing.T) {
for { for {
// Stop when no longer updating and WAL files match. // Stop when no longer updating and WAL files match.
if !updating.Get() { if !updating.Get() {
ws := db.WALStatus() if db.MaxSeqNum() == db2.MaxSeqNum() {
ws2 := db2.WALStatus()
if ws.MaxSeqNumWAL == ws2.MaxSeqNumWAL {
sending.Set(false) sending.Set(false)
return return
} }

View File

@ -100,11 +100,13 @@ func OpenDB(root string, primary bool) *DB {
db.Users.emailBTree = NewBTreeIndex( db.Users.emailBTree = NewBTreeIndex(
db.Users.c, db.Users.c,
"email-bt",
func(lhs, rhs *User) bool { return lhs.Email < rhs.Email }, func(lhs, rhs *User) bool { return lhs.Email < rhs.Email },
nil) nil)
db.Users.nameBTree = NewBTreeIndex( db.Users.nameBTree = NewBTreeIndex(
db.Users.c, db.Users.c,
"name-bt",
func(lhs, rhs *User) bool { func(lhs, rhs *User) bool {
if lhs.Name != rhs.Name { if lhs.Name != rhs.Name {
return lhs.Name < rhs.Name return lhs.Name < rhs.Name
@ -121,6 +123,7 @@ func OpenDB(root string, primary bool) *DB {
db.Users.extIDBTree = NewBTreeIndex( db.Users.extIDBTree = NewBTreeIndex(
db.Users.c, db.Users.c,
"extid-bt",
func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID }, func(lhs, rhs *User) bool { return lhs.ExtID < rhs.ExtID },
func(u *User) bool { return u.ExtID != "" }) func(u *User) bool { return u.ExtID != "" })
@ -187,9 +190,7 @@ func (db *DB) Equals(rhs *DB) error {
// Wait for two databases to become synchronized. // Wait for two databases to become synchronized.
func (db *DB) WaitForSync(rhs *DB) { func (db *DB) WaitForSync(rhs *DB) {
for { for {
s1 := db.WALStatus() if db.MaxSeqNum() == rhs.MaxSeqNum() {
s2 := rhs.WALStatus()
if s1.MaxSeqNumKV == s1.MaxSeqNumWAL && s1.MaxSeqNumKV == s2.MaxSeqNumKV {
return return
} }
time.Sleep(100 * time.Millisecond) time.Sleep(100 * time.Millisecond)