Initial commit

This commit is contained in:
jdl
2023-10-13 11:43:27 +02:00
commit 71eb6b0c7e
121 changed files with 11493 additions and 0 deletions

53
lib/wal/corrupt_test.go Normal file
View File

@@ -0,0 +1,53 @@
package wal
import (
"io"
"git.crumpington.com/public/jldb/lib/errs"
"testing"
)
func TestCorruptWAL(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
wal, err := Create(tmpDir, 100, Config{
SegMinCount: 1024,
SegMaxAgeSec: 3600,
})
if err != nil {
t.Fatal(err)
}
defer wal.Close()
appendRandomRecords(t, wal, 100)
f := wal.seg.f
info, err := f.Stat()
if err != nil {
t.Fatal(err)
}
offset := info.Size() / 2
if _, err := f.WriteAt([]byte{1, 2, 3, 4, 5, 6, 7, 8}, offset); err != nil {
t.Fatal(err)
}
it, err := wal.Iterator(-1)
if err != nil {
t.Fatal(err)
}
defer it.Close()
for it.Next(0) {
rec := it.Record()
if _, err := io.ReadAll(rec.Reader); err != nil {
if errs.Corrupt.Is(err) {
return
}
t.Fatal(err)
}
}
if !errs.Corrupt.Is(it.Error()) {
t.Fatal(it.Error())
}
}

28
lib/wal/design.go Normal file
View File

@@ -0,0 +1,28 @@
package wal
import (
"time"
)
type Info struct {
FirstSeqNum int64
LastSeqNum int64
LastTimestampMS int64
}
type Iterator interface {
// Next will return false if no record is available during the timeout
// period, or if an error is encountered. After Next returns false, the
// caller should check the return value of the Error function.
Next(timeout time.Duration) bool
// Call Record after Next returns true to get the next record.
Record() Record
// The caller must call Close on the iterator so clean-up can be performed.
Close()
// Call Error to see if there was an error during the previous call to Next
// if Next returned false.
Error() error
}

94
lib/wal/gc_test.go Normal file
View File

@@ -0,0 +1,94 @@
package wal
import (
"math/rand"
"sync"
"testing"
"time"
)
func TestDeleteBefore(t *testing.T) {
t.Parallel()
firstSeqNum := rand.Int63n(9288389)
tmpDir := t.TempDir()
wal, err := Create(tmpDir, firstSeqNum, Config{
SegMinCount: 10,
SegMaxAgeSec: 1,
})
if err != nil {
t.Fatal(err)
}
defer wal.Close()
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
err := writeRandomWithEOF(wal, 8*time.Second)
if err != nil {
panic(err)
}
}()
wg.Wait()
info := wal.Info()
if info.FirstSeqNum != firstSeqNum {
t.Fatal(info)
}
lastSeqNum := info.LastSeqNum
lastTimestampMS := info.LastTimestampMS
err = wal.DeleteBefore((info.LastTimestampMS/1000)-4, lastSeqNum+100)
if err != nil {
t.Fatal(err)
}
info = wal.Info()
if info.FirstSeqNum == firstSeqNum || info.LastSeqNum != lastSeqNum || info.LastTimestampMS != lastTimestampMS {
t.Fatal(info)
}
header := wal.header
if header.FirstSegmentID >= header.LastSegmentID {
t.Fatal(header)
}
}
func TestDeleteBeforeOnlyOneSegment(t *testing.T) {
t.Parallel()
firstSeqNum := rand.Int63n(9288389)
tmpDir := t.TempDir()
wal, err := Create(tmpDir, firstSeqNum, Config{
SegMinCount: 10,
SegMaxAgeSec: 10,
})
if err != nil {
t.Fatal(err)
}
defer wal.Close()
if err := writeRandomWithEOF(wal, time.Second); err != nil {
t.Fatal(err)
}
header := wal.header
if header.FirstSegmentID != header.LastSegmentID {
t.Fatal(header)
}
lastSeqNum := wal.Info().LastSeqNum
err = wal.DeleteBefore(time.Now().Unix()+1, lastSeqNum+100)
if err != nil {
t.Fatal(err)
}
header = wal.header
if header.FirstSegmentID != header.LastSegmentID {
t.Fatal(header)
}
}

391
lib/wal/generic_test.go Normal file
View File

