wip: shipping refactor

master
jdl 2022-07-25 22:47:26 +02:00
parent 1814597129
commit 3695fd5018
7 changed files with 315 additions and 90 deletions

View File

@ -2,6 +2,10 @@ package wal
import ( import (
"database/sql" "database/sql"
"encoding/binary"
"log"
"net"
"time"
) )
type Record struct { type Record struct {
@ -57,3 +61,84 @@ func (f *Follower) Replay(afterSeqNum uint64, each func(rec Record) error) error
} }
return nil return nil
} }
func (f *Follower) SendWAL(conn net.Conn) {
defer conn.Close()
var (
buf = make([]byte, 8)
headerBuf = make([]byte, recHeaderSize)
empty = make([]byte, recHeaderSize)
timeout = 16 * time.Second
heartbeatInterval = time.Second * 2
pollInterval = 200 * time.Millisecond
tStart time.Time
err error
)
// Read the fromID from the conn.
conn.SetReadDeadline(time.Now().Add(16 * time.Second))
if _, err := conn.Read(buf[:8]); err != nil {
log.Printf("SendWAL failed to read from ID: %v", err)
return
}
afterSeqNum := binary.LittleEndian.Uint64(buf[:8])
POLL:
conn.SetWriteDeadline(time.Now().Add(timeout))
tStart = time.Now()
for time.Since(tStart) < heartbeatInterval {
if f.MaxSeqNum() > afterSeqNum {
goto REPLAY
}
time.Sleep(pollInterval)
}
goto HEARTBEAT
HEARTBEAT:
conn.SetWriteDeadline(time.Now().Add(timeout))
if _, err := conn.Write(empty); err != nil {
log.Printf("SendWAL failed to send heartbeat: %v", err)
return
}
goto POLL
REPLAY:
err = f.Replay(afterSeqNum, func(rec Record) error {
conn.SetWriteDeadline(time.Now().Add(timeout))
afterSeqNum = rec.SeqNum
encodeRecordHeader(rec, headerBuf)
if _, err := conn.Write(headerBuf); err != nil {
log.Printf("SendWAL failed to send header %v", err)
return err
}
if _, err := conn.Write([]byte(rec.Collection)); err != nil {
log.Printf("SendWAL failed to send collection name %v", err)
return err
}
if !rec.Store {
return nil
}
if _, err := conn.Write(rec.Data); err != nil {
log.Printf("SendWAL failed to send data %v", err)
return err
}
return nil
})
if err != nil {
return
}
goto POLL
}

View File

