From 3695fd5018e13370db249b90e6ff7c86d885b48c Mon Sep 17 00:00:00 2001 From: jdl Date: Mon, 25 Jul 2022 22:47:26 +0200 Subject: [PATCH] wip: shipping refactor --- wal/follower.go | 85 +++++++++++++++++++++ wal/follower_test.go | 8 +- wal/shipping.go | 2 +- wal/shipping_test.go | 24 +++--- wal/writer-background.go | 111 +++++++++++++++++++++++++++ wal/writer.go | 157 ++++++++++++++++++++++----------------- wal/writer_test.go | 18 ++--- 7 files changed, 315 insertions(+), 90 deletions(-) create mode 100644 wal/writer-background.go diff --git a/wal/follower.go b/wal/follower.go index 53af3c1..98ca580 100644 --- a/wal/follower.go +++ b/wal/follower.go @@ -2,6 +2,10 @@ package wal import ( "database/sql" + "encoding/binary" + "log" + "net" + "time" ) type Record struct { @@ -57,3 +61,84 @@ func (f *Follower) Replay(afterSeqNum uint64, each func(rec Record) error) error } return nil } + +func (f *Follower) SendWAL(conn net.Conn) { + defer conn.Close() + + var ( + buf = make([]byte, 8) + headerBuf = make([]byte, recHeaderSize) + empty = make([]byte, recHeaderSize) + timeout = 16 * time.Second + heartbeatInterval = time.Second * 2 + pollInterval = 200 * time.Millisecond + tStart time.Time + err error + ) + + // Read the fromID from the conn. + conn.SetReadDeadline(time.Now().Add(16 * time.Second)) + if _, err := conn.Read(buf[:8]); err != nil { + log.Printf("SendWAL failed to read from ID: %v", err) + return + } + + afterSeqNum := binary.LittleEndian.Uint64(buf[:8]) + +POLL: + + conn.SetWriteDeadline(time.Now().Add(timeout)) + tStart = time.Now() + for time.Since(tStart) < heartbeatInterval { + if f.MaxSeqNum() > afterSeqNum { + goto REPLAY + } + time.Sleep(pollInterval) + } + goto HEARTBEAT + +HEARTBEAT: + + conn.SetWriteDeadline(time.Now().Add(timeout)) + if _, err := conn.Write(empty); err != nil { + log.Printf("SendWAL failed to send heartbeat: %v", err) + return + } + + goto POLL + +REPLAY: + + err = f.Replay(afterSeqNum, func(rec Record) error { + conn.SetWriteDeadline(time.Now().Add(timeout)) + + afterSeqNum = rec.SeqNum + encodeRecordHeader(rec, headerBuf) + if _, err := conn.Write(headerBuf); err != nil { + log.Printf("SendWAL failed to send header %v", err) + return err + } + + if _, err := conn.Write([]byte(rec.Collection)); err != nil { + log.Printf("SendWAL failed to send collection name %v", err) + return err + } + + if !rec.Store { + return nil + } + + if _, err := conn.Write(rec.Data); err != nil { + log.Printf("SendWAL failed to send data %v", err) + return err + } + + return nil + }) + + if err != nil { + return + } + + goto POLL +} diff --git a/wal/follower_test.go b/wal/follower_test.go index c44e632..d5b223a 100644 --- a/wal/follower_test.go +++ b/wal/follower_test.go @@ -19,10 +19,12 @@ func (f *Follower) getReplay(afterSeqNum uint64) (l []Record) { func (f *Follower) waitForSeqNum(n uint64) { for { - if f.MaxSeqNum() == n { + maxSeqNum := f.MaxSeqNum() + //log.Printf("%d/%d", maxSeqNum, n) + if maxSeqNum == n { return } - time.Sleep(time.Millisecond) + time.Sleep(100 * time.Millisecond) } } @@ -34,7 +36,7 @@ func TestFollower(t *testing.T) { t.Run(name, func(t *testing.T) { walPath := randPath() + ".wal" defer os.RemoveAll(walPath) - w := NewWriter(walPath) + w := NewWriter(walPath, true) defer w.Close() f := NewFollower(walPath) defer f.Close() diff --git a/wal/shipping.go b/wal/shipping.go index c4196ca..bd7882e 100644 --- a/wal/shipping.go +++ b/wal/shipping.go @@ -117,7 +117,7 @@ func RecvWAL(walPath string, conn net.Conn) { headerBuf := make([]byte, recHeaderSize) buf := make([]byte, 8) - w := NewWriter(walPath) + w := NewWriter(walPath, true) defer w.Close() afterSeqNum := w.MaxSeqNum() diff --git a/wal/shipping_test.go b/wal/shipping_test.go index 6270c5d..ee8d86d 100644 --- a/wal/shipping_test.go +++ b/wal/shipping_test.go @@ -3,13 +3,14 @@ package wal import ( "fmt" "math/rand" - "mdb/testconn" "os" "testing" "time" + + "git.crumpington.com/private/mdb/testconn" ) -func TestShipp(t *testing.T) { +func TestShip(t *testing.T) { run := func(name string, inner func( t *testing.T, wWALPath string, @@ -20,7 +21,7 @@ func TestShipp(t *testing.T) { t.Run(name, func(t *testing.T) { wWALPath := randPath() + ".wal" fWALPath := randPath() + ".wal" - w := NewWriter(wWALPath) + w := NewWriter(wWALPath, true) defer w.Close() nw := testconn.NewNetwork() @@ -45,8 +46,9 @@ func TestShipp(t *testing.T) { // Run a sender in the background. go func() { + f := NewFollower(wWALPath) conn := nw.Accept() - SendWAL(wWALPath, conn) + f.SendWAL(conn) }() // Run the follower. @@ -68,9 +70,10 @@ func TestShipp(t *testing.T) { }) run("net failures", func(t *testing.T, wWALPath, fWALPath string, w *Writer, nw *testconn.Network) { - N := 10000 + N := 2000 sleepTime := time.Millisecond go func() { + time.Sleep(4 * time.Second) for i := 0; i < N; i++ { time.Sleep(sleepTime) if rand.Float64() < 0.9 { @@ -83,16 +86,19 @@ func TestShipp(t *testing.T) { // Run a sender in the background. go func() { + sender := NewFollower(wWALPath) f := NewFollower(fWALPath) + for f.MaxSeqNum() < uint64(N) { conn := nw.Accept() - SendWAL(wWALPath, conn) + sender.SendWAL(conn) } }() // Run the follower in the background. go func() { f := NewFollower(fWALPath) + for f.MaxSeqNum() < uint64(N) { conn := nw.Dial() RecvWAL(fWALPath, conn) @@ -103,7 +109,7 @@ func TestShipp(t *testing.T) { go func() { f := NewFollower(fWALPath) for f.MaxSeqNum() < uint64(N) { - time.Sleep(time.Duration(rand.Intn(2 * int(sleepTime)))) + time.Sleep(time.Duration(rand.Intn(10 * int(sleepTime)))) if rand.Float64() < 0.5 { nw.CloseClient() } else { @@ -113,10 +119,10 @@ func TestShipp(t *testing.T) { }() time.Sleep(time.Second) - - // Wait for follower to get 100 entries, then close connection. f := NewFollower(fWALPath) defer f.Close() + + // Wait for follower to get 100 entries, then close connection. f.waitForSeqNum(uint64(N)) if err := walsEqual(wWALPath, fWALPath); err != nil { diff --git a/wal/writer-background.go b/wal/writer-background.go new file mode 100644 index 0000000..0165767 --- /dev/null +++ b/wal/writer-background.go @@ -0,0 +1,111 @@ +package wal + +import ( + "database/sql" + "sync" + "time" +) + +func (w *Writer) start() { + w.lock.Lock() + defer w.lock.Unlock() + + if w.running { + return + } + + w.insertQ = make(chan insertJob, 1024) + + var maxSeqNum uint64 + row := w.db.QueryRow(sqlWALMaxSeqNum) + must(row.Scan(&maxSeqNum)) + + w.doneWG.Add(1) + go w.insertProc(maxSeqNum) + w.running = true +} + +func (w *Writer) stop() { + w.lock.Lock() + defer w.lock.Unlock() + + if !w.running { + return + } + + close(w.insertQ) + w.doneWG.Wait() + w.running = false +} + +func (w *Writer) insertProc(maxSeqNum uint64) { + defer w.doneWG.Done() + + var ( + job insertJob + tx *sql.Tx + insert *sql.Stmt + ok bool + err error + newSeqNum uint64 + now int64 + wgs = make([]*sync.WaitGroup, 10) + ) + + var () + +BEGIN: + + newSeqNum = maxSeqNum + wgs = wgs[:0] + + job, ok = <-w.insertQ + if !ok { + return + } + + tx, err = w.db.Begin() + must(err) + + insert, err = tx.Prepare(sqlWALInsert) + must(err) + + now = time.Now().Unix() + +LOOP: + + newSeqNum++ + _, err = insert.Exec( + newSeqNum, + now, + job.Collection, + job.ID, + job.Store, + job.Data) + + must(err) + if job.Ready != nil { + wgs = append(wgs, job.Ready) + } + + select { + case job, ok = <-w.insertQ: + if ok { + goto LOOP + } + default: + } + + goto COMMIT + +COMMIT: + + must(tx.Commit()) + + maxSeqNum = newSeqNum + for i := range wgs { + wgs[i].Done() + } + + goto BEGIN +} diff --git a/wal/writer.go b/wal/writer.go index 4462cc2..0c51361 100644 --- a/wal/writer.go +++ b/wal/writer.go @@ -2,6 +2,9 @@ package wal import ( "database/sql" + "encoding/binary" + "log" + "net" "sync" "time" @@ -17,29 +20,33 @@ type insertJob struct { } type Writer struct { + primary bool db *sql.DB insert *sql.Stmt + + lock sync.Mutex + running bool insertQ chan insertJob doneWG sync.WaitGroup + + recvLock sync.Mutex } -func NewWriter(walPath string) *Writer { +func NewWriter(walPath string, primary bool) *Writer { db := initWAL(walPath) insert, err := db.Prepare(sqlWALInsert) must(err) w := &Writer{ + primary: primary, db: db, insert: insert, - insertQ: make(chan insertJob, 1024), } - var maxSeqNum uint64 - row := db.QueryRow(sqlWALMaxSeqNum) - must(row.Scan(&maxSeqNum)) + if primary { + w.start() + } - w.doneWG.Add(1) - go w.insertProc(maxSeqNum) return w } @@ -48,13 +55,16 @@ func (w *Writer) Close() { return } - close(w.insertQ) - w.doneWG.Wait() + w.stop() w.db.Close() w.db = nil } func (w *Writer) Store(collection string, id uint64, data []byte) { + if !w.primary { + //panic("Store called on secondary.") + } + job := insertJob{ Collection: collection, ID: id, @@ -68,6 +78,10 @@ func (w *Writer) Store(collection string, id uint64, data []byte) { } func (w *Writer) Delete(collection string, id uint64) { + if !w.primary { + //panic("Delete called on secondary.") + } + job := insertJob{ Collection: collection, ID: id, @@ -103,74 +117,81 @@ func (w *Writer) MaxSeqNum() (n uint64) { return } -func (w *Writer) insertProc(maxSeqNum uint64) { - defer w.doneWG.Done() +func (w *Writer) RecvWAL(conn net.Conn) { + defer conn.Close() - var ( - job insertJob - tx *sql.Tx - insert *sql.Stmt - ok bool - err error - newSeqNum uint64 - now int64 - wgs = make([]*sync.WaitGroup, 10) - ) + if w.primary { + //panic("RecvWAL called on primary.") + } - var () - -BEGIN: - - newSeqNum = maxSeqNum - wgs = wgs[:0] - - job, ok = <-w.insertQ - if !ok { + if !w.recvLock.TryLock() { + log.Printf("Multiple calls to RecvWAL. Dropping connection.") return } + defer w.recvLock.Unlock() - tx, err = w.db.Begin() - must(err) + headerBuf := make([]byte, recHeaderSize) + buf := make([]byte, 8) - insert, err = tx.Prepare(sqlWALInsert) - must(err) + afterSeqNum := w.MaxSeqNum() + expectedSeqNum := afterSeqNum + 1 - now = time.Now().Unix() - -LOOP: - - newSeqNum++ - _, err = insert.Exec( - newSeqNum, - now, - job.Collection, - job.ID, - job.Store, - job.Data) - - must(err) - if job.Ready != nil { - wgs = append(wgs, job.Ready) + // Send fromID to the conn. + conn.SetWriteDeadline(time.Now().Add(time.Minute)) + binary.LittleEndian.PutUint64(buf, afterSeqNum) + if _, err := conn.Write(buf); err != nil { + log.Printf("RecvWAL failed to send after sequence number: %v", err) + return } + conn.SetWriteDeadline(time.Time{}) - select { - case job, ok = <-w.insertQ: - if ok { - goto LOOP + // Start processing inserts. + w.start() + defer w.stop() + + for { + conn.SetReadDeadline(time.Now().Add(time.Minute)) + if _, err := conn.Read(headerBuf); err != nil { + log.Printf("RecvWAL failed to read header: %v", err) + return + } + rec, colLen, dataLen := decodeRecHeader(headerBuf) + + // Heartbeat. + if rec.SeqNum == 0 { + continue + } + + if rec.SeqNum != expectedSeqNum { + log.Printf("Expected sequence number %d but got %d.", + expectedSeqNum, rec.SeqNum) + return + } + expectedSeqNum++ + + if cap(buf) < colLen { + buf = make([]byte, colLen) + } + buf = buf[:colLen] + + if _, err := conn.Read(buf); err != nil { + log.Printf("RecvWAL failed to collection name: %v", err) + return + } + + rec.Collection = string(buf) + + if rec.Store { + rec.Data = make([]byte, dataLen) + if _, err := conn.Read(rec.Data); err != nil { + log.Printf("RecvWAL failed to data: %v", err) + return + } + } + if rec.Store { + w.storeAsync(rec.Collection, rec.ID, rec.Data) + } else { + w.deleteAsync(rec.Collection, rec.ID) } - default: } - - goto COMMIT - -COMMIT: - - must(tx.Commit()) - - maxSeqNum = newSeqNum - for i := range wgs { - wgs[i].Done() - } - - goto BEGIN } diff --git a/wal/writer_test.go b/wal/writer_test.go index c4e74e3..8d9d2b7 100644 --- a/wal/writer_test.go +++ b/wal/writer_test.go @@ -29,7 +29,7 @@ func TestWriter(t *testing.T) { t.Run(name, func(t *testing.T) { walPath := randPath() + ".wal" defer os.RemoveAll(walPath) - w := NewWriter(walPath) + w := NewWriter(walPath, true) defer w.Close() inner(t, walPath, w) }) @@ -57,15 +57,15 @@ func TestWriter(t *testing.T) { w.Store("a", 1, _b("Hello")) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.Delete("b", 1) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.Store("a", 2, _b("World")) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.Store("a", 1, _b("Good bye")) err := walEqual(walPath, []Record{ @@ -168,19 +168,19 @@ func TestWriter(t *testing.T) { run("store delete async with close", func(t *testing.T, walPath string, w *Writer) { w.storeAsync("a", 1, _b("hello1")) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.storeAsync("a", 2, _b("hello2")) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.deleteAsync("a", 1) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.storeAsync("a", 3, _b("hello3")) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.storeAsync("b", 1, _b("b1")) w.Close() - w = NewWriter(walPath) + w = NewWriter(walPath, true) w.waitForSeqNum(5)