@@ -0,0 +1,391 @@
package wal
import (
"bytes"
"encoding/binary"
"errors"
"io"
"git.crumpington.com/public/jldb/lib/errs"
"math/rand"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
)
type waLog interface {
Append(int64, io.Reader) (int64, int64, error)
appendRecord(Record) (int64, int64, error)
Iterator(int64) (Iterator, error)
Close() error
}
func TestGenericTestHarness_segment(t *testing.T) {
t.Parallel()
(&GenericTestHarness{
New: func(tmpDir string, firstSeqNum int64) (waLog, error) {
l, err := createSegment(filepath.Join(tmpDir, "x"), 1, firstSeqNum, 12345)
return l, err
},
}).Run(t)
}
func TestGenericTestHarness_wal(t *testing.T) {
t.Parallel()
(&GenericTestHarness{
New: func(tmpDir string, firstSeqNum int64) (waLog, error) {
l, err := Create(tmpDir, firstSeqNum, Config{
SegMinCount: 1,
SegMaxAgeSec: 1,
})
return l, err
},
}).Run(t)
}
// ----------------------------------------------------------------------------
type GenericTestHarness struct {
New func(tmpDir string, firstSeqNum int64) (waLog, error)
}
func (h *GenericTestHarness) Run(t *testing.T) {
val := reflect.ValueOf(h)
typ := val.Type()
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
if !strings.HasPrefix(method.Name, "Test") {
continue
}
t.Run(method.Name, func(t *testing.T) {
t.Parallel()
firstSeqNum := rand.Int63n(23423)
tmpDir := t.TempDir()
wal, err := h.New(tmpDir, firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer wal.Close()
val.MethodByName(method.Name).Call([]reflect.Value{
reflect.ValueOf(t),
reflect.ValueOf(firstSeqNum),
reflect.ValueOf(wal),
})
})
}
}
// ----------------------------------------------------------------------------
func (h *GenericTestHarness) TestBasic(t *testing.T, firstSeqNum int64, wal waLog) {
expected := appendRandomRecords(t, wal, 123)
for i := 0; i < 123; i++ {
it, err := wal.Iterator(firstSeqNum + int64(i))
if err != nil {
t.Fatal(err)
}
checkIteratorMatches(t, it, expected[i:])
it.Close()
}
}
func (h *GenericTestHarness) TestAppendNotFound(t *testing.T, firstSeqNum int64, wal waLog) {
recs := appendRandomRecords(t, wal, 123)
lastSeqNum := recs[len(recs)-1].SeqNum
it, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
it.Close()
it, err = wal.Iterator(lastSeqNum + 1)
if err != nil {
t.Fatal(err)
}
it.Close()
if _, err = wal.Iterator(firstSeqNum - 1); !errs.NotFound.Is(err) {
t.Fatal(err)
}
if _, err = wal.Iterator(lastSeqNum + 2); !errs.NotFound.Is(err) {
t.Fatal(err)
}
}
func (h *GenericTestHarness) TestNextAfterClose(t *testing.T, firstSeqNum int64, wal waLog) {
appendRandomRecords(t, wal, 123)
it, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it.Close()
if !it.Next(0) {
t.Fatal("Should be next")
}
if err := wal.Close(); err != nil {
t.Fatal(err)
}
if it.Next(0) {
t.Fatal("Shouldn't be next")
}
if !errs.Closed.Is(it.Error()) {
t.Fatal(it.Error())
}
}
func (h *GenericTestHarness) TestNextTimeout(t *testing.T, firstSeqNum int64, wal waLog) {
recs := appendRandomRecords(t, wal, 123)
it, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it.Close()
for range recs {
if !it.Next(0) {
t.Fatal("Expected next")
}
}
if it.Next(time.Millisecond) {
t.Fatal("Unexpected next")
}
}
func (h *GenericTestHarness) TestNextNotify(t *testing.T, firstSeqNum int64, wal waLog) {
it, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it.Close()
recsC := make(chan []RawRecord, 1)
go func() {
time.Sleep(time.Second)
recsC <- appendRandomRecords(t, wal, 1)
}()
if !it.Next(time.Hour) {
t.Fatal("expected next")
}
recs := <-recsC
rec := it.Record()
if rec.SeqNum != recs[0].SeqNum {
t.Fatal(rec)
}
}
func (h *GenericTestHarness) TestNextArchived(t *testing.T, firstSeqNum int64, wal waLog) {
type archiver interface {
Archive() error
}
arch, ok := wal.(archiver)
if !ok {
return
}
recs := appendRandomRecords(t, wal, 10)
it, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it.Close()
if err := arch.Archive(); err != nil {
t.Fatal(err)
}
for i, expected := range recs {
if !it.Next(time.Millisecond) {
t.Fatal(i, "no next")
}
rec := it.Record()
if rec.SeqNum != expected.SeqNum {
t.Fatal(rec, expected)
}
}
if it.Next(time.Minute) {
t.Fatal("unexpected next")
}
if !errs.EOFArchived.Is(it.Error()) {
t.Fatal(it.Error())
}
}
func (h *GenericTestHarness) TestWriteReadConcurrent(t *testing.T, firstSeqNum int64, wal waLog) {
N := 1200
writeErr := make(chan error, 1)
dataSize := int64(4)
makeData := func(i int) []byte {
data := make([]byte, 4)
binary.LittleEndian.PutUint32(data, uint32(i))
return data
}
go func() {
for i := 0; i < N; i++ {
seqNum, _, err := wal.Append(dataSize, bytes.NewBuffer(makeData(i)))
if err != nil {
writeErr <- err
return
}
if seqNum != int64(i)+firstSeqNum {
writeErr <- errors.New("Incorrect seq num")
return
}
time.Sleep(time.Millisecond)
}
writeErr <- nil
}()
it, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it.Close()
for i := 0; i < N; i++ {
if !it.Next(time.Minute) {
t.Fatal("expected next", i, it.Error(), it.Record())
}
expectedData := makeData(i)
rec := it.Record()
data, err := io.ReadAll(rec.Reader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, expectedData) {
t.Fatal(data, expectedData)
}
}
if err := <-writeErr; err != nil {
t.Fatal(err)
}
}
func (h *GenericTestHarness) TestAppendAfterClose(t *testing.T, firstSeqNum int64, wal waLog) {
if _, _, err := wal.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4})); err != nil {
t.Fatal(err)
}
wal.Close()
_, _, err := wal.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4}))
if !errs.Closed.Is(err) {
t.Fatal(err)
}
}
func (h *GenericTestHarness) TestIterateNegativeOne(t *testing.T, firstSeqNum int64, wal waLog) {
recs := appendRandomRecords(t, wal, 10)
it1, err := wal.Iterator(firstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it1.Close()
it2, err := wal.Iterator(-1)
if err != nil {
t.Fatal(err)
}
defer it2.Close()
if !it1.Next(0) {
t.Fatal(0)
}
if !it2.Next(0) {
t.Fatal(0)
}
r1 := it1.Record()
r2 := it2.Record()
if r1.SeqNum != r2.SeqNum || r1.SeqNum != firstSeqNum || r1.SeqNum != recs[0].SeqNum {
t.Fatal(r1.SeqNum, r2.SeqNum, firstSeqNum, recs[0].SeqNum)
}
}
func (h *GenericTestHarness) TestIteratorAfterClose(t *testing.T, firstSeqNum int64, wal waLog) {
appendRandomRecords(t, wal, 10)
wal.Close()
if _, err := wal.Iterator(-1); !errs.Closed.Is(err) {
t.Fatal(err)
}
}
func (h *GenericTestHarness) TestIteratorNextWithError(t *testing.T, firstSeqNum int64, wal waLog) {
appendRandomRecords(t, wal, 10)
it, err := wal.Iterator(-1)
if err != nil {
t.Fatal(err)
}
wal.Close()
it.Next(0)
if !errs.Closed.Is(it.Error()) {
t.Fatal(it.Error())
}
it.Next(0)
if !errs.Closed.Is(it.Error()) {
t.Fatal(it.Error())
}
}
func (h *GenericTestHarness) TestIteratorConcurrentClose(t *testing.T, firstSeqNum int64, wal waLog) {
it, err := wal.Iterator(-1)
if err != nil {
t.Fatal(err)
}
go func() {
writeRandomWithEOF(wal, 3*time.Second)
wal.Close()
}()
for it.Next(time.Hour) {
// Skip.
}
// Error may be Closed or NotFound.
if !errs.Closed.Is(it.Error()) && !errs.NotFound.Is(it.Error()) {
t.Fatal(it.Error())
}
}

125
lib/wal/io.go Normal file
View File

@@ -0,0 +1,125 @@
package wal
import (
"encoding/binary"
"errors"
"hash/crc32"
"io"
"git.crumpington.com/public/jldb/lib/errs"
)
func ioErrOrEOF(err error) error {
if err == nil {
return nil
}
if errors.Is(err, io.EOF) {
return err
}
return errs.IO.WithErr(err)
}
// ----------------------------------------------------------------------------
type readAtReader struct {
f io.ReaderAt
offset int64
}
func readerAtToReader(f io.ReaderAt, offset int64) io.Reader {
return &readAtReader{f: f, offset: offset}
}
func (r *readAtReader) Read(b []byte) (int, error) {
n, err := r.f.ReadAt(b, r.offset)
r.offset += int64(n)
return n, ioErrOrEOF(err)
}
// ----------------------------------------------------------------------------
type writeAtWriter struct {
w io.WriterAt
offset int64
}
func writerAtToWriter(w io.WriterAt, offset int64) io.Writer {
return &writeAtWriter{w: w, offset: offset}
}
func (w *writeAtWriter) Write(b []byte) (int, error) {
n, err := w.w.WriteAt(b, w.offset)
w.offset += int64(n)
return n, ioErrOrEOF(err)
}
// ----------------------------------------------------------------------------
type crcWriter struct {
w io.Writer
crc uint32
}
func newCRCWriter(w io.Writer) *crcWriter {
return &crcWriter{w: w}
}
func (w *crcWriter) Write(b []byte) (int, error) {
n, err := w.w.Write(b)
w.crc = crc32.Update(w.crc, crc32.IEEETable, b[:n])
return n, ioErrOrEOF(err)
}
func (w *crcWriter) CRC() uint32 {
return w.crc
}
// ----------------------------------------------------------------------------
type dataReader struct {
r io.Reader
remaining int64
crc uint32
}
func newDataReader(r io.Reader, dataSize int64) *dataReader {
return &dataReader{r: r, remaining: dataSize}
}
func (r *dataReader) Read(b []byte) (int, error) {
if r.remaining == 0 {
return 0, io.EOF
}
if int64(len(b)) > r.remaining {
b = b[:r.remaining]
}
n, err := r.r.Read(b)
r.crc = crc32.Update(r.crc, crc32.IEEETable, b[:n])
r.remaining -= int64(n)
if r.remaining == 0 {
if err := r.checkCRC(); err != nil {
return n, err
}
}
if err != nil && !errors.Is(err, io.EOF) {
return n, errs.IO.WithErr(err)
}
return n, nil
}
func (r *dataReader) checkCRC() error {
buf := make([]byte, 4)
if _, err := r.r.Read(buf); err != nil {
return errs.Corrupt.WithErr(err)
}
crc := binary.LittleEndian.Uint32(buf)
if crc != r.crc {
return errs.Corrupt.WithMsg("crc mismatch")
}
return nil
}

79
lib/wal/notify.go Normal file
View File

@@ -0,0 +1,79 @@
package wal
import "sync"
type segmentState struct {
Closed bool
Archived bool
FirstSeqNum int64
LastSeqNum int64
}
func newSegmentState(closed bool, header segmentHeader) segmentState {
return segmentState{
Closed: closed,
Archived: header.ArchivedAt != 0,
FirstSeqNum: header.FirstSeqNum,
LastSeqNum: header.LastSeqNum,
}
}
type notifyMux struct {
lock sync.Mutex
nextID int64
recvrs map[int64]chan segmentState
}
type stateRecvr struct {
// Each recvr will always get the most recent sequence number on C.
// When the segment is closed, a -1 is sent.
C chan segmentState
Close func()
}
func newNotifyMux() *notifyMux {
return &notifyMux{
recvrs: map[int64]chan segmentState{},
}
}
func (m *notifyMux) NewRecvr(header segmentHeader) stateRecvr {
state := newSegmentState(false, header)
m.lock.Lock()
defer m.lock.Unlock()
m.nextID++
recvrID := m.nextID
recvr := stateRecvr{
C: make(chan segmentState, 1),
Close: func() {
m.lock.Lock()
defer m.lock.Unlock()
delete(m.recvrs, recvrID)
},
}
recvr.C <- state
m.recvrs[recvrID] = recvr.C
return recvr
}
func (m *notifyMux) Notify(closed bool, header segmentHeader) {
state := newSegmentState(closed, header)
m.lock.Lock()
defer m.lock.Unlock()
for _, c := range m.recvrs {
select {
case c <- state:
case <-c:
c <- state
}
}
}

90
lib/wal/record.go Normal file
View File

@@ -0,0 +1,90 @@
package wal
import (
"encoding/binary"
"hash/crc32"
"io"
"git.crumpington.com/public/jldb/lib/errs"
)
const recordHeaderSize = 28
type Record struct {
SeqNum int64
TimestampMS int64
DataSize int64
Reader io.Reader
}
func (rec Record) writeHeaderTo(w io.Writer) (int, error) {
buf := make([]byte, recordHeaderSize)
binary.LittleEndian.PutUint64(buf[0:], uint64(rec.SeqNum))
binary.LittleEndian.PutUint64(buf[8:], uint64(rec.TimestampMS))
binary.LittleEndian.PutUint64(buf[16:], uint64(rec.DataSize))
crc := crc32.ChecksumIEEE(buf[:recordHeaderSize-4])
binary.LittleEndian.PutUint32(buf[24:], crc)
n, err := w.Write(buf)
if err != nil {
err = errs.IO.WithErr(err)
}
return n, err
}
func (rec *Record) readHeaderFrom(r io.Reader) error {
buf := make([]byte, recordHeaderSize)
if _, err := io.ReadFull(r, buf); err != nil {
return errs.IO.WithErr(err)
}
crc := crc32.ChecksumIEEE(buf[:recordHeaderSize-4])
stored := binary.LittleEndian.Uint32(buf[recordHeaderSize-4:])
if crc != stored {
return errs.Corrupt.WithMsg("checksum mismatch")
}
rec.SeqNum = int64(binary.LittleEndian.Uint64(buf[0:]))
rec.TimestampMS = int64(binary.LittleEndian.Uint64(buf[8:]))
rec.DataSize = int64(binary.LittleEndian.Uint64(buf[16:]))
return nil
}
func (rec Record) serializedSize() int64 {
return recordHeaderSize + rec.DataSize + 4 // 4 for data CRC32.
}
func (rec Record) writeTo(w io.Writer) (int64, error) {
nn, err := rec.writeHeaderTo(w)
if err != nil {
return int64(nn), err
}
n := int64(nn)
// Write the data.
crcW := newCRCWriter(w)
n2, err := io.CopyN(crcW, rec.Reader, rec.DataSize)
n += n2
if err != nil {
return n, errs.IO.WithErr(err)
}
// Write the data crc value.
err = binary.Write(w, binary.LittleEndian, crcW.CRC())
if err != nil {
return n, errs.IO.WithErr(err)
}
n += 4
return n, nil
}
func (rec *Record) readFrom(r io.Reader) error {
if err := rec.readHeaderFrom(r); err != nil {
return err
}
rec.Reader = newDataReader(r, rec.DataSize)
return nil
}

171
lib/wal/record_test.go Normal file
View File

@@ -0,0 +1,171 @@
package wal
import (
"bytes"
"io"
"git.crumpington.com/public/jldb/lib/errs"
"git.crumpington.com/public/jldb/lib/testutil"
"math/rand"
"testing"
)
func NewRecordForTesting() Record {
data := randData()
return Record{
SeqNum: rand.Int63(),
TimestampMS: rand.Int63(),
DataSize: int64(len(data)),
Reader: bytes.NewBuffer(data),
}
}
func AssertRecordHeadersEqual(t *testing.T, r1, r2 Record) {
t.Helper()
eq := r1.SeqNum == r2.SeqNum &&
r1.TimestampMS == r2.TimestampMS &&
r1.DataSize == r2.DataSize
if !eq {
t.Fatal(r1, r2)
}
}
func TestRecordWriteHeaderToReadHeaderFrom(t *testing.T) {
t.Parallel()
rec1 := NewRecordForTesting()
b := &bytes.Buffer{}
n, err := rec1.writeHeaderTo(b)
if err != nil {
t.Fatal(err)
}
if n != recordHeaderSize {
t.Fatal(n)
}
rec2 := Record{}
if err := rec2.readHeaderFrom(b); err != nil {
t.Fatal(err)
}
AssertRecordHeadersEqual(t, rec1, rec2)
}
func TestRecordWriteHeaderToEOF(t *testing.T) {
t.Parallel()
rec := NewRecordForTesting()
for limit := 1; limit < recordHeaderSize; limit++ {
buf := &bytes.Buffer{}
w := testutil.NewLimitWriter(buf, limit)
n, err := rec.writeHeaderTo(w)
if !errs.IO.Is(err) {
t.Fatal(limit, n, err)
}
}
}
func TestRecordReadHeaderFromError(t *testing.T) {
t.Parallel()
rec := NewRecordForTesting()
for limit := 1; limit < recordHeaderSize; limit++ {
b := &bytes.Buffer{}
if _, err := rec.writeHeaderTo(b); err != nil {
t.Fatal(err)
}
r := io.LimitReader(b, int64(limit))
if err := rec.readFrom(r); !errs.IO.Is(err) {
t.Fatal(err)
}
}
}
func TestRecordReadHeaderFromCorrupt(t *testing.T) {
t.Parallel()
rec := NewRecordForTesting()
b := &bytes.Buffer{}
for i := 0; i < recordHeaderSize; i++ {
if _, err := rec.writeHeaderTo(b); err != nil {
t.Fatal(err)
}
b.Bytes()[i]++
if err := rec.readHeaderFrom(b); !errs.Corrupt.Is(err) {
t.Fatal(err)
}
}
}
func TestRecordWriteToReadFrom(t *testing.T) {
t.Parallel()
r1 := NewRecordForTesting()
data := randData()
r1.Reader = bytes.NewBuffer(bytes.Clone(data))
r1.DataSize = int64(len(data))
r2 := Record{}
b := &bytes.Buffer{}
if _, err := r1.writeTo(b); err != nil {
t.Fatal(err)
}
if err := r2.readFrom(b); err != nil {
t.Fatal(err)
}
AssertRecordHeadersEqual(t, r1, r2)
data2, err := io.ReadAll(r2.Reader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, data2) {
t.Fatal(data, data2)
}
}
func TestRecordReadFromCorrupt(t *testing.T) {
t.Parallel()
data := randData()
r1 := NewRecordForTesting()
for i := 0; i < int(r1.serializedSize()); i++ {
r1.Reader = bytes.NewBuffer(data)
r1.DataSize = int64(len(data))
buf := &bytes.Buffer{}
r1.writeTo(buf)
buf.Bytes()[i]++
r2 := Record{}
if err := r2.readFrom(buf); err != nil {
if !errs.Corrupt.Is(err) {
t.Fatal(i, err)
}
continue // OK.
}
if _, err := io.ReadAll(r2.Reader); !errs.Corrupt.Is(err) {
t.Fatal(err)
}
}
}
func TestRecordWriteToError(t *testing.T) {
t.Parallel()
data := randData()
r1 := NewRecordForTesting()
r1.Reader = bytes.NewBuffer(data)
r1.DataSize = int64(len(data))
for i := 0; i < int(r1.serializedSize()); i++ {
w := testutil.NewLimitWriter(&bytes.Buffer{}, i)
r1.Reader = bytes.NewBuffer(data)
if _, err := r1.writeTo(w); !errs.IO.Is(err) {
t.Fatal(err)
}
}
}

44
lib/wal/segment-header.go Normal file
View File

@@ -0,0 +1,44 @@
package wal
import "encoding/binary"
type segmentHeader struct {
CreatedAt int64
ArchivedAt int64
FirstSeqNum int64
LastSeqNum int64 // FirstSeqNum - 1 if empty.
LastTimestampMS int64 // 0 if empty.
InsertAt int64
}
func (h segmentHeader) WriteTo(b []byte) {
vals := []int64{
h.CreatedAt,
h.ArchivedAt,
h.FirstSeqNum,
h.LastSeqNum,
h.LastTimestampMS,
h.InsertAt,
}
for _, val := range vals {
binary.LittleEndian.PutUint64(b[0:8], uint64(val))
b = b[8:]
}
}
func (h *segmentHeader) ReadFrom(b []byte) {
ptrs := []*int64{
&h.CreatedAt,
&h.ArchivedAt,
&h.FirstSeqNum,
&h.LastSeqNum,
&h.LastTimestampMS,
&h.InsertAt,
}
for _, ptr := range ptrs {
*ptr = int64(binary.LittleEndian.Uint64(b[0:8]))
b = b[8:]
}
}

165
lib/wal/segment-iterator.go Normal file
View File

@@ -0,0 +1,165 @@
package wal
import (
"git.crumpington.com/public/jldb/lib/atomicheader"
"git.crumpington.com/public/jldb/lib/errs"
"os"
"time"
)
type segmentIterator struct {
f *os.File
recvr stateRecvr
state segmentState
offset int64
err error
rec Record
ticker *time.Ticker // Ticker if timeout has been set.
tickerC <-chan time.Time // Ticker channel if timeout has been set.
}
func newSegmentIterator(
f *os.File,
fromSeqNum int64,
recvr stateRecvr,
) (
Iterator,
error,
) {
it := &segmentIterator{
f: f,
recvr: recvr,
state: <-recvr.C,
}
if err := it.seekToSeqNum(fromSeqNum); err != nil {
it.Close()
return nil, err
}
it.rec.SeqNum = fromSeqNum - 1
it.ticker = time.NewTicker(time.Second)
it.tickerC = it.ticker.C
return it, nil
}
func (it *segmentIterator) seekToSeqNum(fromSeqNum int64) error {
state := it.state
// Is the requested sequence number out-of-range?
if fromSeqNum < state.FirstSeqNum || fromSeqNum > state.LastSeqNum+1 {
return errs.NotFound.WithMsg("sequence number not in segment")
}
// Seek to start.
it.offset = atomicheader.ReservedBytes
// Seek to first seq num - we're already there.
if fromSeqNum == it.state.FirstSeqNum {
return nil
}
for {
if err := it.readRecord(); err != nil {
return err
}
it.offset += it.rec.serializedSize()
if it.rec.SeqNum == fromSeqNum-1 {
return nil
}
}
}
func (it *segmentIterator) Close() {
it.f.Close()
it.recvr.Close()
}
// Next returns true if there's a record available to read via it.Record().
//
// If Next returns false, the caller should check the error value with
// it.Error().
func (it *segmentIterator) Next(timeout time.Duration) bool {
if it.err != nil {
return false
}
// Get new state if available.
select {
case it.state = <-it.recvr.C:
default:
}
if it.state.Closed {
it.err = errs.Closed
return false
}
if it.rec.SeqNum < it.state.LastSeqNum {
if it.err = it.readRecord(); it.err != nil {
return false
}
it.offset += it.rec.serializedSize()
return true
}
if it.state.Archived {
it.err = errs.EOFArchived
return false
}
if timeout <= 0 {
return false // Nothing to return.
}
// Wait for new record, or timeout.
it.ticker.Reset(timeout)
// Get new state if available.
select {
case it.state = <-it.recvr.C:
// OK
case <-it.tickerC:
return false // Timeout, no error.
}
if it.state.Closed {
it.err = errs.Closed
return false
}
if it.rec.SeqNum < it.state.LastSeqNum {
if it.err = it.readRecord(); it.err != nil {
return false
}
it.offset += it.rec.serializedSize()
return true
}
if it.state.Archived {
it.err = errs.EOFArchived
return false
}
return false
}
func (it *segmentIterator) Record() Record {
return it.rec
}
func (it *segmentIterator) Error() error {
return it.err
}
func (it *segmentIterator) readRecord() error {
return it.rec.readFrom(readerAtToReader(it.f, it.offset))
}

250
lib/wal/segment.go Normal file
View File

@@ -0,0 +1,250 @@
package wal
import (
"bufio"
"io"
"git.crumpington.com/public/jldb/lib/atomicheader"
"git.crumpington.com/public/jldb/lib/errs"
"os"
"sync"
"time"
)
type segment struct {
ID int64
lock sync.Mutex
closed bool
header segmentHeader
headWriter *atomicheader.Handler
f *os.File
notifyMux *notifyMux
// For non-archived segments.
w *bufio.Writer
}
func createSegment(path string, id, firstSeqNum, timestampMS int64) (*segment, error) {
f, err := os.Create(path)
if err != nil {
return nil, errs.IO.WithErr(err)
}
defer f.Close()
if err := atomicheader.Init(f); err != nil {
return nil, err
}
handler, err := atomicheader.Open(f)
if err != nil {
return nil, err
}
header := segmentHeader{
CreatedAt: time.Now().Unix(),
FirstSeqNum: firstSeqNum,
LastSeqNum: firstSeqNum - 1,
LastTimestampMS: timestampMS,
InsertAt: atomicheader.ReservedBytes,
}
err = handler.Write(func(page []byte) error {
header.WriteTo(page)
return nil
})
if err != nil {
return nil, err
}
return openSegment(path, id)
}
func openSegment(path string, id int64) (*segment, error) {
f, err := os.OpenFile(path, os.O_RDWR, 0600)
if err != nil {
return nil, errs.IO.WithErr(err)
}
handler, err := atomicheader.Open(f)
if err != nil {
f.Close()
return nil, err
}
var header segmentHeader
err = handler.Read(func(page []byte) error {
header.ReadFrom(page)
return nil
})
if err != nil {
f.Close()
return nil, err
}
if _, err := f.Seek(header.InsertAt, io.SeekStart); err != nil {
f.Close()
return nil, errs.IO.WithErr(err)
}
seg := &segment{
ID: id,
header: header,
headWriter: handler,
f: f,
notifyMux: newNotifyMux(),
}
if header.ArchivedAt == 0 {
seg.w = bufio.NewWriterSize(f, 1024*1024)
}
return seg, nil
}
// Append appends the data from r to the log atomically. If an error is
// returned, the caller should check for errs.Fatal. If a fatal error occurs,
// the segment should no longer be used.
func (seg *segment) Append(dataSize int64, r io.Reader) (int64, int64, error) {
return seg.appendRecord(Record{
SeqNum: -1,
TimestampMS: time.Now().UnixMilli(),
DataSize: dataSize,
Reader: r,
})
}
func (seg *segment) Header() segmentHeader {
seg.lock.Lock()
defer seg.lock.Unlock()
return seg.header
}
// appendRecord appends a record in an atomic fashion. Do not use the segment
// after a fatal error.
func (seg *segment) appendRecord(rec Record) (int64, int64, error) {
seg.lock.Lock()
defer seg.lock.Unlock()
header := seg.header // Copy.
if seg.closed {
return 0, 0, errs.Closed
}
if header.ArchivedAt != 0 {
return 0, 0, errs.Archived
}
if rec.SeqNum == -1 {
rec.SeqNum = header.LastSeqNum + 1
} else if rec.SeqNum != header.LastSeqNum+1 {
return 0, 0, errs.Unexpected.WithMsg(
"expected sequence number %d but got %d",
header.LastSeqNum+1,
rec.SeqNum)
}
seg.w.Reset(writerAtToWriter(seg.f, header.InsertAt))
n, err := rec.writeTo(seg.w)
if err != nil {
return 0, 0, err
}
if err := seg.w.Flush(); err != nil {
return 0, 0, ioErrOrEOF(err)
}
// Write new header to sync.
header.LastSeqNum = rec.SeqNum
header.LastTimestampMS = rec.TimestampMS
header.InsertAt += n
err = seg.headWriter.Write(func(page []byte) error {
header.WriteTo(page)
return nil
})
if err != nil {
return 0, 0, err
}
seg.header = header
seg.notifyMux.Notify(false, header)
return rec.SeqNum, rec.TimestampMS, nil
}
// ----------------------------------------------------------------------------
func (seg *segment) Archive() error {
seg.lock.Lock()
defer seg.lock.Unlock()
header := seg.header // Copy
if header.ArchivedAt != 0 {
return nil
}
header.ArchivedAt = time.Now().Unix()
err := seg.headWriter.Write(func(page []byte) error {
header.WriteTo(page)
return nil
})
if err != nil {
return err
}
seg.w = nil // We won't be writing any more.
seg.header = header
seg.notifyMux.Notify(false, header)
return nil
}
// ----------------------------------------------------------------------------
func (seg *segment) Iterator(fromSeqNum int64) (Iterator, error) {
seg.lock.Lock()
defer seg.lock.Unlock()
if seg.closed {
return nil, errs.Closed
}
f, err := os.Open(seg.f.Name())
if err != nil {
return nil, errs.IO.WithErr(err)
}
header := seg.header
if fromSeqNum == -1 {
fromSeqNum = header.FirstSeqNum
}
return newSegmentIterator(
f,
fromSeqNum,
seg.notifyMux.NewRecvr(header))
}
// ----------------------------------------------------------------------------
func (seg *segment) Close() error {
seg.lock.Lock()
defer seg.lock.Unlock()
if seg.closed {
return nil
}
seg.closed = true
header := seg.header
seg.notifyMux.Notify(true, header)
seg.f.Close()
return nil
}

145
lib/wal/segment_test.go Normal file
View File

@@ -0,0 +1,145 @@
package wal
import (
"bytes"
crand "crypto/rand"
"io"
"git.crumpington.com/public/jldb/lib/atomicheader"
"git.crumpington.com/public/jldb/lib/errs"
"path/filepath"
"testing"
"time"
)
func newSegmentForTesting(t *testing.T) *segment {
tmpDir := t.TempDir()
seg, err := createSegment(filepath.Join(tmpDir, "x"), 1, 100, 200)
if err != nil {
t.Fatal(err)
}
return seg
}
func TestNewSegmentDirNotFound(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "notFound", "1245")
if _, err := createSegment(p, 1, 1234, 5678); !errs.IO.Is(err) {
t.Fatal(err)
}
}
func TestOpenSegmentNotFound(t *testing.T) {
t.Parallel()
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "notFound")
if _, err := openSegment(p, 1); !errs.IO.Is(err) {
t.Fatal(err)
}
}
func TestOpenSegmentTruncatedFile(t *testing.T) {
t.Parallel()
seg := newSegmentForTesting(t)
path := seg.f.Name()
if err := seg.f.Truncate(4); err != nil {
t.Fatal(err)
}
seg.Close()
if _, err := openSegment(path, 1); !errs.IO.Is(err) {
t.Fatal(err)
}
}
func TestOpenSegmentCorruptHeader(t *testing.T) {
t.Parallel()
seg := newSegmentForTesting(t)
path := seg.f.Name()
buf := make([]byte, atomicheader.ReservedBytes)
crand.Read(buf)
if _, err := seg.f.Seek(0, io.SeekStart); err != nil {
t.Fatal(err)
}
if _, err := seg.f.Write(buf); err != nil {
t.Fatal(err)
}
seg.Close()
if _, err := openSegment(path, 1); !errs.Corrupt.Is(err) {
t.Fatal(err)
}
}
func TestOpenSegmentCorruptHeader2(t *testing.T) {
t.Parallel()
seg := newSegmentForTesting(t)
path := seg.f.Name()
buf := make([]byte, 1024) // 2 pages.
crand.Read(buf)
if _, err := seg.f.Seek(1024, io.SeekStart); err != nil {
t.Fatal(err)
}
if _, err := seg.f.Write(buf); err != nil {
t.Fatal(err)
}
seg.Close()
if _, err := openSegment(path, 1); !errs.Corrupt.Is(err) {
t.Fatal(err)
}
}
func TestSegmentArchiveTwice(t *testing.T) {
t.Parallel()
seg := newSegmentForTesting(t)
for i := 0; i < 2; i++ {
if err := seg.Archive(); err != nil {
t.Fatal(err)
}
}
}
func TestSegmentAppendArchived(t *testing.T) {
t.Parallel()
seg := newSegmentForTesting(t)
appendRandomRecords(t, seg, 8)
if err := seg.Archive(); err != nil {
t.Fatal(err)
}
_, _, err := seg.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4}))
if !errs.Archived.Is(err) {
t.Fatal(err)
}
}
func TestSegmentAppendRecordInvalidSeqNum(t *testing.T) {
t.Parallel()
seg := newSegmentForTesting(t)
appendRandomRecords(t, seg, 8) // 109 is next.
_, _, err := seg.appendRecord(Record{
SeqNum: 110,
TimestampMS: time.Now().UnixMilli(),
DataSize: 100,
})
if !errs.Unexpected.Is(err) {
t.Fatal(err)
}
}