@ -19,10 +19,12 @@ func (f *Follower) getReplay(afterSeqNum uint64) (l []Record) {
func (f *Follower) waitForSeqNum(n uint64) { func (f *Follower) waitForSeqNum(n uint64) {
for { for {
if f.MaxSeqNum() == n { maxSeqNum := f.MaxSeqNum()
//log.Printf("%d/%d", maxSeqNum, n)
if maxSeqNum == n {
return return
} }
time.Sleep(time.Millisecond) time.Sleep(100 * time.Millisecond)
} }
} }
@ -34,7 +36,7 @@ func TestFollower(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
walPath := randPath() + ".wal" walPath := randPath() + ".wal"
defer os.RemoveAll(walPath) defer os.RemoveAll(walPath)
w := NewWriter(walPath) w := NewWriter(walPath, true)
defer w.Close() defer w.Close()
f := NewFollower(walPath) f := NewFollower(walPath)
defer f.Close() defer f.Close()

View File

@ -117,7 +117,7 @@ func RecvWAL(walPath string, conn net.Conn) {
headerBuf := make([]byte, recHeaderSize) headerBuf := make([]byte, recHeaderSize)
buf := make([]byte, 8) buf := make([]byte, 8)
w := NewWriter(walPath) w := NewWriter(walPath, true)
defer w.Close() defer w.Close()
afterSeqNum := w.MaxSeqNum() afterSeqNum := w.MaxSeqNum()

View File

@ -3,13 +3,14 @@ package wal
import ( import (
"fmt" "fmt"
"math/rand" "math/rand"
"mdb/testconn"
"os" "os"
"testing" "testing"
"time" "time"
"git.crumpington.com/private/mdb/testconn"
) )
func TestShipp(t *testing.T) { func TestShip(t *testing.T) {
run := func(name string, inner func( run := func(name string, inner func(
t *testing.T, t *testing.T,
wWALPath string, wWALPath string,
@ -20,7 +21,7 @@ func TestShipp(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
wWALPath := randPath() + ".wal" wWALPath := randPath() + ".wal"
fWALPath := randPath() + ".wal" fWALPath := randPath() + ".wal"
w := NewWriter(wWALPath) w := NewWriter(wWALPath, true)
defer w.Close() defer w.Close()
nw := testconn.NewNetwork() nw := testconn.NewNetwork()
@ -45,8 +46,9 @@ func TestShipp(t *testing.T) {
// Run a sender in the background. // Run a sender in the background.
go func() { go func() {
f := NewFollower(wWALPath)
conn := nw.Accept() conn := nw.Accept()
SendWAL(wWALPath, conn) f.SendWAL(conn)
}() }()
// Run the follower. // Run the follower.
@ -68,9 +70,10 @@ func TestShipp(t *testing.T) {
}) })
run("net failures", func(t *testing.T, wWALPath, fWALPath string, w *Writer, nw *testconn.Network) { run("net failures", func(t *testing.T, wWALPath, fWALPath string, w *Writer, nw *testconn.Network) {
N := 10000 N := 2000
sleepTime := time.Millisecond sleepTime := time.Millisecond
go func() { go func() {
time.Sleep(4 * time.Second)
for i := 0; i < N; i++ { for i := 0; i < N; i++ {
time.Sleep(sleepTime) time.Sleep(sleepTime)
if rand.Float64() < 0.9 { if rand.Float64() < 0.9 {
@ -83,16 +86,19 @@ func TestShipp(t *testing.T) {
// Run a sender in the background. // Run a sender in the background.
go func() { go func() {
sender := NewFollower(wWALPath)
f := NewFollower(fWALPath) f := NewFollower(fWALPath)
for f.MaxSeqNum() < uint64(N) { for f.MaxSeqNum() < uint64(N) {
conn := nw.Accept() conn := nw.Accept()
SendWAL(wWALPath, conn) sender.SendWAL(conn)
} }
}() }()
// Run the follower in the background. // Run the follower in the background.
go func() { go func() {
f := NewFollower(fWALPath) f := NewFollower(fWALPath)
for f.MaxSeqNum() < uint64(N) { for f.MaxSeqNum() < uint64(N) {
conn := nw.Dial() conn := nw.Dial()
RecvWAL(fWALPath, conn) RecvWAL(fWALPath, conn)
@ -103,7 +109,7 @@ func TestShipp(t *testing.T) {
go func() { go func() {
f := NewFollower(fWALPath) f := NewFollower(fWALPath)
for f.MaxSeqNum() < uint64(N) { for f.MaxSeqNum() < uint64(N) {
time.Sleep(time.Duration(rand.Intn(2 * int(sleepTime)))) time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime))))
if rand.Float64() < 0.5 { if rand.Float64() < 0.5 {
nw.CloseClient() nw.CloseClient()
} else { } else {
@ -113,10 +119,10 @@ func TestShipp(t *testing.T) {
}() }()
time.Sleep(time.Second) time.Sleep(time.Second)
// Wait for follower to get 100 entries, then close connection.
f := NewFollower(fWALPath) f := NewFollower(fWALPath)
defer f.Close() defer f.Close()
// Wait for follower to get 100 entries, then close connection.
f.waitForSeqNum(uint64(N)) f.waitForSeqNum(uint64(N))
if err := walsEqual(wWALPath, fWALPath); err != nil { if err := walsEqual(wWALPath, fWALPath); err != nil {

111
wal/writer-background.go Normal file
View File

@ -0,0 +1,111 @@
package wal
import (
"database/sql"
"sync"
"time"
)
func (w *Writer) start() {
w.lock.Lock()
defer w.lock.Unlock()
if w.running {
return
}
w.insertQ = make(chan insertJob, 1024)
var maxSeqNum uint64
row := w.db.QueryRow(sqlWALMaxSeqNum)
must(row.Scan(&maxSeqNum))
w.doneWG.Add(1)
go w.insertProc(maxSeqNum)
w.running = true
}
func (w *Writer) stop() {
w.lock.Lock()
defer w.lock.Unlock()
if !w.running {
return
}
close(w.insertQ)
w.doneWG.Wait()
w.running = false
}
func (w *Writer) insertProc(maxSeqNum uint64) {
defer w.doneWG.Done()
var (
job insertJob
tx *sql.Tx
insert *sql.Stmt
ok bool
err error
newSeqNum uint64
now int64
wgs = make([]*sync.WaitGroup, 10)
)
var ()
BEGIN:
newSeqNum = maxSeqNum
wgs = wgs[:0]
job, ok = <-w.insertQ
if !ok {
return
}
tx, err = w.db.Begin()
must(err)
insert, err = tx.Prepare(sqlWALInsert)
must(err)
now = time.Now().Unix()
LOOP:
newSeqNum++
_, err = insert.Exec(
newSeqNum,
now,
job.Collection,
job.ID,
job.Store,
job.Data)
must(err)
if job.Ready != nil {
wgs = append(wgs, job.Ready)
}
select {
case job, ok = <-w.insertQ:
if ok {
goto LOOP
}
default:
}
goto COMMIT
COMMIT:
must(tx.Commit())
maxSeqNum = newSeqNum
for i := range wgs {
wgs[i].Done()
}
goto BEGIN
}

View File

@ -2,6 +2,9 @@ package wal
import ( import (
"database/sql" "database/sql"
"encoding/binary"
"log"
"net"
"sync" "sync"
"time" "time"
@ -17,29 +20,33 @@ type insertJob struct {
} }
type Writer struct { type Writer struct {
primary bool
db *sql.DB db *sql.DB
insert *sql.Stmt insert *sql.Stmt
lock sync.Mutex
running bool
insertQ chan insertJob insertQ chan insertJob
doneWG sync.WaitGroup doneWG sync.WaitGroup
recvLock sync.Mutex
} }
func NewWriter(walPath string) *Writer { func NewWriter(walPath string, primary bool) *Writer {
db := initWAL(walPath) db := initWAL(walPath)
insert, err := db.Prepare(sqlWALInsert) insert, err := db.Prepare(sqlWALInsert)
must(err) must(err)
w := &Writer{ w := &Writer{
primary: primary,
db: db, db: db,
insert: insert, insert: insert,
insertQ: make(chan insertJob, 1024),
} }
var maxSeqNum uint64 if primary {
row := db.QueryRow(sqlWALMaxSeqNum) w.start()
must(row.Scan(&maxSeqNum)) }
w.doneWG.Add(1)
go w.insertProc(maxSeqNum)
return w return w
} }
@ -48,13 +55,16 @@ func (w *Writer) Close() {
return return
} }
close(w.insertQ) w.stop()
w.doneWG.Wait()
w.db.Close() w.db.Close()
w.db = nil w.db = nil
} }
func (w *Writer) Store(collection string, id uint64, data []byte) { func (w *Writer) Store(collection string, id uint64, data []byte) {
if !w.primary {
//panic("Store called on secondary.")
}
job := insertJob{ job := insertJob{
Collection: collection, Collection: collection,
ID: id, ID: id,
@ -68,6 +78,10 @@ func (w *Writer) Store(collection string, id uint64, data []byte) {
} }
func (w *Writer) Delete(collection string, id uint64) { func (w *Writer) Delete(collection string, id uint64) {
if !w.primary {
//panic("Delete called on secondary.")
}
job := insertJob{ job := insertJob{
Collection: collection, Collection: collection,
ID: id, ID: id,
@ -103,74 +117,81 @@ func (w *Writer) MaxSeqNum() (n uint64) {
return return
} }
func (w *Writer) insertProc(maxSeqNum uint64) { func (w *Writer) RecvWAL(conn net.Conn) {
defer w.doneWG.Done() defer conn.Close()
var ( if w.primary {
job insertJob //panic("RecvWAL called on primary.")
tx *sql.Tx }
insert *sql.Stmt
ok bool
err error
newSeqNum uint64
now int64
wgs = make([]*sync.WaitGroup, 10)
)
var () if !w.recvLock.TryLock() {
log.Printf("Multiple calls to RecvWAL. Dropping connection.")
BEGIN:
newSeqNum = maxSeqNum
wgs = wgs[:0]
job, ok = <-w.insertQ
if !ok {
return return
} }
defer w.recvLock.Unlock()
tx, err = w.db.Begin() headerBuf := make([]byte, recHeaderSize)
must(err) buf := make([]byte, 8)
insert, err = tx.Prepare(sqlWALInsert) afterSeqNum := w.MaxSeqNum()
must(err) expectedSeqNum := afterSeqNum + 1
now = time.Now().Unix() // Send fromID to the conn.
conn.SetWriteDeadline(time.Now().Add(time.Minute))
LOOP: binary.LittleEndian.PutUint64(buf, afterSeqNum)
if _, err := conn.Write(buf); err != nil {
newSeqNum++ log.Printf("RecvWAL failed to send after sequence number: %v", err)
_, err = insert.Exec( return
newSeqNum,
now,
job.Collection,
job.ID,
job.Store,
job.Data)
must(err)
if job.Ready != nil {
wgs = append(wgs, job.Ready)
} }
conn.SetWriteDeadline(time.Time{})
select { // Start processing inserts.
case job, ok = <-w.insertQ: w.start()
if ok { defer w.stop()
goto LOOP
for {
conn.SetReadDeadline(time.Now().Add(time.Minute))
if _, err := conn.Read(headerBuf); err != nil {
log.Printf("RecvWAL failed to read header: %v", err)
return
}
rec, colLen, dataLen := decodeRecHeader(headerBuf)
// Heartbeat.
if rec.SeqNum == 0 {
continue
}
if rec.SeqNum != expectedSeqNum {
log.Printf("Expected sequence number %d but got %d.",
expectedSeqNum, rec.SeqNum)
return
}
expectedSeqNum++
if cap(buf) < colLen {
buf = make([]byte, colLen)
}
buf = buf[:colLen]
if _, err := conn.Read(buf); err != nil {
log.Printf("RecvWAL failed to collection name: %v", err)
return
}
rec.Collection = string(buf)
if rec.Store {
rec.Data = make([]byte, dataLen)
if _, err := conn.Read(rec.Data); err != nil {
log.Printf("RecvWAL failed to data: %v", err)
return
}
}
if rec.Store {
w.storeAsync(rec.Collection, rec.ID, rec.Data)
} else {
w.deleteAsync(rec.Collection, rec.ID)
} }
default:
} }
goto COMMIT
COMMIT:
must(tx.Commit())
maxSeqNum = newSeqNum
for i := range wgs {
wgs[i].Done()
}
goto BEGIN
} }

View File

@ -29,7 +29,7 @@ func TestWriter(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
walPath := randPath() + ".wal" walPath := randPath() + ".wal"
defer os.RemoveAll(walPath) defer os.RemoveAll(walPath)
w := NewWriter(walPath) w := NewWriter(walPath, true)
defer w.Close() defer w.Close()
inner(t, walPath, w) inner(t, walPath, w)
}) })
@ -57,15 +57,15 @@ func TestWriter(t *testing.T) {
w.Store("a", 1, _b("Hello")) w.Store("a", 1, _b("Hello"))
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.Delete("b", 1) w.Delete("b", 1)
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.Store("a", 2, _b("World")) w.Store("a", 2, _b("World"))
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.Store("a", 1, _b("Good bye")) w.Store("a", 1, _b("Good bye"))
err := walEqual(walPath, []Record{ err := walEqual(walPath, []Record{
@ -168,19 +168,19 @@ func TestWriter(t *testing.T) {
run("store delete async with close", func(t *testing.T, walPath string, w *Writer) { run("store delete async with close", func(t *testing.T, walPath string, w *Writer) {
w.storeAsync("a", 1, _b("hello1")) w.storeAsync("a", 1, _b("hello1"))
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.storeAsync("a", 2, _b("hello2")) w.storeAsync("a", 2, _b("hello2"))
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.deleteAsync("a", 1) w.deleteAsync("a", 1)
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.storeAsync("a", 3, _b("hello3")) w.storeAsync("a", 3, _b("hello3"))
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.storeAsync("b", 1, _b("b1")) w.storeAsync("b", 1, _b("b1"))
w.Close() w.Close()
w = NewWriter(walPath) w = NewWriter(walPath, true)
w.waitForSeqNum(5) w.waitForSeqNum(5)