package wal import ( "bytes" "io" "git.crumpington.com/public/jldb/lib/errs" "git.crumpington.com/public/jldb/lib/testutil" "math/rand" "testing" ) func NewRecordForTesting() Record { data := randData() return Record{ SeqNum: rand.Int63(), TimestampMS: rand.Int63(), DataSize: int64(len(data)), Reader: bytes.NewBuffer(data), } } func AssertRecordHeadersEqual(t *testing.T, r1, r2 Record) { t.Helper() eq := r1.SeqNum == r2.SeqNum && r1.TimestampMS == r2.TimestampMS && r1.DataSize == r2.DataSize if !eq { t.Fatal(r1, r2) } } func TestRecordWriteHeaderToReadHeaderFrom(t *testing.T) { t.Parallel() rec1 := NewRecordForTesting() b := &bytes.Buffer{} n, err := rec1.writeHeaderTo(b) if err != nil { t.Fatal(err) } if n != recordHeaderSize { t.Fatal(n) } rec2 := Record{} if err := rec2.readHeaderFrom(b); err != nil { t.Fatal(err) } AssertRecordHeadersEqual(t, rec1, rec2) } func TestRecordWriteHeaderToEOF(t *testing.T) { t.Parallel() rec := NewRecordForTesting() for limit := 1; limit < recordHeaderSize; limit++ { buf := &bytes.Buffer{} w := testutil.NewLimitWriter(buf, limit) n, err := rec.writeHeaderTo(w) if !errs.IO.Is(err) { t.Fatal(limit, n, err) } } } func TestRecordReadHeaderFromError(t *testing.T) { t.Parallel() rec := NewRecordForTesting() for limit := 1; limit < recordHeaderSize; limit++ { b := &bytes.Buffer{} if _, err := rec.writeHeaderTo(b); err != nil { t.Fatal(err) } r := io.LimitReader(b, int64(limit)) if err := rec.readFrom(r); !errs.IO.Is(err) { t.Fatal(err) } } } func TestRecordReadHeaderFromCorrupt(t *testing.T) { t.Parallel() rec := NewRecordForTesting() b := &bytes.Buffer{} for i := 0; i < recordHeaderSize; i++ { if _, err := rec.writeHeaderTo(b); err != nil { t.Fatal(err) } b.Bytes()[i]++ if err := rec.readHeaderFrom(b); !errs.Corrupt.Is(err) { t.Fatal(err) } } } func TestRecordWriteToReadFrom(t *testing.T) { t.Parallel() r1 := NewRecordForTesting() data := randData() r1.Reader = bytes.NewBuffer(bytes.Clone(data)) r1.DataSize = int64(len(data)) r2 := Record{} b := &bytes.Buffer{} if _, err := r1.writeTo(b); err != nil { t.Fatal(err) } if err := r2.readFrom(b); err != nil { t.Fatal(err) } AssertRecordHeadersEqual(t, r1, r2) data2, err := io.ReadAll(r2.Reader) if err != nil { t.Fatal(err) } if !bytes.Equal(data, data2) { t.Fatal(data, data2) } } func TestRecordReadFromCorrupt(t *testing.T) { t.Parallel() data := randData() r1 := NewRecordForTesting() for i := 0; i < int(r1.serializedSize()); i++ { r1.Reader = bytes.NewBuffer(data) r1.DataSize = int64(len(data)) buf := &bytes.Buffer{} r1.writeTo(buf) buf.Bytes()[i]++ r2 := Record{} if err := r2.readFrom(buf); err != nil { if !errs.Corrupt.Is(err) { t.Fatal(i, err) } continue // OK. } if _, err := io.ReadAll(r2.Reader); !errs.Corrupt.Is(err) { t.Fatal(err) } } } func TestRecordWriteToError(t *testing.T) { t.Parallel() data := randData() r1 := NewRecordForTesting() r1.Reader = bytes.NewBuffer(data) r1.DataSize = int64(len(data)) for i := 0; i < int(r1.serializedSize()); i++ { w := testutil.NewLimitWriter(&bytes.Buffer{}, i) r1.Reader = bytes.NewBuffer(data) if _, err := r1.writeTo(w); !errs.IO.Is(err) { t.Fatal(err) } } }