232
lib/wal/test-util_test.go Normal file
View File

@@ -0,0 +1,232 @@
package wal
import (
"bytes"
crand "crypto/rand"
"encoding/base32"
"encoding/binary"
"hash/crc32"
"io"
"math/rand"
"os"
"reflect"
"testing"
"time"
)
// ----------------------------------------------------------------------------
func randString() string {
size := 8 + rand.Intn(92)
buf := make([]byte, size)
if _, err := crand.Read(buf); err != nil {
panic(err)
}
return base32.StdEncoding.EncodeToString(buf)
}
// ----------------------------------------------------------------------------
type RawRecord struct {
Record
Data []byte
DataCRC uint32
}
func (rr *RawRecord) ReadFrom(t *testing.T, f *os.File, offset int64) {
t.Helper()
buf := make([]byte, recordHeaderSize)
if _, err := f.ReadAt(buf, offset); err != nil {
t.Fatal(err)
}
if err := rr.Record.readHeaderFrom(readerAtToReader(f, offset)); err != nil {
t.Fatal(err)
}
rr.Data = make([]byte, rr.DataSize+4) // For data and CRC32.
if _, err := f.ReadAt(rr.Data, offset+recordHeaderSize); err != nil {
t.Fatal(err)
}
storedCRC := binary.LittleEndian.Uint32(rr.Data[rr.DataSize:])
computedCRC := crc32.ChecksumIEEE(rr.Data[:rr.DataSize])
if storedCRC != computedCRC {
t.Fatal(storedCRC, computedCRC)
}
rr.Data = rr.Data[:rr.DataSize]
}
// ----------------------------------------------------------------------------
func appendRandomRecords(t *testing.T, w waLog, count int64) []RawRecord {
t.Helper()
recs := make([]RawRecord, count)
for i := range recs {
rec := RawRecord{
Data: []byte(randString()),
}
rec.DataSize = int64(len(rec.Data))
seqNum, _, err := w.Append(int64(len(rec.Data)), bytes.NewBuffer(rec.Data))
if err != nil {
t.Fatal(err)
}
rec.SeqNum = seqNum
recs[i] = rec
}
// Check that sequence numbers are sequential.
seqNum := recs[0].SeqNum
for _, rec := range recs {
if rec.SeqNum != seqNum {
t.Fatal(seqNum, rec)
}
seqNum++
}
return recs
}
func checkIteratorMatches(t *testing.T, it Iterator, recs []RawRecord) {
for i, expected := range recs {
if !it.Next(time.Millisecond) {
t.Fatal(i, "no next")
}
rec := it.Record()
if rec.SeqNum != expected.SeqNum {
t.Fatal(i, rec.SeqNum, expected.SeqNum)
}
if rec.DataSize != expected.DataSize {
t.Fatal(i, rec.DataSize, expected.DataSize)
}
if rec.TimestampMS == 0 {
t.Fatal(rec.TimestampMS)
}
data := make([]byte, rec.DataSize)
if _, err := io.ReadFull(rec.Reader, data); err != nil {
t.Fatal(err)
}
if !bytes.Equal(data, expected.Data) {
t.Fatalf("%d %s != %s", i, data, expected.Data)
}
}
if it.Error() != nil {
t.Fatal(it.Error())
}
// Check that iterator is empty.
if it.Next(0) {
t.Fatal("extra", it.Record())
}
}
func randData() []byte {
data := make([]byte, 1+rand.Intn(128))
crand.Read(data)
return data
}
func writeRandomWithEOF(w waLog, dt time.Duration) error {
tStart := time.Now()
for time.Since(tStart) < dt {
data := randData()
_, _, err := w.Append(int64(len(data)), bytes.NewBuffer(data))
if err != nil {
return err
}
time.Sleep(time.Millisecond)
}
_, _, err := w.Append(3, bytes.NewBuffer([]byte("EOF")))
return err
}
func waitForEOF(t *testing.T, w *WAL) {
t.Helper()
h := w.seg.Header()
it, err := w.Iterator(h.FirstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it.Close()
for it.Next(time.Hour) {
rec := it.Record()
buf := make([]byte, rec.DataSize)
if _, err := io.ReadFull(rec.Reader, buf); err != nil {
t.Fatal(err)
}
if bytes.Equal(buf, []byte("EOF")) {
return
}
}
t.Fatal("waitForEOF", it.Error())
}
func checkWALsEqual(t *testing.T, w1, w2 *WAL) {
t.Helper()
info1 := w1.Info()
info2 := w2.Info()
if !reflect.DeepEqual(info1, info2) {
t.Fatal(info1, info2)
}
it1, err := w1.Iterator(info1.FirstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it1.Close()
it2, err := w2.Iterator(info2.FirstSeqNum)
if err != nil {
t.Fatal(err)
}
defer it2.Close()
for {
ok1 := it1.Next(time.Second)
ok2 := it2.Next(time.Second)
if ok1 != ok2 {
t.Fatal(ok1, ok2)
}
if !ok1 {
return
}
rec1 := it1.Record()
rec2 := it2.Record()
data1, err := io.ReadAll(rec1.Reader)
if err != nil {
t.Fatal(err)
}
data2, err := io.ReadAll(rec2.Reader)
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(data1, data2) {
t.Fatal(data1, data2)
}
}
}

25
lib/wal/wal-header.go Normal file
View File

@@ -0,0 +1,25 @@
package wal
import "encoding/binary"
type walHeader struct {
FirstSegmentID int64
LastSegmentID int64
}
func (h walHeader) WriteTo(b []byte) {
vals := []int64{h.FirstSegmentID, h.LastSegmentID}
for _, val := range vals {
binary.LittleEndian.PutUint64(b[0:8], uint64(val))
b = b[8:]
}
}
func (h *walHeader) ReadFrom(b []byte) {
ptrs := []*int64{&h.FirstSegmentID, &h.LastSegmentID}
for _, ptr := range ptrs {
*ptr = int64(binary.LittleEndian.Uint64(b[0:8]))
b = b[8:]
}
}

88
lib/wal/wal-iterator.go Normal file
View File

@@ -0,0 +1,88 @@
package wal
import (
"git.crumpington.com/public/jldb/lib/errs"
"time"
)
type walIterator struct {
// getSeg should return a segment given its ID, or return nil.
getSeg func(id int64) (*segment, error)
seg *segment // Our current segment.
it Iterator // Our current segment iterator.
seqNum int64
err error
}
func newWALIterator(
getSeg func(id int64) (*segment, error),
seg *segment,
fromSeqNum int64,
) (
*walIterator,
error,
) {
segIter, err := seg.Iterator(fromSeqNum)
if err != nil {
return nil, err
}
return &walIterator{
getSeg: getSeg,
seg: seg,
it: segIter,
seqNum: fromSeqNum,
}, nil
}
func (it *walIterator) Next(timeout time.Duration) bool {
if it.err != nil {
return false
}
if it.it.Next(timeout) {
it.seqNum++
return true
}
it.err = it.it.Error()
if !errs.EOFArchived.Is(it.err) {
return false
}
it.it.Close()
id := it.seg.ID + 1
it.seg, it.err = it.getSeg(id)
if it.err != nil {
return false
}
if it.seg == nil {
it.err = errs.NotFound // Could be not-found, or closed.
return false
}
it.it, it.err = it.seg.Iterator(it.seqNum)
if it.err != nil {
return false
}
return it.Next(timeout)
}
func (it *walIterator) Record() Record {
return it.it.Record()
}
func (it *walIterator) Error() error {
return it.err
}
func (it *walIterator) Close() {
if it.it != nil {
it.it.Close()
}
it.it = nil
}

60
lib/wal/wal-recv.go Normal file
View File

@@ -0,0 +1,60 @@
package wal
import (
"encoding/binary"
"io"
"git.crumpington.com/public/jldb/lib/errs"
"net"
"time"
)
func (wal *WAL) Recv(conn net.Conn, timeout time.Duration) error {
defer conn.Close()
var (
rec Record
msgType = make([]byte, 1)
)
// Send sequence number.
seqNum := wal.Info().LastSeqNum + 1
conn.SetWriteDeadline(time.Now().Add(timeout))
if err := binary.Write(conn, binary.LittleEndian, seqNum); err != nil {
return errs.IO.WithErr(err)
}
conn.SetWriteDeadline(time.Time{})
for {
conn.SetReadDeadline(time.Now().Add(timeout))
if _, err := io.ReadFull(conn, msgType); err != nil {
return errs.IO.WithErr(err)
}
switch msgType[0] {
case msgTypeHeartbeat:
// Nothing to do.
case msgTypeError:
e := &errs.Error{}
if err := e.Read(conn); err != nil {
return err
}
return e
case msgTypeRec:
if err := rec.readFrom(conn); err != nil {
return err
}
if _, _, err := wal.appendRecord(rec); err != nil {
return err
}
default:
return errs.Unexpected.WithMsg("Unknown message type: %d", msgType[0])
}
}
}

73
lib/wal/wal-send.go Normal file
View File

@@ -0,0 +1,73 @@
package wal
import (
"encoding/binary"
"git.crumpington.com/public/jldb/lib/errs"
"net"
"time"
)
const (
msgTypeRec = 8
msgTypeHeartbeat = 16
msgTypeError = 32
)
func (wal *WAL) Send(conn net.Conn, timeout time.Duration) error {
defer conn.Close()
var (
seqNum int64
heartbeatTimeout = timeout / 8
)
conn.SetReadDeadline(time.Now().Add(timeout))
if err := binary.Read(conn, binary.LittleEndian, &seqNum); err != nil {
return errs.IO.WithErr(err)
}
conn.SetReadDeadline(time.Time{})
it, err := wal.Iterator(seqNum)
if err != nil {
return err
}
defer it.Close()
for {
if it.Next(heartbeatTimeout) {
rec := it.Record()
conn.SetWriteDeadline(time.Now().Add(timeout))
if _, err := conn.Write([]byte{msgTypeRec}); err != nil {
return errs.IO.WithErr(err)
}
if _, err := rec.writeTo(conn); err != nil {
return err
}
continue
}
if it.Error() != nil {
conn.SetWriteDeadline(time.Now().Add(timeout))
if _, err := conn.Write([]byte{msgTypeError}); err != nil {
return errs.IO.WithErr(err)
}
err, ok := it.Error().(*errs.Error)
if !ok {
err = errs.Unexpected.WithErr(err)
}
err.Write(conn)
// w.Flush()
return err
}
conn.SetWriteDeadline(time.Now().Add(timeout))
if _, err := conn.Write([]byte{msgTypeHeartbeat}); err != nil {
return errs.IO.WithErr(err)
}
}
}

View File

@@ -0,0 +1,271 @@
package wal
import (
"git.crumpington.com/public/jldb/lib/errs"
"git.crumpington.com/public/jldb/lib/testutil"
"log"
"math/rand"
"reflect"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestSendRecvHarness(t *testing.T) {
t.Parallel()
(&SendRecvTestHarness{}).Run(t)
}
type SendRecvTestHarness struct{}
func (h *SendRecvTestHarness) Run(t *testing.T) {
val := reflect.ValueOf(h)
typ := val.Type()
for i := 0; i < typ.NumMethod(); i++ {
method := typ.Method(i)
if !strings.HasPrefix(method.Name, "Test") {
continue
}
t.Run(method.Name, func(t *testing.T) {
t.Parallel()
pDir := t.TempDir()
sDir := t.TempDir()
config := Config{
SegMinCount: 8,
SegMaxAgeSec: 1,
}
pWAL, err := Create(pDir, 1, config)
if err != nil {
t.Fatal(err)
}
defer pWAL.Close()
sWAL, err := Create(sDir, 1, config)
if err != nil {
t.Fatal(err)
}
defer sWAL.Close()
nw := testutil.NewNetwork()
defer func() {
nw.CloseServer()
nw.CloseClient()
}()
val.MethodByName(method.Name).Call([]reflect.Value{
reflect.ValueOf(t),
reflect.ValueOf(pWAL),
reflect.ValueOf(sWAL),
reflect.ValueOf(nw),
})
})
}
}
func (h *SendRecvTestHarness) TestSimple(
t *testing.T,
pWAL *WAL,
sWAL *WAL,
nw *testutil.Network,
) {
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
if err := writeRandomWithEOF(pWAL, 5*time.Second); err != nil {
panic(err)
}
}()
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Accept()
if err := pWAL.Send(conn, 8*time.Second); err != nil {
log.Printf("Send error: %v", err)
}
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Dial()
if err := sWAL.Recv(conn, 8*time.Second); err != nil {
log.Printf("Recv error: %v", err)
}
}()
waitForEOF(t, sWAL)
nw.CloseServer()
nw.CloseClient()
wg.Wait()
checkWALsEqual(t, pWAL, sWAL)
}
func (h *SendRecvTestHarness) TestWriteThenRead(
t *testing.T,
pWAL *WAL,
sWAL *WAL,
nw *testutil.Network,
) {
wg := sync.WaitGroup{}
if err := writeRandomWithEOF(pWAL, 2*time.Second); err != nil {
t.Fatal(err)
}
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Accept()
if err := pWAL.Send(conn, 8*time.Second); err != nil {
log.Printf("Send error: %v", err)
}
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Dial()
if err := sWAL.Recv(conn, 8*time.Second); err != nil {
log.Printf("Recv error: %v", err)
}
}()
waitForEOF(t, sWAL)
nw.CloseServer()
nw.CloseClient()
wg.Wait()
checkWALsEqual(t, pWAL, sWAL)
}
func (h *SendRecvTestHarness) TestNetworkFailures(
t *testing.T,
pWAL *WAL,
sWAL *WAL,
nw *testutil.Network,
) {
recvDone := &atomic.Bool{}
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
writeRandomWithEOF(pWAL, 10*time.Second)
}()
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
for {
if recvDone.Load() {
return
}
if conn := nw.Accept(); conn != nil {
pWAL.Send(conn, 8*time.Second)
}
}
}()
// Recv in the background.
wg.Add(1)
go func() {
defer wg.Done()
for !recvDone.Load() {
if conn := nw.Dial(); conn != nil {
sWAL.Recv(conn, 8*time.Second)
}
}
}()
wg.Add(1)
failureCount := 0
go func() {
defer wg.Done()
for {
if recvDone.Load() {
return
}
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
failureCount++
if rand.Float64() < 0.5 {
nw.CloseClient()
} else {
nw.CloseServer()
}
}
}()
waitForEOF(t, sWAL)
recvDone.Store(true)
wg.Wait()
log.Printf("%d network failures.", failureCount)
if failureCount < 10 {
t.Fatal("Expected more failures.")
}
checkWALsEqual(t, pWAL, sWAL)
}
func (h *SendRecvTestHarness) TestSenderClose(
t *testing.T,
pWAL *WAL,
sWAL *WAL,
nw *testutil.Network,
) {
wg := sync.WaitGroup{}
wg.Add(1)
go func() {
defer wg.Done()
if err := writeRandomWithEOF(pWAL, 5*time.Second); !errs.Closed.Is(err) {
panic(err)
}
}()
// Close primary after some time.
wg.Add(1)
go func() {
defer wg.Done()
time.Sleep(time.Second)
pWAL.Close()
}()
// Send in the background.
wg.Add(1)
go func() {
defer wg.Done()
conn := nw.Accept()
if err := pWAL.Send(conn, 8*time.Second); err != nil {
log.Printf("Send error: %v", err)
}
}()
conn := nw.Dial()
if err := sWAL.Recv(conn, 8*time.Second); !errs.Closed.Is(err) {
t.Fatal(err)
}
nw.CloseServer()
nw.CloseClient()
wg.Wait()
}

