master
jdl 2022-07-26 09:35:30 +02:00
parent 9c8b2c1e41
commit f1d34b0370
8 changed files with 630 additions and 0 deletions

34
kvstore/db-sql.go Normal file
View File

@ -0,0 +1,34 @@
package kvstore
const sqlSchema = `
BEGIN IMMEDIATE;
CREATE TABLE IF NOT EXISTS kv(
Collection TEXT NOT NULL,
ID INTEGER NOT NULL,
Data BLOB,
PRIMARY KEY (Collection, ID)
) WITHOUT ROWID;
CREATE TABLE IF NOT EXISTS maxSeqNum(
ID INTEGER NOT NULL PRIMARY KEY,
Value INTEGER NOT NULL
) WITHOUT ROWID;
INSERT INTO maxSeqNum(ID, Value) VALUES (1, 0) ON CONFLICT DO NOTHING;
COMMIT;`
const sqlKVUpsert = `
INSERT INTO kv
(Collection,ID,Data) VALUES (?,?,?)
ON CONFLICT(Collection,ID) DO UPDATE SET
Data=excluded.Data
WHERE
ID=excluded.ID`
const sqlKVDelete = `DELETE FROM kv WHERE Collection=? AND ID=?`
const sqlKVIterate = `SELECT ID,Data FROM kv WHERE Collection=?`
const sqlMaxSeqNumGet = `SELECT Value FROM maxSeqNum WHERE ID=1`
const sqlMaxSeqNumSet = `UPDATE maxSeqNum SET Value=? WHERE ID=1`

5
kvstore/dep-graph.sh Executable file
View File

@ -0,0 +1,5 @@
#!/bin/bash
godepgraph -s . > .deps.dot && xdot .deps.dot
rm .deps.dot

13
kvstore/main_test.go Normal file
View File

@ -0,0 +1,13 @@
package kvstore
import (
"math/rand"
"os"
"testing"
"time"
)
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
os.Exit(m.Run())
}

183
kvstore/shipping_test.go Normal file
View File

@ -0,0 +1,183 @@
package kvstore
import (
"math/rand"
"os"
"sync"
"testing"
"time"
"git.crumpington.com/private/mdb/testconn"
)
// ----------------------------------------------------------------------------
// Stores info from secondary callbacks.
type callbacks struct {
lock sync.Mutex
m map[string]map[uint64]string
}
func (sc *callbacks) onStore(c string, id uint64, data []byte) {
sc.lock.Lock()
defer sc.lock.Unlock()
if _, ok := sc.m[c]; !ok {
sc.m[c] = map[uint64]string{}
}
sc.m[c][id] = string(data)
}
func (sc *callbacks) onDelete(c string, id uint64) {
sc.lock.Lock()
defer sc.lock.Unlock()
if _, ok := sc.m[c]; !ok {
return
}
delete(sc.m[c], id)
}
// ----------------------------------------------------------------------------
func TestShipping(t *testing.T) {
run := func(name string, inner func(
t *testing.T,
pDir string,
sDir string,
primary *KV,
secondary *KV,
cbs *callbacks,
nw *testconn.Network,
)) {
t.Run(name, func(t *testing.T) {
pDir, _ := os.MkdirTemp("", "")
defer os.RemoveAll(pDir)
sDir, _ := os.MkdirTemp("", "")
defer os.RemoveAll(sDir)
nw := testconn.NewNetwork()
defer func() {
nw.CloseServer()
nw.CloseClient()
}()
cbs := &callbacks{
m: map[string]map[uint64]string{},
}
prim := NewPrimary(pDir)
defer prim.Close()
sec := NewSecondary(sDir, cbs.onStore, cbs.onDelete)
defer sec.Close()
inner(t, pDir, sDir, prim, sec, cbs, nw)
})
}
run("simple", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
M := 10
N := 100
wg := sync.WaitGroup{}
// Store M*N values in the background.
for i := 0; i < M; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for i := 0; i < N; i++ {
time.Sleep(10 * time.Millisecond)
prim.randAction()
}
}()
}
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Accept()
prim.SyncSend(conn)
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Dial()
sec.SyncRecv(conn)
}()
sec.waitForSeqNum(uint64(M * N))
nw.CloseServer()
nw.CloseClient()
wg.Wait()
prim.equalsKV("a", sec)
prim.equalsKV("b", sec)
prim.equalsKV("c", sec)
})
run("net failures", func(t *testing.T, pDir, sDir string, prim, sec *KV, cbs *callbacks, nw *testconn.Network) {
M := 10
N := 1000
sleepTime := time.Millisecond
wg := sync.WaitGroup{}
// Store M*N values in the background.
for i := 0; i < M; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < N; j++ {
time.Sleep(sleepTime)
prim.randAction()
}
}()
}
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
for sec.MaxSeqNum() < uint64(M*N) {
if conn := nw.Accept(); conn != nil {
prim.SyncSend(conn)
}
}
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
for sec.MaxSeqNum() < uint64(M*N) {
if conn := nw.Dial(); conn != nil {
sec.SyncRecv(conn)
}
}
}()
wg.Add(1)
go func() {
defer wg.Done()
for sec.MaxSeqNum() < uint64(M*N) {
time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime))))
if rand.Float64() < 0.5 {
nw.CloseClient()
} else {
nw.CloseServer()
}
}
}()
sec.waitForSeqNum(uint64(M * N))
wg.Wait()
prim.equalsKV("a", sec)
prim.equalsKV("b", sec)
prim.equalsKV("c", sec)
})
}

