package wal import ( "encoding/binary" "errors" "hash/crc32" "io" "git.crumpington.com/public/jldb/lib/errs" ) func ioErrOrEOF(err error) error { if err == nil { return nil } if errors.Is(err, io.EOF) { return err } return errs.IO.WithErr(err) } // ---------------------------------------------------------------------------- type readAtReader struct { f io.ReaderAt offset int64 } func readerAtToReader(f io.ReaderAt, offset int64) io.Reader { return &readAtReader{f: f, offset: offset} } func (r *readAtReader) Read(b []byte) (int, error) { n, err := r.f.ReadAt(b, r.offset) r.offset += int64(n) return n, ioErrOrEOF(err) } // ---------------------------------------------------------------------------- type writeAtWriter struct { w io.WriterAt offset int64 } func writerAtToWriter(w io.WriterAt, offset int64) io.Writer { return &writeAtWriter{w: w, offset: offset} } func (w *writeAtWriter) Write(b []byte) (int, error) { n, err := w.w.WriteAt(b, w.offset) w.offset += int64(n) return n, ioErrOrEOF(err) } // ---------------------------------------------------------------------------- type crcWriter struct { w io.Writer crc uint32 } func newCRCWriter(w io.Writer) *crcWriter { return &crcWriter{w: w} } func (w *crcWriter) Write(b []byte) (int, error) { n, err := w.w.Write(b) w.crc = crc32.Update(w.crc, crc32.IEEETable, b[:n]) return n, ioErrOrEOF(err) } func (w *crcWriter) CRC() uint32 { return w.crc } // ---------------------------------------------------------------------------- type dataReader struct { r io.Reader remaining int64 crc uint32 } func newDataReader(r io.Reader, dataSize int64) *dataReader { return &dataReader{r: r, remaining: dataSize} } func (r *dataReader) Read(b []byte) (int, error) { if r.remaining == 0 { return 0, io.EOF } if int64(len(b)) > r.remaining { b = b[:r.remaining] } n, err := r.r.Read(b) r.crc = crc32.Update(r.crc, crc32.IEEETable, b[:n]) r.remaining -= int64(n) if r.remaining == 0 { if err := r.checkCRC(); err != nil { return n, err } } if err != nil && !errors.Is(err, io.EOF) { return n, errs.IO.WithErr(err) } return n, nil } func (r *dataReader) checkCRC() error { buf := make([]byte, 4) if _, err := r.r.Read(buf); err != nil { return errs.Corrupt.WithErr(err) } crc := binary.LittleEndian.Uint32(buf) if crc != r.crc { return errs.Corrupt.WithMsg("crc mismatch") } return nil }