321
lib/wal/wal.go Normal file
View File

@@ -0,0 +1,321 @@
package wal
import (
"io"
"git.crumpington.com/public/jldb/lib/atomicheader"
"git.crumpington.com/public/jldb/lib/errs"
"os"
"path/filepath"
"strconv"
"sync"
"time"
)
type Config struct {
SegMinCount int64
SegMaxAgeSec int64
}
type WAL struct {
rootDir string
conf Config
lock sync.Mutex // Protects the fields below.
closed bool
header walHeader
headerWriter *atomicheader.Handler
f *os.File // WAL header.
segments map[int64]*segment // Used by the iterator.
seg *segment // Current segment.
}
func Create(rootDir string, firstSeqNum int64, conf Config) (*WAL, error) {
w := &WAL{rootDir: rootDir, conf: conf}
seg, err := createSegment(w.segmentPath(1), 1, firstSeqNum, 0)
if err != nil {
return nil, err
}
defer seg.Close()
f, err := os.Create(w.headerPath())
if err != nil {
return nil, errs.IO.WithErr(err)
}
defer f.Close()
if err := atomicheader.Init(f); err != nil {
return nil, err
}
handler, err := atomicheader.Open(f)
if err != nil {
return nil, err
}
header := walHeader{
FirstSegmentID: 1,
LastSegmentID: 1,
}
err = handler.Write(func(page []byte) error {
header.WriteTo(page)
return nil
})
if err != nil {
return nil, err
}
return Open(rootDir, conf)
}
func Open(rootDir string, conf Config) (*WAL, error) {
w := &WAL{rootDir: rootDir, conf: conf}
f, err := os.OpenFile(w.headerPath(), os.O_RDWR, 0600)
if err != nil {
return nil, errs.IO.WithErr(err)
}
handler, err := atomicheader.Open(f)
if err != nil {
f.Close()
return nil, err
}
var header walHeader
err = handler.Read(func(page []byte) error {
header.ReadFrom(page)
return nil
})
if err != nil {
f.Close()
return nil, err
}
w.header = header
w.headerWriter = handler
w.f = f
w.segments = map[int64]*segment{}
for segID := header.FirstSegmentID; segID < header.LastSegmentID+1; segID++ {
segID := segID
seg, err := openSegment(w.segmentPath(segID), segID)
if err != nil {
w.Close()
return nil, err
}
w.segments[segID] = seg
}
w.seg = w.segments[header.LastSegmentID]
if err := w.grow(); err != nil {
w.Close()
return nil, err
}
return w, nil
}
func (w *WAL) Close() error {
w.lock.Lock()
defer w.lock.Unlock()
if w.closed {
return nil
}
w.closed = true
for _, seg := range w.segments {
seg.Close()
delete(w.segments, seg.ID)
}
w.f.Close()
return nil
}
func (w *WAL) Info() (info Info) {
w.lock.Lock()
defer w.lock.Unlock()
h := w.header
info.FirstSeqNum = w.segments[h.FirstSegmentID].Header().FirstSeqNum
lastHeader := w.segments[h.LastSegmentID].Header()
info.LastSeqNum = lastHeader.LastSeqNum
info.LastTimestampMS = lastHeader.LastTimestampMS
return
}
func (w *WAL) Append(dataSize int64, r io.Reader) (int64, int64, error) {
return w.appendRecord(Record{
SeqNum: -1,
TimestampMS: time.Now().UnixMilli(),
DataSize: dataSize,
Reader: r,
})
}
func (w *WAL) appendRecord(rec Record) (int64, int64, error) {
w.lock.Lock()
defer w.lock.Unlock()
if w.closed {
return 0, 0, errs.Closed
}
if err := w.grow(); err != nil {
return 0, 0, err
}
return w.seg.appendRecord(rec)
}
func (w *WAL) Iterator(fromSeqNum int64) (Iterator, error) {
w.lock.Lock()
defer w.lock.Unlock()
if w.closed {
return nil, errs.Closed
}
header := w.header
var seg *segment
getSeg := func(id int64) (*segment, error) {
w.lock.Lock()
defer w.lock.Unlock()
if w.closed {
return nil, errs.Closed
}
return w.segments[id], nil
}
if fromSeqNum == -1 {
seg = w.segments[header.FirstSegmentID]
return newWALIterator(getSeg, seg, fromSeqNum)
}
// Seek to the appropriate segment.
seg = w.segments[header.FirstSegmentID]
for seg != nil {
h := seg.Header()
if fromSeqNum >= h.FirstSeqNum && fromSeqNum <= h.LastSeqNum+1 {
return newWALIterator(getSeg, seg, fromSeqNum)
}
seg = w.segments[seg.ID+1]
}
return nil, errs.NotFound
}
func (w *WAL) DeleteBefore(timestamp, keepSeqNum int64) error {
for {
seg, err := w.removeSeg(timestamp, keepSeqNum)
if err != nil || seg == nil {
return err
}
id := seg.ID
os.RemoveAll(w.segmentPath(id))
seg.Close()
}
}
func (w *WAL) removeSeg(timestamp, keepSeqNum int64) (*segment, error) {
w.lock.Lock()
defer w.lock.Unlock()
header := w.header
if header.FirstSegmentID == header.LastSegmentID {
return nil, nil // Nothing to delete now.
}
id := header.FirstSegmentID
seg := w.segments[id]
if seg == nil {
return nil, errs.Unexpected.WithMsg("segment %d not found", id)
}
segHeader := seg.Header()
if seg == w.seg || segHeader.ArchivedAt > timestamp {
return nil, nil // Nothing to delete now.
}
if segHeader.LastSeqNum >= keepSeqNum {
return nil, nil
}
header.FirstSegmentID = id + 1
err := w.headerWriter.Write(func(page []byte) error {
header.WriteTo(page)
return nil
})
if err != nil {
return nil, err
}
w.header = header
delete(w.segments, id)
return seg, nil
}
func (w *WAL) grow() error {
segHeader := w.seg.Header()
if segHeader.ArchivedAt == 0 {
if (segHeader.LastSeqNum - segHeader.FirstSeqNum) < w.conf.SegMinCount {
return nil
}
if time.Now().Unix()-segHeader.CreatedAt < w.conf.SegMaxAgeSec {
return nil
}
}
newSegID := w.seg.ID + 1
firstSeqNum := segHeader.LastSeqNum + 1
timestampMS := segHeader.LastTimestampMS
newSeg, err := createSegment(w.segmentPath(newSegID), newSegID, firstSeqNum, timestampMS)
if err != nil {
return err
}
walHeader := w.header
walHeader.LastSegmentID = newSegID
err = w.headerWriter.Write(func(page []byte) error {
walHeader.WriteTo(page)
return nil
})
if err != nil {
newSeg.Close()
return err
}
if err := w.seg.Archive(); err != nil {
newSeg.Close()
return err
}
w.seg = newSeg
w.segments[newSeg.ID] = newSeg
w.header = walHeader
return nil
}
func (w *WAL) headerPath() string {
return filepath.Join(w.rootDir, "header")
}
func (w *WAL) segmentPath(segID int64) string {
return filepath.Join(w.rootDir, "seg."+strconv.FormatInt(segID, 10))
}