245
kvstore/store.go Normal file
View File

@ -0,0 +1,245 @@
package kvstore
import (
"database/sql"
"fmt"
"net"
"os"
"sync"
"time"
"git.crumpington.com/private/mdb/wal"
"golang.org/x/sys/unix"
_ "github.com/mattn/go-sqlite3"
)
type KV struct {
primary bool
lockPath string
dataPath string
walPath string
w *wal.Writer
f *wal.Follower
lock *os.File
db *sql.DB
stop chan struct{}
done sync.WaitGroup
onStore func(string, uint64, []byte)
onDelete func(string, uint64)
closeLock sync.Mutex
shippingLock sync.Mutex
}
func (kv *KV) init() {
// Acquire the lock.
lock, err := os.OpenFile(kv.lockPath, os.O_RDWR|os.O_CREATE, 0600)
must(err)
must(unix.Flock(int(lock.Fd()), unix.LOCK_EX))
kv.lock = lock
opts := `?_journal=WAL`
db, err := sql.Open("sqlite3", kv.dataPath+opts)
must(err)
_, err = db.Exec(sqlSchema)
must(err)
if kv.primary {
kv.w = wal.NewWriterPrimary(kv.walPath)
} else {
kv.w = wal.NewWriterSecondary(kv.walPath)
}
kv.f = wal.NewFollower(kv.walPath)
kv.db = db
kv.stop = make(chan struct{})
kv.commit()
}
func (kv *KV) start() {
// Spawn follower in background to write data from WAL to data.
kv.done.Add(1)
go kv.background()
}
func NewPrimary(dir string) *KV {
kv := &KV{
primary: true,
lockPath: lockPath(dir),
dataPath: dataPath(dir),
walPath: walPath(dir),
}
kv.init()
kv.start()
return kv
}
func NewSecondary(
dir string,
onStore func(collection string, id uint64, data []byte),
onDelete func(collection string, id uint64),
) *KV {
kv := &KV{
primary: false,
lockPath: lockPath(dir),
dataPath: dataPath(dir),
walPath: walPath(dir),
}
kv.init()
kv.onStore = onStore
kv.onDelete = onDelete
kv.start()
return kv
}
func (kv *KV) Primary() bool {
return kv.primary
}
func (kv *KV) MaxSeqNum() (seqNum uint64) {
kv.db.QueryRow(sqlMaxSeqNumGet).Scan(&seqNum)
return seqNum
}
func (kv *KV) WALMaxSeqNum() uint64 {
return kv.w.MaxSeqNum()
}
func (kv *KV) Iterate(collection string, each func(id uint64, data []byte)) {
rows, err := kv.db.Query(sqlKVIterate, collection)
must(err)
defer rows.Close()
var (
id uint64
data []byte
)
for rows.Next() {
must(rows.Scan(&id, &data))
each(id, data)
}
}
func (kv *KV) Close() {
kv.closeLock.Lock()
defer kv.closeLock.Unlock()
if kv.w == nil {
return
}
kv.stop <- struct{}{}
kv.done.Wait()
kv.w.Close()
kv.f.Close()
kv.db.Close()
kv.lock.Close()
kv.w = nil
kv.f = nil
kv.db = nil
kv.lock = nil
}
func (kv *KV) Store(collection string, id uint64, data []byte) {
if !kv.primary {
panic("Store called on secondary.")
}
kv.w.Store(collection, id, data)
}
func (kv *KV) Delete(collection string, id uint64) {
if !kv.primary {
panic("Delete called on secondary.")
}
kv.w.Delete(collection, id)
}
func (kv *KV) SyncSend(conn net.Conn) {
if !kv.primary {
panic("SyncSend called on secondary.")
}
kv.f.SendWAL(conn)
}
func (kv *KV) SyncRecv(conn net.Conn) {
if kv.primary {
panic("SyncRecv called on primary.")
}
if !kv.shippingLock.TryLock() {
return
}
defer kv.shippingLock.Unlock()
kv.w.RecvWAL(conn)
}
func (kv *KV) background() {
defer kv.done.Done()
commitTicker := time.NewTicker(commitInterval)
defer commitTicker.Stop()
cleanTicker := time.NewTicker(cleanInterval)
defer cleanTicker.Stop()
for {
select {
case <-commitTicker.C:
kv.commit()
case <-cleanTicker.C:
kv.w.DeleteBefore(cleanBeforeSecs)
case <-kv.stop:
return
}
}
}
func (kv *KV) commit() {
maxSeqNum := kv.MaxSeqNum()
if maxSeqNum == kv.f.MaxSeqNum() {
return
}
tx, err := kv.db.Begin()
must(err)
upsert, err := tx.Prepare(sqlKVUpsert)
must(err)
delete, err := tx.Prepare(sqlKVDelete)
must(err)
err = kv.f.Replay(maxSeqNum, func(rec wal.Record) error {
if rec.SeqNum != maxSeqNum+1 {
return fmt.Errorf("expected sequence number %d but got %d", maxSeqNum+1, rec.SeqNum)
}
if rec.Store {
_, err = upsert.Exec(rec.Collection, rec.ID, rec.Data)
} else {
_, err = delete.Exec(rec.Collection, rec.ID)
}
maxSeqNum = rec.SeqNum
return err
})
must(err)
_, err = tx.Exec(sqlMaxSeqNumSet, maxSeqNum)
must(err)
must(tx.Commit())
}

99
kvstore/store_test.go Normal file
View File

@ -0,0 +1,99 @@
package kvstore
import (
"fmt"
"math/rand"
"os"
"testing"
"time"
)
// ----------------------------------------------------------------------------
func (kv *KV) waitForSeqNum(x uint64) {
for {
seqNum := kv.MaxSeqNum()
//log.Printf("%d/%d", seqNum, x)
if seqNum == x {
return
}
time.Sleep(100 * time.Millisecond)
}
}
func (kv *KV) dump(collection string) map[uint64]string {
m := map[uint64]string{}
kv.Iterate(collection, func(id uint64, data []byte) {
m[id] = string(data)
})
return m
}
func (kv *KV) equals(collection string, expected map[uint64]string) error {
m := kv.dump(collection)
if len(m) != len(expected) {
return fmt.Errorf("Expected %d values but found %d", len(expected), len(m))
}
for key, exp := range expected {
val, ok := m[key]
if !ok {
return fmt.Errorf("Value for %d not found.", key)
}
if val != exp {
return fmt.Errorf("%d: Expected %s but found %s.", key, exp, val)
}
}
return nil
}
func (kv *KV) equalsKV(collection string, rhs *KV) error {
return kv.equals(collection, rhs.dump(collection))
}
// Collection one of ("a", "b", "c").
// ID one of [1,10]
var (
randCollections = []string{"a", "b", "c"}
randIDs = []uint64{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
)
func (kv *KV) randAction() {
c := randCollections[rand.Intn(len(randCollections))]
id := randIDs[rand.Intn(len(randIDs))]
// Mostly stores.
if rand.Float64() < 0.9 {
kv.Store(c, id, randBytes())
} else {
kv.Delete(c, id)
}
}
// ----------------------------------------------------------------------------
func TestKV(t *testing.T) {
run := func(name string, inner func(t *testing.T, kv *KV)) {
dir, _ := os.MkdirTemp("", "")
defer os.RemoveAll(dir)
kv := NewPrimary(dir)
defer kv.Close()
inner(t, kv)
}
run("simple", func(t *testing.T, kv *KV) {
kv.Store("a", 1, _b("Hello"))
kv.Store("a", 2, _b("World"))
kv.waitForSeqNum(2)
err := kv.equals("a", map[uint64]string{
1: "Hello",
2: "World",
})
if err != nil {
t.Fatal(err)
}
})
}

30
kvstore/util.go Normal file
View File

@ -0,0 +1,30 @@
package kvstore
import (
"path/filepath"
"time"
)
const (
commitInterval = 250 * time.Millisecond // How often to commit from WAL.
cleanInterval = time.Minute // How often to clean WAL.
cleanBeforeSecs = 86400 * 7 // Clean WAL entries older than.
)
func must(err error) {
if err != nil {
panic(err)
}
}
func dataPath(dir string) string {
return filepath.Join(dir, "data")
}
func walPath(dir string) string {
return filepath.Join(dir, "wal")
}
func lockPath(dir string) string {
return filepath.Join(dir, "lock")
}

21
kvstore/util_test.go Normal file
View File

@ -0,0 +1,21 @@
package kvstore
import (
"crypto/rand"
"encoding/hex"
mrand "math/rand"
)
func _b(in string) []byte {
return []byte(in)
}
func randString() string {
buf := make([]byte, 1+mrand.Intn(20))
rand.Read(buf)
return hex.EncodeToString(buf)
}
func randBytes() []byte {
return _b(randString())
}