Initial commit
This commit is contained in:
136
lib/atomicheader/atomicheader.go
Normal file
136
lib/atomicheader/atomicheader.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package atomicheader
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"os"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
PageSize = 512
|
||||
AvailabePageSize = 508
|
||||
|
||||
ReservedBytes = PageSize * 4
|
||||
|
||||
offsetSwitch = 1 * PageSize
|
||||
offset1 = 2 * PageSize
|
||||
offset2 = 3 * PageSize
|
||||
)
|
||||
|
||||
type Handler struct {
|
||||
lock sync.Mutex
|
||||
switchPage []byte // At offsetSwitch.
|
||||
page []byte // Page buffer is re-used for reading and writing.
|
||||
|
||||
currentPage int64 // Either 0 or 1.
|
||||
f *os.File
|
||||
}
|
||||
|
||||
func Init(f *os.File) error {
|
||||
if err := f.Truncate(ReservedBytes); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
switchPage := make([]byte, PageSize)
|
||||
switchPage[0] = 2
|
||||
if _, err := f.WriteAt(switchPage, offsetSwitch); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Open(f *os.File) (*Handler, error) {
|
||||
switchPage := make([]byte, PageSize)
|
||||
|
||||
if _, err := f.ReadAt(switchPage, offsetSwitch); err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
h := &Handler{
|
||||
switchPage: switchPage,
|
||||
page: make([]byte, PageSize),
|
||||
currentPage: int64(switchPage[0]),
|
||||
f: f,
|
||||
}
|
||||
|
||||
if h.currentPage != 1 && h.currentPage != 2 {
|
||||
return nil, errs.Corrupt.WithMsg("invalid page id: %d", h.currentPage)
|
||||
}
|
||||
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// Read reads the currently active header page.
|
||||
func (h *Handler) Read(read func(page []byte) error) error {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
if _, err := h.f.ReadAt(h.page, h.currentOffset()); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
computedCRC := crc32.ChecksumIEEE(h.page[:PageSize-4])
|
||||
storedCRC := binary.LittleEndian.Uint32(h.page[PageSize-4:])
|
||||
if computedCRC != storedCRC {
|
||||
return errs.Corrupt.WithMsg("checksum mismatch")
|
||||
}
|
||||
|
||||
return read(h.page)
|
||||
}
|
||||
|
||||
// Write writes the currently active header page. The page buffer given to the
|
||||
// function may contain old data, so the caller may need to zero some bytes if
|
||||
// necessary.
|
||||
func (h *Handler) Write(update func(page []byte) error) error {
|
||||
h.lock.Lock()
|
||||
defer h.lock.Unlock()
|
||||
|
||||
if err := update(h.page); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
crc := crc32.ChecksumIEEE(h.page[:PageSize-4])
|
||||
binary.LittleEndian.PutUint32(h.page[PageSize-4:], crc)
|
||||
|
||||
newPageNum := 1 + h.currentPage%2
|
||||
newOffset := h.getOffset(newPageNum)
|
||||
|
||||
if _, err := h.f.WriteAt(h.page, newOffset); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
if err := h.f.Sync(); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
h.switchPage[0] = byte(newPageNum)
|
||||
if _, err := h.f.WriteAt(h.switchPage, offsetSwitch); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
if err := h.f.Sync(); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
h.currentPage = newPageNum
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (h *Handler) currentOffset() int64 {
|
||||
return h.getOffset(h.currentPage)
|
||||
}
|
||||
|
||||
func (h *Handler) getOffset(pageNum int64) int64 {
|
||||
switch pageNum {
|
||||
case 1:
|
||||
return offset1
|
||||
case 2:
|
||||
return offset2
|
||||
default:
|
||||
panic("Invalid page number.")
|
||||
}
|
||||
}
|
||||
121
lib/atomicheader/atomicheader_test.go
Normal file
121
lib/atomicheader/atomicheader_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package atomicheader
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func NewForTesting(t *testing.T) (*Handler, func()) {
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
f, err := os.Create(filepath.Join(tmpDir, "h"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Init(f); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
h, err := Open(f)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return h, func() {
|
||||
f.Close()
|
||||
os.RemoveAll(tmpDir)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicHeaderSimple(t *testing.T) {
|
||||
h, cleanup := NewForTesting(t)
|
||||
defer cleanup()
|
||||
|
||||
err := h.Write(func(page []byte) error {
|
||||
for i := range page[:AvailabePageSize] {
|
||||
page[i] = byte(i) % 11
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = h.Read(func(page []byte) error {
|
||||
for i := range page[:AvailabePageSize] {
|
||||
if page[i] != byte(i)%11 {
|
||||
t.Fatal(i, page[i], byte(i)%11)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAtomicHeaderThreaded(t *testing.T) {
|
||||
h, cleanup := NewForTesting(t)
|
||||
defer cleanup()
|
||||
|
||||
expectedValue := byte(0)
|
||||
|
||||
writeErr := make(chan error, 1)
|
||||
stop := make(chan struct{})
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-stop:
|
||||
writeErr <- nil
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
err := h.Write(func(page []byte) error {
|
||||
if page[0] != expectedValue {
|
||||
return errors.New("Unexpected current value.")
|
||||
}
|
||||
|
||||
expectedValue++
|
||||
page[0] = expectedValue
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
writeErr <- err
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond / 13)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 2000; i++ {
|
||||
time.Sleep(time.Millisecond)
|
||||
err := h.Read(func(page []byte) error {
|
||||
if page[0] != expectedValue {
|
||||
t.Fatal(page[0], expectedValue)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
close(stop)
|
||||
wg.Wait()
|
||||
|
||||
if err := <-writeErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
121
lib/errs/error.go
Normal file
121
lib/errs/error.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"runtime/debug"
|
||||
)
|
||||
|
||||
type Error struct {
|
||||
msg string
|
||||
code int64
|
||||
collection string
|
||||
index string
|
||||
stackTrace string
|
||||
err error // Wrapped error
|
||||
}
|
||||
|
||||
func NewErr(code int64, msg string) *Error {
|
||||
return &Error{
|
||||
msg: msg,
|
||||
code: code,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
if e.collection != "" || e.index != "" {
|
||||
return fmt.Sprintf(`[%d] (%s/%s) %s`, e.code, e.collection, e.index, e.msg)
|
||||
} else {
|
||||
return fmt.Sprintf("[%d] %s", e.code, e.msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (e *Error) Is(rhs error) bool {
|
||||
e2, ok := rhs.(*Error)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return e.code == e2.code
|
||||
}
|
||||
|
||||
func (e *Error) WithErr(err error) *Error {
|
||||
if e2, ok := err.(*Error); ok && e2.code == e.code {
|
||||
return e2
|
||||
}
|
||||
|
||||
e2 := e.WithMsg(err.Error())
|
||||
e2.err = err
|
||||
return e2
|
||||
}
|
||||
|
||||
func (e *Error) Unwrap() error {
|
||||
if e.err != nil {
|
||||
return e.err
|
||||
}
|
||||
return e
|
||||
}
|
||||
|
||||
func (e *Error) WithMsg(msg string, args ...any) *Error {
|
||||
err := *e
|
||||
err.msg += ": " + fmt.Sprintf(msg, args...)
|
||||
if len(err.stackTrace) == 0 {
|
||||
err.stackTrace = string(debug.Stack())
|
||||
}
|
||||
return &err
|
||||
}
|
||||
|
||||
func (e *Error) WithCollection(s string) *Error {
|
||||
err := *e
|
||||
err.collection = s
|
||||
return &err
|
||||
}
|
||||
|
||||
func (e *Error) WithIndex(s string) *Error {
|
||||
err := *e
|
||||
err.index = s
|
||||
return &err
|
||||
}
|
||||
|
||||
func (e *Error) msgTruncacted() string {
|
||||
if len(e.msg) > 255 {
|
||||
return e.msg[:255]
|
||||
}
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func (e *Error) Write(w io.Writer) error {
|
||||
msg := e.msgTruncacted()
|
||||
|
||||
if err := binary.Write(w, binary.LittleEndian, e.code); err != nil {
|
||||
return IO.WithErr(err)
|
||||
}
|
||||
|
||||
if _, err := w.Write([]byte{byte(len(msg))}); err != nil {
|
||||
return err
|
||||
}
|
||||
_, err := w.Write([]byte(msg))
|
||||
return err
|
||||
}
|
||||
|
||||
func (e *Error) Read(r io.Reader) error {
|
||||
var (
|
||||
size uint8
|
||||
)
|
||||
|
||||
if err := binary.Read(r, binary.LittleEndian, &e.code); err != nil {
|
||||
return IO.WithErr(err)
|
||||
}
|
||||
|
||||
if err := binary.Read(r, binary.LittleEndian, &size); err != nil {
|
||||
return IO.WithErr(err)
|
||||
}
|
||||
|
||||
msgBuf := make([]byte, size)
|
||||
if _, err := io.ReadFull(r, msgBuf); err != nil {
|
||||
return IO.WithErr(err)
|
||||
}
|
||||
|
||||
e.msg = string(msgBuf)
|
||||
return nil
|
||||
}
|
||||
26
lib/errs/error_test.go
Normal file
26
lib/errs/error_test.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package errs
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestError_Simple(t *testing.T) {
|
||||
e := Archived
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
if err := e.Write(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
e2 := &Error{}
|
||||
if err := e2.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(*e, *e2) {
|
||||
t.Fatal("not equal")
|
||||
}
|
||||
}
|
||||
21
lib/errs/errors.go
Normal file
21
lib/errs/errors.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package errs
|
||||
|
||||
var (
|
||||
Archived = NewErr(100, "archived")
|
||||
EOFArchived = NewErr(101, "EOF-archived")
|
||||
IO = NewErr(102, "IO error")
|
||||
NotFound = NewErr(103, "not found")
|
||||
Locked = NewErr(104, "locked")
|
||||
NotAuthorized = NewErr(105, "not authorized")
|
||||
NotAllowed = NewErr(106, "not allowed")
|
||||
Stopped = NewErr(107, "stopped")
|
||||
Timeout = NewErr(108, "timeout")
|
||||
Duplicate = NewErr(109, "duplicate")
|
||||
ReadOnly = NewErr(110, "read only")
|
||||
Encoding = NewErr(111, "encoding")
|
||||
Closed = NewErr(112, "closed")
|
||||
InvalidPath = NewErr(200, "invalid path")
|
||||
Corrupt = NewErr(666, "corrupt")
|
||||
Fatal = NewErr(1053, "fatal")
|
||||
Unexpected = NewErr(999, "unexpected")
|
||||
)
|
||||
22
lib/errs/fmt.go
Normal file
22
lib/errs/fmt.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package errs
|
||||
|
||||
import "fmt"
|
||||
|
||||
func FmtDetails(err error) string {
|
||||
e, ok := err.(*Error)
|
||||
if !ok {
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
var s string
|
||||
if e.collection != "" || e.index != "" {
|
||||
s = fmt.Sprintf(`[%d] (%s/%s) %s`, e.code, e.collection, e.index, e.msg)
|
||||
} else {
|
||||
s = fmt.Sprintf("[%d] %s", e.code, e.msg)
|
||||
}
|
||||
if len(e.stackTrace) != 0 {
|
||||
s += "\n\nStack Trace:\n" + e.stackTrace + "\n"
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
58
lib/flock/flock.go
Normal file
58
lib/flock/flock.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package flock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// Lock gets an exclusive lock on the file at the given path. If the file
|
||||
// doesn't exist, it's created.
|
||||
func Lock(path string) (*os.File, error) {
|
||||
return lock(path, unix.LOCK_EX)
|
||||
}
|
||||
|
||||
// TryLock will return a nil file if the file is already locked.
|
||||
func TryLock(path string) (*os.File, error) {
|
||||
return lock(path, unix.LOCK_EX|unix.LOCK_NB)
|
||||
}
|
||||
|
||||
func LockFile(f *os.File) error {
|
||||
_, err := lockFile(f, unix.LOCK_EX)
|
||||
return err
|
||||
}
|
||||
|
||||
// Returns true if the lock was successfully acquired.
|
||||
func TryLockFile(f *os.File) (bool, error) {
|
||||
return lockFile(f, unix.LOCK_EX|unix.LOCK_NB)
|
||||
}
|
||||
|
||||
func lockFile(f *os.File, flags int) (bool, error) {
|
||||
if err := unix.Flock(int(f.Fd()), flags); err != nil {
|
||||
if flags&unix.LOCK_NB != 0 && errors.Is(err, unix.EAGAIN) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func lock(path string, flags int) (*os.File, error) {
|
||||
perm := os.O_CREATE | os.O_RDWR
|
||||
f, err := os.OpenFile(path, perm, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ok, err := lockFile(f, flags)
|
||||
if err != nil || !ok {
|
||||
f.Close()
|
||||
f = nil
|
||||
}
|
||||
return f, err
|
||||
}
|
||||
|
||||
// Unlock releases the lock acquired via the Lock function.
|
||||
func Unlock(f *os.File) error {
|
||||
return f.Close()
|
||||
}
|
||||
85
lib/httpconn/client.go
Normal file
85
lib/httpconn/client.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package httpconn
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrInvalidStatus = errors.New("invalid status")
|
||||
|
||||
func Dial(rawURL string) (net.Conn, error) {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, errs.Unexpected.WithErr(err)
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "https":
|
||||
return DialHTTPS(u.Host+":443", u.Path)
|
||||
case "http":
|
||||
return DialHTTP(u.Host, u.Path)
|
||||
default:
|
||||
return nil, errs.Unexpected.WithMsg("Unknown scheme: " + u.Scheme)
|
||||
}
|
||||
}
|
||||
|
||||
func DialHTTPS(host, path string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
d := tls.Dialer{}
|
||||
conn, err := d.DialContext(ctx, "tcp", host)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
return finishDialing(conn, host, path)
|
||||
}
|
||||
|
||||
func DialHTTPSWithIP(ip, host, path string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
d := tls.Dialer{Config: &tls.Config{ServerName: host}}
|
||||
conn, err := d.DialContext(ctx, "tcp", ip)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
return finishDialing(conn, host, path)
|
||||
}
|
||||
|
||||
func DialHTTP(host, path string) (net.Conn, error) {
|
||||
conn, err := net.Dial("tcp", host)
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
return finishDialing(conn, host, path)
|
||||
}
|
||||
|
||||
func finishDialing(conn net.Conn, host, path string) (net.Conn, error) {
|
||||
conn.SetDeadline(time.Now().Add(10 * time.Second))
|
||||
|
||||
io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n")
|
||||
io.WriteString(conn, "Host: "+host+"\n\n")
|
||||
|
||||
// Require successful HTTP response before using the conn.
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
if resp.Status != "200 OK" {
|
||||
conn.Close()
|
||||
return nil, errs.IO.WithMsg("invalid status: %s", resp.Status)
|
||||
}
|
||||
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
42
lib/httpconn/conn_test.go
Normal file
42
lib/httpconn/conn_test.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package httpconn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
func TestNetTest_TestConn(t *testing.T) {
|
||||
nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
|
||||
|
||||
connCh := make(chan net.Conn, 1)
|
||||
doneCh := make(chan bool)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := Accept(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
connCh <- conn
|
||||
<-doneCh
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
|
||||
c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
c2 = <-connCh
|
||||
|
||||
return c1, c2, func() {
|
||||
doneCh <- true
|
||||
srv.Close()
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
32
lib/httpconn/server.go
Normal file
32
lib/httpconn/server.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package httpconn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
|
||||
if r.Method != "CONNECT" {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
io.WriteString(w, "405 must CONNECT\n")
|
||||
return nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n")
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
32
lib/idgen/gen.go
Normal file
32
lib/idgen/gen.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package idgen
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
lock sync.Mutex
|
||||
ts uint64 = uint64(time.Now().Unix())
|
||||
counter uint64 = 1
|
||||
counterMax uint64 = 1 << 28
|
||||
)
|
||||
|
||||
// Next can generate ~268M ints per second for ~1000 years.
|
||||
func Next() uint64 {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
tt := uint64(time.Now().Unix())
|
||||
if tt > ts {
|
||||
ts = tt
|
||||
counter = 1
|
||||
} else {
|
||||
counter++
|
||||
if counter == counterMax {
|
||||
panic("Too many IDs.")
|
||||
}
|
||||
}
|
||||
|
||||
return ts<<28 + counter
|
||||
}
|
||||
11
lib/idgen/gen_test.go
Normal file
11
lib/idgen/gen_test.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package idgen
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkNext(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
Next()
|
||||
}
|
||||
}
|
||||
51
lib/rep/functions.go
Normal file
51
lib/rep/functions.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"net"
|
||||
"path/filepath"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func lockFilePath(rootDir string) string {
|
||||
return filepath.Join(rootDir, "lock")
|
||||
}
|
||||
|
||||
func walRootDir(rootDir string) string {
|
||||
return filepath.Join(rootDir, "wal")
|
||||
}
|
||||
|
||||
func stateFilePath(rootDir string) string {
|
||||
return filepath.Join(rootDir, "state")
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func sendJSON(
|
||||
item any,
|
||||
conn net.Conn,
|
||||
timeout time.Duration,
|
||||
) error {
|
||||
|
||||
buf := bufPoolGet()
|
||||
defer bufPoolPut(buf)
|
||||
|
||||
if err := json.NewEncoder(buf).Encode(item); err != nil {
|
||||
return errs.Unexpected.WithErr(err)
|
||||
}
|
||||
|
||||
sizeBuf := make([]byte, 2)
|
||||
binary.LittleEndian.PutUint16(sizeBuf, uint16(buf.Len()))
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
buffers := net.Buffers{sizeBuf, buf.Bytes()}
|
||||
if _, err := buffers.WriteTo(conn); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
178
lib/rep/http-client.go
Normal file
178
lib/rep/http-client.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"git.crumpington.com/public/jldb/lib/httpconn"
|
||||
"git.crumpington.com/public/jldb/lib/wal"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type client struct {
|
||||
// Mutex-protected variables.
|
||||
lock sync.Mutex
|
||||
closed bool
|
||||
conn net.Conn
|
||||
|
||||
// The following are constant.
|
||||
endpoint string
|
||||
psk []byte
|
||||
timeout time.Duration
|
||||
|
||||
buf []byte
|
||||
}
|
||||
|
||||
func newClient(endpoint, psk string, timeout time.Duration) *client {
|
||||
b := make([]byte, 256)
|
||||
copy(b, []byte(psk))
|
||||
|
||||
return &client{
|
||||
endpoint: endpoint,
|
||||
psk: b,
|
||||
timeout: timeout,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) GetInfo() (info Info, err error) {
|
||||
err = c.withConn(cmdGetInfo, func(conn net.Conn) error {
|
||||
return c.recvJSON(&info, conn, c.timeout)
|
||||
})
|
||||
return info, err
|
||||
}
|
||||
|
||||
func (c *client) RecvState(recv func(net.Conn) error) error {
|
||||
return c.withConn(cmdSendState, recv)
|
||||
}
|
||||
|
||||
func (c *client) StreamWAL(w *wal.WAL) error {
|
||||
return c.withConn(cmdStreamWAL, func(conn net.Conn) error {
|
||||
return w.Recv(conn, c.timeout)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *client) Close() {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
c.closed = true
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (c *client) writeCmd(cmd byte) error {
|
||||
c.conn.SetWriteDeadline(time.Now().Add(c.timeout))
|
||||
if _, err := c.conn.Write([]byte{cmd}); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) dial() error {
|
||||
c.conn = nil
|
||||
|
||||
conn, err := httpconn.Dial(c.endpoint)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(c.timeout))
|
||||
if _, err := conn.Write(c.psk); err != nil {
|
||||
conn.Close()
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) withConn(cmd byte, fn func(net.Conn) error) error {
|
||||
conn, err := c.getConn(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := fn(conn); err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *client) getConn(cmd byte) (net.Conn, error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if c.closed {
|
||||
return nil, errs.IO.WithErr(io.EOF)
|
||||
}
|
||||
|
||||
dialed := false
|
||||
|
||||
if c.conn == nil {
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dialed = true
|
||||
}
|
||||
|
||||
if err := c.writeCmd(cmd); err != nil {
|
||||
if dialed {
|
||||
c.conn = nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.dial(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.writeCmd(cmd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c.conn, nil
|
||||
}
|
||||
|
||||
func (c *client) recvJSON(
|
||||
item any,
|
||||
conn net.Conn,
|
||||
timeout time.Duration,
|
||||
) error {
|
||||
|
||||
if cap(c.buf) < 2 {
|
||||
c.buf = make([]byte, 0, 1024)
|
||||
}
|
||||
buf := c.buf[:2]
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
size := binary.LittleEndian.Uint16(buf)
|
||||
|
||||
if cap(buf) < int(size) {
|
||||
buf = make([]byte, size)
|
||||
c.buf = buf
|
||||
}
|
||||
buf = buf[:size]
|
||||
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(buf, item); err != nil {
|
||||
return errs.Unexpected.WithErr(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
79
lib/rep/http-handler.go
Normal file
79
lib/rep/http-handler.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"crypto/subtle"
|
||||
"git.crumpington.com/public/jldb/lib/httpconn"
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
cmdGetInfo = 10
|
||||
cmdSendState = 20
|
||||
cmdStreamWAL = 30
|
||||
)
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
func (rep *Replicator) Handle(w http.ResponseWriter, r *http.Request) {
|
||||
logf := func(pattern string, args ...any) {
|
||||
log.Printf("[HTTP-HANDLER] "+pattern, args...)
|
||||
}
|
||||
|
||||
conn, err := httpconn.Accept(w, r)
|
||||
if err != nil {
|
||||
logf("Failed to accept connection: %s", err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
psk := make([]byte, 256)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(rep.conf.NetTimeout))
|
||||
if _, err := conn.Read(psk); err != nil {
|
||||
logf("Failed to read PSK: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
expected := rep.pskBytes
|
||||
if subtle.ConstantTimeCompare(expected, psk) != 1 {
|
||||
logf("PSK mismatch.")
|
||||
return
|
||||
}
|
||||
|
||||
cmd := make([]byte, 1)
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(rep.conf.NetTimeout))
|
||||
if _, err := conn.Read(cmd); err != nil {
|
||||
logf("Read failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
switch cmd[0] {
|
||||
|
||||
case cmdGetInfo:
|
||||
if err := sendJSON(rep.Info(), conn, rep.conf.NetTimeout); err != nil {
|
||||
logf("Failed to send info: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
case cmdSendState:
|
||||
|
||||
if err := rep.sendState(conn); err != nil {
|
||||
if !rep.stopped() {
|
||||
logf("Failed to send state: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
case cmdStreamWAL:
|
||||
err := rep.wal.Send(conn, rep.conf.NetTimeout)
|
||||
if !rep.stopped() {
|
||||
logf("Failed when sending WAL: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
9
lib/rep/info.go
Normal file
9
lib/rep/info.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package rep
|
||||
|
||||
type Info struct {
|
||||
AppSeqNum int64 // Page file sequence number.
|
||||
AppTimestampMS int64 // Page file timestamp.
|
||||
WALFirstSeqNum int64 // WAL min sequence number.
|
||||
WALLastSeqNum int64 // WAL max sequence number.
|
||||
WALLastTimestampMS int64 // WAL timestamp.
|
||||
}
|
||||
20
lib/rep/localstate.go
Normal file
20
lib/rep/localstate.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
type localState struct {
|
||||
SeqNum int64
|
||||
TimestampMS int64
|
||||
}
|
||||
|
||||
func (h localState) writeTo(b []byte) {
|
||||
binary.LittleEndian.PutUint64(b[0:8], uint64(h.SeqNum))
|
||||
binary.LittleEndian.PutUint64(b[8:16], uint64(h.TimestampMS))
|
||||
}
|
||||
|
||||
func (h *localState) readFrom(b []byte) {
|
||||
h.SeqNum = int64(binary.LittleEndian.Uint64(b[0:8]))
|
||||
h.TimestampMS = int64(binary.LittleEndian.Uint64(b[8:16]))
|
||||
}
|
||||
21
lib/rep/pools.go
Normal file
21
lib/rep/pools.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() any {
|
||||
return &bytes.Buffer{}
|
||||
},
|
||||
}
|
||||
|
||||
func bufPoolGet() *bytes.Buffer {
|
||||
return bufPool.Get().(*bytes.Buffer)
|
||||
}
|
||||
|
||||
func bufPoolPut(b *bytes.Buffer) {
|
||||
b.Reset()
|
||||
bufPool.Put(b)
|
||||
}
|
||||
41
lib/rep/rep-sendrecv.go
Normal file
41
lib/rep/rep-sendrecv.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (rep *Replicator) sendState(conn net.Conn) error {
|
||||
state := rep.getState()
|
||||
|
||||
buf := make([]byte, 512)
|
||||
state.writeTo(buf)
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(rep.conf.NetTimeout))
|
||||
if _, err := conn.Write(buf); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
conn.SetWriteDeadline(time.Time{})
|
||||
|
||||
return rep.app.SendState(conn)
|
||||
}
|
||||
|
||||
func (rep *Replicator) recvState(conn net.Conn) error {
|
||||
buf := make([]byte, 512)
|
||||
conn.SetReadDeadline(time.Now().Add(rep.conf.NetTimeout))
|
||||
if _, err := io.ReadFull(conn, buf); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
if err := rep.app.RecvState(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := localState{}
|
||||
state.readFrom(buf)
|
||||
|
||||
return rep.setState(state)
|
||||
}
|
||||
184
lib/rep/replicator-open.go
Normal file
184
lib/rep/replicator-open.go
Normal file
@@ -0,0 +1,184 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"git.crumpington.com/public/jldb/lib/atomicheader"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"git.crumpington.com/public/jldb/lib/flock"
|
||||
"git.crumpington.com/public/jldb/lib/wal"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (rep *Replicator) loadConfigDefaults() {
|
||||
conf := rep.conf
|
||||
|
||||
if conf.NetTimeout <= 0 {
|
||||
conf.NetTimeout = time.Minute
|
||||
}
|
||||
if conf.WALSegMinCount <= 0 {
|
||||
conf.WALSegMinCount = 1024
|
||||
}
|
||||
if conf.WALSegMaxAgeSec <= 0 {
|
||||
conf.WALSegMaxAgeSec = 3600
|
||||
}
|
||||
if conf.WALSegGCAgeSec <= 0 {
|
||||
conf.WALSegGCAgeSec = 7 * 86400
|
||||
}
|
||||
|
||||
rep.conf = conf
|
||||
|
||||
rep.pskBytes = make([]byte, 256)
|
||||
copy(rep.pskBytes, []byte(conf.ReplicationPSK))
|
||||
}
|
||||
|
||||
func (rep *Replicator) initDirectories() error {
|
||||
if err := os.MkdirAll(walRootDir(rep.conf.RootDir), 0700); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) acquireLock() error {
|
||||
lockFile, err := flock.TryLock(lockFilePath(rep.conf.RootDir))
|
||||
if err != nil {
|
||||
return errs.IO.WithMsg("locked: %s", lockFilePath(rep.conf.RootDir))
|
||||
}
|
||||
if lockFile == nil {
|
||||
return errs.Locked
|
||||
}
|
||||
rep.lockFile = lockFile
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) loadLocalState() error {
|
||||
f, err := os.OpenFile(stateFilePath(rep.conf.RootDir), os.O_RDWR|os.O_CREATE, 0600)
|
||||
if err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
if info.Size() < atomicheader.ReservedBytes {
|
||||
if err := atomicheader.Init(f); err != nil {
|
||||
f.Close()
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
}
|
||||
|
||||
rep.stateHandler, err = atomicheader.Open(f)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
rep.stateFile = f
|
||||
var state localState
|
||||
|
||||
err = rep.stateHandler.Read(func(page []byte) error {
|
||||
state.readFrom(page)
|
||||
return nil
|
||||
})
|
||||
if err == nil {
|
||||
rep.state.Store(&state)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write a clean state.
|
||||
state = localState{}
|
||||
rep.state.Store(&state)
|
||||
return rep.stateHandler.Write(func(page []byte) error {
|
||||
state.writeTo(page)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (rep *Replicator) walConfig() wal.Config {
|
||||
return wal.Config{
|
||||
SegMinCount: rep.conf.WALSegMinCount,
|
||||
SegMaxAgeSec: rep.conf.WALSegMaxAgeSec,
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *Replicator) openWAL() (err error) {
|
||||
rep.wal, err = wal.Open(walRootDir(rep.conf.RootDir), rep.walConfig())
|
||||
if err != nil {
|
||||
rep.wal, err = wal.Create(walRootDir(rep.conf.RootDir), 1, rep.walConfig())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) recvStateIfNecessary() error {
|
||||
if rep.conf.Primary {
|
||||
return nil
|
||||
}
|
||||
|
||||
sInfo := rep.Info()
|
||||
pInfo, err := rep.client.GetInfo()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if pInfo.WALFirstSeqNum <= sInfo.WALLastSeqNum {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Make a new WAL.
|
||||
rep.wal.Close()
|
||||
|
||||
if err = rep.client.RecvState(rep.recvState); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
state := rep.getState()
|
||||
|
||||
rep.wal, err = wal.Create(walRootDir(rep.conf.RootDir), state.SeqNum+1, rep.walConfig())
|
||||
return err
|
||||
}
|
||||
|
||||
// Replays un-acked entries in the WAL. Acks after all records are replayed.
|
||||
func (rep *Replicator) replay() error {
|
||||
state := rep.getState()
|
||||
it, err := rep.wal.Iterator(state.SeqNum + 1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for it.Next(0) {
|
||||
rec := it.Record()
|
||||
if err := rep.app.Replay(rec); err != nil {
|
||||
return err
|
||||
}
|
||||
state.SeqNum = rec.SeqNum
|
||||
state.TimestampMS = rec.TimestampMS
|
||||
}
|
||||
|
||||
if it.Error() != nil {
|
||||
return it.Error()
|
||||
}
|
||||
|
||||
return rep.ack(state.SeqNum, state.TimestampMS)
|
||||
}
|
||||
|
||||
func (rep *Replicator) startWALGC() {
|
||||
rep.done.Add(1)
|
||||
go rep.runWALGC()
|
||||
}
|
||||
|
||||
func (rep *Replicator) startWALFollower() {
|
||||
rep.done.Add(1)
|
||||
go rep.runWALFollower()
|
||||
}
|
||||
|
||||
func (rep *Replicator) startWALRecvr() {
|
||||
rep.done.Add(1)
|
||||
go rep.runWALRecvr()
|
||||
}
|
||||
66
lib/rep/replicator-walfollower.go
Normal file
66
lib/rep/replicator-walfollower.go
Normal file
@@ -0,0 +1,66 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (rep *Replicator) runWALFollower() {
|
||||
defer rep.done.Done()
|
||||
|
||||
for {
|
||||
rep.followOnce()
|
||||
|
||||
select {
|
||||
case <-rep.stop:
|
||||
return
|
||||
default:
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *Replicator) followOnce() {
|
||||
logf := func(pattern string, args ...any) {
|
||||
log.Printf("[WAL-FOLLOWER] "+pattern, args...)
|
||||
}
|
||||
|
||||
state := rep.getState()
|
||||
it, err := rep.wal.Iterator(state.SeqNum + 1)
|
||||
if err != nil {
|
||||
logf("Failed to create WAL iterator: %v", err)
|
||||
return
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-rep.stop:
|
||||
logf("Stopped")
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
if it.Next(time.Second) {
|
||||
rec := it.Record()
|
||||
|
||||
if err := rep.app.Apply(rec); err != nil {
|
||||
logf("App failed to apply change: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := rep.ack(rec.SeqNum, rec.TimestampMS); err != nil {
|
||||
logf("App failed to update local state: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case rep.appendNotify <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
} else if it.Error() != nil {
|
||||
logf("Iteration error: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
28
lib/rep/replicator-walgc.go
Normal file
28
lib/rep/replicator-walgc.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (rep *Replicator) runWALGC() {
|
||||
defer rep.done.Done()
|
||||
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ticker.C:
|
||||
state := rep.getState()
|
||||
before := time.Now().Unix() - rep.conf.WALSegMaxAgeSec
|
||||
if err := rep.wal.DeleteBefore(before, state.SeqNum); err != nil {
|
||||
log.Printf("[WAL-GC] failed to delete wal segments: %v", err)
|
||||
}
|
||||
// OK
|
||||
case <-rep.stop:
|
||||
log.Print("[WAL-GC] Stopped")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
38
lib/rep/replicator-walrecvr.go
Normal file
38
lib/rep/replicator-walrecvr.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (rep *Replicator) runWALRecvr() {
|
||||
go func() {
|
||||
<-rep.stop
|
||||
rep.client.Close()
|
||||
}()
|
||||
|
||||
defer rep.done.Done()
|
||||
|
||||
for {
|
||||
rep.runWALRecvrOnce()
|
||||
select {
|
||||
case <-rep.stop:
|
||||
log.Print("[WAL-RECVR] Stopped")
|
||||
return
|
||||
default:
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *Replicator) runWALRecvrOnce() {
|
||||
logf := func(pattern string, args ...any) {
|
||||
log.Printf("[WAL-RECVR] "+pattern, args...)
|
||||
}
|
||||
|
||||
if err := rep.client.StreamWAL(rep.wal); err != nil {
|
||||
if !rep.stopped() {
|
||||
logf("Recv failed: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
235
lib/rep/replicator.go
Normal file
235
lib/rep/replicator.go
Normal file
@@ -0,0 +1,235 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/atomicheader"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"git.crumpington.com/public/jldb/lib/wal"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
RootDir string
|
||||
Primary bool
|
||||
ReplicationPSK string
|
||||
NetTimeout time.Duration // Default is 1 minute.
|
||||
|
||||
// WAL settings.
|
||||
WALSegMinCount int64 // Minimum Change sets in a segment. Default is 1024.
|
||||
WALSegMaxAgeSec int64 // Maximum age of a segment. Default is 1 hour.
|
||||
WALSegGCAgeSec int64 // Segment age for garbage collection. Default is 7 days.
|
||||
|
||||
// If true, Append won't return until a successful App.Apply.
|
||||
SynchronousAppend bool
|
||||
|
||||
// Necessary for secondary.
|
||||
PrimaryEndpoint string
|
||||
}
|
||||
|
||||
type App struct {
|
||||
// SendState: The primary may need to send storage state to a secondary node.
|
||||
SendState func(conn net.Conn) error
|
||||
|
||||
// (1) RecvState: Secondary nodes may need to load state from the primary if the
|
||||
// WAL is too far behind.
|
||||
RecvState func(conn net.Conn) error
|
||||
|
||||
// (2) InitStorage: Prepare application storage for possible calls to
|
||||
// Replay.
|
||||
InitStorage func() error
|
||||
|
||||
// (3) Replay: write the change to storage. Replay must be idempotent.
|
||||
Replay func(rec wal.Record) error
|
||||
|
||||
// (4) LoadFromStorage: load the application's state from it's persistent
|
||||
// storage.
|
||||
LoadFromStorage func() error
|
||||
|
||||
// (5) Apply: write the change to persistent storage. Apply must be
|
||||
// idempotent. In normal operation each change is applied exactly once.
|
||||
Apply func(rec wal.Record) error
|
||||
}
|
||||
|
||||
type Replicator struct {
|
||||
app App
|
||||
conf Config
|
||||
|
||||
lockFile *os.File
|
||||
pskBytes []byte
|
||||
wal *wal.WAL
|
||||
|
||||
appendNotify chan struct{}
|
||||
|
||||
// lock protects state. The lock is held when replaying (R), following (R),
|
||||
// and sending state (W).
|
||||
stateFile *os.File
|
||||
state *atomic.Pointer[localState]
|
||||
stateHandler *atomicheader.Handler
|
||||
|
||||
stop chan struct{}
|
||||
done *sync.WaitGroup
|
||||
|
||||
client *client // For secondary connection to primary.
|
||||
}
|
||||
|
||||
func Open(app App, conf Config) (*Replicator, error) {
|
||||
rep := &Replicator{
|
||||
app: app,
|
||||
conf: conf,
|
||||
state: &atomic.Pointer[localState]{},
|
||||
stop: make(chan struct{}),
|
||||
done: &sync.WaitGroup{},
|
||||
appendNotify: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
rep.loadConfigDefaults()
|
||||
|
||||
rep.state.Store(&localState{})
|
||||
rep.client = newClient(rep.conf.PrimaryEndpoint, rep.conf.ReplicationPSK, rep.conf.NetTimeout)
|
||||
|
||||
if err := rep.initDirectories(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.acquireLock(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.loadLocalState(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.openWAL(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.recvStateIfNecessary(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.app.InitStorage(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.replay(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := rep.app.LoadFromStorage(); err != nil {
|
||||
rep.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rep.startWALGC()
|
||||
rep.startWALFollower()
|
||||
|
||||
if !rep.conf.Primary {
|
||||
rep.startWALRecvr()
|
||||
}
|
||||
|
||||
return rep, nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) Append(size int64, r io.Reader) (int64, int64, error) {
|
||||
if !rep.conf.Primary {
|
||||
return 0, 0, errs.NotAllowed.WithMsg("cannot write to secondary")
|
||||
}
|
||||
|
||||
seqNum, timestampMS, err := rep.wal.Append(size, r)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
if !rep.conf.SynchronousAppend {
|
||||
return seqNum, timestampMS, nil
|
||||
}
|
||||
|
||||
<-rep.appendNotify
|
||||
return seqNum, timestampMS, nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) Primary() bool {
|
||||
return rep.conf.Primary
|
||||
}
|
||||
|
||||
// TODO: Probably remove this.
|
||||
// The caller may call Ack after Apply to acknowledge that the change has also
|
||||
// been applied to the caller's application. Alternatively, the caller may use
|
||||
// follow to apply changes to their application state.
|
||||
func (rep *Replicator) ack(seqNum, timestampMS int64) error {
|
||||
state := rep.getState()
|
||||
state.SeqNum = seqNum
|
||||
state.TimestampMS = timestampMS
|
||||
return rep.setState(state)
|
||||
}
|
||||
|
||||
func (rep *Replicator) getState() localState {
|
||||
return *rep.state.Load()
|
||||
}
|
||||
|
||||
func (rep *Replicator) setState(state localState) error {
|
||||
err := rep.stateHandler.Write(func(page []byte) error {
|
||||
state.writeTo(page)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rep.state.Store(&state)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) Info() Info {
|
||||
state := rep.getState()
|
||||
walInfo := rep.wal.Info()
|
||||
|
||||
return Info{
|
||||
AppSeqNum: state.SeqNum,
|
||||
AppTimestampMS: state.TimestampMS,
|
||||
WALFirstSeqNum: walInfo.FirstSeqNum,
|
||||
WALLastSeqNum: walInfo.LastSeqNum,
|
||||
WALLastTimestampMS: walInfo.LastTimestampMS,
|
||||
}
|
||||
}
|
||||
|
||||
func (rep *Replicator) Close() error {
|
||||
if rep.stopped() {
|
||||
return nil
|
||||
}
|
||||
|
||||
close(rep.stop)
|
||||
rep.done.Wait()
|
||||
|
||||
if rep.lockFile != nil {
|
||||
rep.lockFile.Close()
|
||||
}
|
||||
|
||||
if rep.wal != nil {
|
||||
rep.wal.Close()
|
||||
}
|
||||
|
||||
if rep.client != nil {
|
||||
rep.client.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rep *Replicator) stopped() bool {
|
||||
select {
|
||||
case <-rep.stop:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
128
lib/rep/testapp-harness_test.go
Normal file
128
lib/rep/testapp-harness_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestAppHarnessRun(t *testing.T) {
|
||||
TestAppHarness{}.Run(t)
|
||||
}
|
||||
|
||||
type TestAppHarness struct {
|
||||
}
|
||||
|
||||
func (h TestAppHarness) Run(t *testing.T) {
|
||||
val := reflect.ValueOf(h)
|
||||
typ := val.Type()
|
||||
for i := 0; i < typ.NumMethod(); i++ {
|
||||
method := typ.Method(i)
|
||||
|
||||
if !strings.HasPrefix(method.Name, "Test") {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(method.Name, func(t *testing.T) {
|
||||
//t.Parallel()
|
||||
rootDir := t.TempDir()
|
||||
|
||||
app1 := newApp(t, rand.Int63(), Config{
|
||||
Primary: true,
|
||||
RootDir: filepath.Join(rootDir, "app1"),
|
||||
ReplicationPSK: "123",
|
||||
WALSegMinCount: 1,
|
||||
WALSegMaxAgeSec: 1,
|
||||
WALSegGCAgeSec: 1,
|
||||
})
|
||||
defer app1.Close()
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/rep/", app1.rep.Handle)
|
||||
testServer := httptest.NewServer(mux)
|
||||
defer testServer.Close()
|
||||
|
||||
app2 := newApp(t, rand.Int63(), Config{
|
||||
Primary: false,
|
||||
RootDir: filepath.Join(rootDir, "app2"),
|
||||
ReplicationPSK: "123",
|
||||
PrimaryEndpoint: testServer.URL + "/rep/",
|
||||
WALSegMinCount: 1,
|
||||
WALSegMaxAgeSec: 1,
|
||||
WALSegGCAgeSec: 1,
|
||||
})
|
||||
|
||||
val.MethodByName(method.Name).Call([]reflect.Value{
|
||||
reflect.ValueOf(t),
|
||||
reflect.ValueOf(app1),
|
||||
reflect.ValueOf(app2),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (TestAppHarness) TestRandomUpdates(t *testing.T, app1, app2 *TestApp) {
|
||||
go app1.UpdateRandomFor(4 * time.Second)
|
||||
app2.WaitForEOF()
|
||||
app1.AssertEqual(t, app2)
|
||||
}
|
||||
|
||||
/*
|
||||
func (TestAppHarness) TestRandomUpdatesReplay(t *testing.T, app1, app2 *TestApp) {
|
||||
app1.UpdateRandomFor(4 * time.Second)
|
||||
app2.WaitForEOF()
|
||||
|
||||
app1.Close()
|
||||
app1 = newApp(t, app1.ID, app1.rep.conf)
|
||||
|
||||
app1.AssertEqual(t, app2)
|
||||
info := app1.rep.Info()
|
||||
if info.AppSeqNum != 0 {
|
||||
t.Fatal(info)
|
||||
}
|
||||
}
|
||||
|
||||
func (TestAppHarness) TestRandomUpdatesAck(t *testing.T, app1, app2 *TestApp) {
|
||||
go app1.UpdateRandomFor(4 * time.Second)
|
||||
app2.WaitForEOF()
|
||||
app1.AssertEqual(t, app2)
|
||||
info := app1.rep.Info()
|
||||
if info.AppSeqNum == 0 || info.AppSeqNum != info.WALLastSeqNum {
|
||||
t.Fatal(info)
|
||||
}
|
||||
}
|
||||
|
||||
func (TestAppHarness) TestWriteThenOpenFollower(t *testing.T, app1, app2 *TestApp) {
|
||||
app2.Close()
|
||||
app1.UpdateRandomFor(4 * time.Second)
|
||||
|
||||
app2 = newApp(t, app2.ID, app2.rep.conf)
|
||||
app2.WaitForEOF()
|
||||
app1.AssertEqual(t, app2)
|
||||
}
|
||||
|
||||
func (TestAppHarness) TestUpdateOpenFollowerConcurrently(t *testing.T, app1, app2 *TestApp) {
|
||||
app2.Close()
|
||||
go app1.UpdateRandomFor(4 * time.Second)
|
||||
time.Sleep(2 * time.Second)
|
||||
app2 = newApp(t, app2.ID, app2.rep.conf)
|
||||
app2.WaitForEOF()
|
||||
app1.AssertEqual(t, app2)
|
||||
}
|
||||
|
||||
func (TestAppHarness) TestUpdateCloseOpenFollowerConcurrently(t *testing.T, app1, app2 *TestApp) {
|
||||
go app1.UpdateRandomFor(4 * time.Second)
|
||||
|
||||
time.Sleep(time.Second)
|
||||
app2.Close()
|
||||
time.Sleep(time.Second)
|
||||
app2 = newApp(t, app2.ID, app2.rep.conf)
|
||||
app2.WaitForEOF()
|
||||
app1.AssertEqual(t, app2)
|
||||
}
|
||||
*/
|
||||
239
lib/rep/testapp_test.go
Normal file
239
lib/rep/testapp_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package rep
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/wal"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type TestCmd struct {
|
||||
Set int64 // 1 for set, 0 for delete
|
||||
Key int64
|
||||
Val int64
|
||||
}
|
||||
|
||||
func (c TestCmd) marshal() []byte {
|
||||
b := make([]byte, 24)
|
||||
binary.LittleEndian.PutUint64(b, uint64(c.Set))
|
||||
binary.LittleEndian.PutUint64(b[8:], uint64(c.Key))
|
||||
binary.LittleEndian.PutUint64(b[16:], uint64(c.Val))
|
||||
return b
|
||||
}
|
||||
|
||||
func (c *TestCmd) unmarshal(b []byte) {
|
||||
c.Set = int64(binary.LittleEndian.Uint64(b))
|
||||
c.Key = int64(binary.LittleEndian.Uint64(b[8:]))
|
||||
c.Val = int64(binary.LittleEndian.Uint64(b[16:]))
|
||||
}
|
||||
|
||||
func CmdFromRec(rec wal.Record) TestCmd {
|
||||
cmd := TestCmd{}
|
||||
|
||||
buf, err := io.ReadAll(rec.Reader)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if len(buf) != 24 {
|
||||
panic(len(buf))
|
||||
}
|
||||
cmd.unmarshal(buf)
|
||||
return cmd
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
var storage = map[int64]map[int64]int64{}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type TestApp struct {
|
||||
ID int64
|
||||
storage map[int64]int64
|
||||
|
||||
rep *Replicator
|
||||
|
||||
lock sync.Mutex
|
||||
m map[int64]int64
|
||||
}
|
||||
|
||||
func newApp(t *testing.T, id int64, conf Config) *TestApp {
|
||||
t.Helper()
|
||||
a := &TestApp{
|
||||
ID: id,
|
||||
m: map[int64]int64{},
|
||||
}
|
||||
|
||||
var err error
|
||||
a.rep, err = Open(App{
|
||||
SendState: a.sendState,
|
||||
RecvState: a.recvState,
|
||||
InitStorage: a.initStorage,
|
||||
Replay: a.replay,
|
||||
LoadFromStorage: a.loadFromStorage,
|
||||
Apply: a.apply,
|
||||
}, conf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *TestApp) _set(k, v int64) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
a.m[k] = v
|
||||
}
|
||||
|
||||
func (a *TestApp) _del(k int64) {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
delete(a.m, k)
|
||||
}
|
||||
|
||||
func (a *TestApp) Get(k int64) int64 {
|
||||
a.lock.Lock()
|
||||
defer a.lock.Unlock()
|
||||
return a.m[k]
|
||||
}
|
||||
|
||||
func (app *TestApp) Close() {
|
||||
app.rep.Close()
|
||||
}
|
||||
|
||||
func (app *TestApp) Set(k, v int64) error {
|
||||
cmd := TestCmd{Set: 1, Key: k, Val: v}
|
||||
if _, _, err := app.rep.Append(24, bytes.NewBuffer(cmd.marshal())); err != nil {
|
||||
return err
|
||||
}
|
||||
app._set(k, v)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *TestApp) Del(k int64) error {
|
||||
cmd := TestCmd{Set: 0, Key: k, Val: 0}
|
||||
if _, _, err := app.rep.Append(24, bytes.NewBuffer(cmd.marshal())); err != nil {
|
||||
return err
|
||||
}
|
||||
app._del(k)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *TestApp) UpdateRandomFor(dt time.Duration) {
|
||||
tStart := time.Now()
|
||||
for time.Since(tStart) < dt {
|
||||
if rand.Float32() < 0.5 {
|
||||
if err := app.Set(1+rand.Int63n(10), 1+rand.Int63n(10)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
} else {
|
||||
if err := app.Del(1 + rand.Int63n(10)); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
app.Set(999, 999)
|
||||
}
|
||||
|
||||
func (app *TestApp) WaitForEOF() {
|
||||
for app.Get(999) != 999 {
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *TestApp) AssertEqual(t *testing.T, rhs *TestApp) {
|
||||
app.lock.Lock()
|
||||
defer app.lock.Unlock()
|
||||
rhs.lock.Lock()
|
||||
defer rhs.lock.Unlock()
|
||||
|
||||
if len(app.m) != len(rhs.m) {
|
||||
t.Fatal(len(app.m), len(rhs.m))
|
||||
}
|
||||
|
||||
for k := range app.m {
|
||||
if app.m[k] != rhs.m[k] {
|
||||
t.Fatal(k, app.m[k], rhs.m[k])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (app *TestApp) sendState(conn net.Conn) error {
|
||||
app.lock.Lock()
|
||||
b, _ := json.Marshal(app.m)
|
||||
app.lock.Unlock()
|
||||
|
||||
_, err := conn.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
func (app *TestApp) recvState(conn net.Conn) error {
|
||||
m := map[int64]int64{}
|
||||
if err := json.NewDecoder(conn).Decode(&m); err != nil {
|
||||
return err
|
||||
}
|
||||
storage[app.ID] = m
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *TestApp) initStorage() error {
|
||||
if _, ok := storage[app.ID]; !ok {
|
||||
storage[app.ID] = map[int64]int64{}
|
||||
}
|
||||
app.storage = storage[app.ID]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *TestApp) replay(rec wal.Record) error {
|
||||
cmd := CmdFromRec(rec)
|
||||
if cmd.Set != 0 {
|
||||
app.storage[cmd.Key] = cmd.Val
|
||||
} else {
|
||||
delete(app.storage, cmd.Key)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *TestApp) loadFromStorage() error {
|
||||
app.m = map[int64]int64{}
|
||||
for k, v := range app.storage {
|
||||
app.m[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (app *TestApp) apply(rec wal.Record) error {
|
||||
cmd := CmdFromRec(rec)
|
||||
if cmd.Set != 0 {
|
||||
app.storage[cmd.Key] = cmd.Val
|
||||
} else {
|
||||
delete(app.storage, cmd.Key)
|
||||
}
|
||||
|
||||
// For primary, only update storage.
|
||||
if app.rep.Primary() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// For secondary, update the map.
|
||||
if cmd.Set != 0 {
|
||||
app._set(cmd.Key, cmd.Val)
|
||||
} else {
|
||||
app._del(cmd.Key)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
33
lib/testutil/limitwriter.go
Normal file
33
lib/testutil/limitwriter.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
func NewLimitWriter(w io.Writer, limit int) *LimitWriter {
|
||||
return &LimitWriter{
|
||||
w: w,
|
||||
limit: limit,
|
||||
}
|
||||
}
|
||||
|
||||
type LimitWriter struct {
|
||||
w io.Writer
|
||||
limit int
|
||||
written int
|
||||
}
|
||||
|
||||
func (lw *LimitWriter) Write(buf []byte) (int, error) {
|
||||
n, err := lw.w.Write(buf)
|
||||
if err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
||||
lw.written += n
|
||||
if lw.written > lw.limit {
|
||||
return n, os.ErrClosed
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
79
lib/testutil/testconn.go
Normal file
79
lib/testutil/testconn.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Network struct {
|
||||
lock sync.Mutex
|
||||
// Current connections.
|
||||
cConn net.Conn
|
||||
sConn net.Conn
|
||||
|
||||
acceptQ chan net.Conn
|
||||
}
|
||||
|
||||
func NewNetwork() *Network {
|
||||
return &Network{
|
||||
acceptQ: make(chan net.Conn, 1),
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) Dial() net.Conn {
|
||||
cc, sc := net.Pipe()
|
||||
func() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.cConn != nil {
|
||||
n.cConn.Close()
|
||||
n.cConn = nil
|
||||
}
|
||||
select {
|
||||
case n.acceptQ <- sc:
|
||||
n.cConn = cc
|
||||
default:
|
||||
cc = nil
|
||||
}
|
||||
}()
|
||||
return cc
|
||||
}
|
||||
|
||||
func (n *Network) Accept() net.Conn {
|
||||
var sc net.Conn
|
||||
select {
|
||||
case sc = <-n.acceptQ:
|
||||
case <-time.After(time.Second):
|
||||
return nil
|
||||
}
|
||||
|
||||
func() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.sConn != nil {
|
||||
n.sConn.Close()
|
||||
n.sConn = nil
|
||||
}
|
||||
n.sConn = sc
|
||||
}()
|
||||
return sc
|
||||
}
|
||||
|
||||
func (n *Network) CloseClient() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.cConn != nil {
|
||||
n.cConn.Close()
|
||||
n.cConn = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (n *Network) CloseServer() {
|
||||
n.lock.Lock()
|
||||
defer n.lock.Unlock()
|
||||
if n.sConn != nil {
|
||||
n.sConn.Close()
|
||||
n.sConn = nil
|
||||
}
|
||||
}
|
||||
10
lib/testutil/util.go
Normal file
10
lib/testutil/util.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package testutil
|
||||
|
||||
import "testing"
|
||||
|
||||
func AssertNotNil(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
53
lib/wal/corrupt_test.go
Normal file
53
lib/wal/corrupt_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCorruptWAL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
wal, err := Create(tmpDir, 100, Config{
|
||||
SegMinCount: 1024,
|
||||
SegMaxAgeSec: 3600,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer wal.Close()
|
||||
|
||||
appendRandomRecords(t, wal, 100)
|
||||
|
||||
f := wal.seg.f
|
||||
info, err := f.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
offset := info.Size() / 2
|
||||
if _, err := f.WriteAt([]byte{1, 2, 3, 4, 5, 6, 7, 8}, offset); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
it, err := wal.Iterator(-1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for it.Next(0) {
|
||||
rec := it.Record()
|
||||
if _, err := io.ReadAll(rec.Reader); err != nil {
|
||||
if errs.Corrupt.Is(err) {
|
||||
return
|
||||
}
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if !errs.Corrupt.Is(it.Error()) {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
}
|
||||
28
lib/wal/design.go
Normal file
28
lib/wal/design.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type Info struct {
|
||||
FirstSeqNum int64
|
||||
LastSeqNum int64
|
||||
LastTimestampMS int64
|
||||
}
|
||||
|
||||
type Iterator interface {
|
||||
// Next will return false if no record is available during the timeout
|
||||
// period, or if an error is encountered. After Next returns false, the
|
||||
// caller should check the return value of the Error function.
|
||||
Next(timeout time.Duration) bool
|
||||
|
||||
// Call Record after Next returns true to get the next record.
|
||||
Record() Record
|
||||
|
||||
// The caller must call Close on the iterator so clean-up can be performed.
|
||||
Close()
|
||||
|
||||
// Call Error to see if there was an error during the previous call to Next
|
||||
// if Next returned false.
|
||||
Error() error
|
||||
}
|
||||
94
lib/wal/gc_test.go
Normal file
94
lib/wal/gc_test.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDeleteBefore(t *testing.T) {
|
||||
t.Parallel()
|
||||
firstSeqNum := rand.Int63n(9288389)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
wal, err := Create(tmpDir, firstSeqNum, Config{
|
||||
SegMinCount: 10,
|
||||
SegMaxAgeSec: 1,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer wal.Close()
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := writeRandomWithEOF(wal, 8*time.Second)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
info := wal.Info()
|
||||
if info.FirstSeqNum != firstSeqNum {
|
||||
t.Fatal(info)
|
||||
}
|
||||
|
||||
lastSeqNum := info.LastSeqNum
|
||||
lastTimestampMS := info.LastTimestampMS
|
||||
|
||||
err = wal.DeleteBefore((info.LastTimestampMS/1000)-4, lastSeqNum+100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info = wal.Info()
|
||||
if info.FirstSeqNum == firstSeqNum || info.LastSeqNum != lastSeqNum || info.LastTimestampMS != lastTimestampMS {
|
||||
t.Fatal(info)
|
||||
}
|
||||
|
||||
header := wal.header
|
||||
if header.FirstSegmentID >= header.LastSegmentID {
|
||||
t.Fatal(header)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteBeforeOnlyOneSegment(t *testing.T) {
|
||||
t.Parallel()
|
||||
firstSeqNum := rand.Int63n(9288389)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
wal, err := Create(tmpDir, firstSeqNum, Config{
|
||||
SegMinCount: 10,
|
||||
SegMaxAgeSec: 10,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer wal.Close()
|
||||
|
||||
if err := writeRandomWithEOF(wal, time.Second); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
header := wal.header
|
||||
if header.FirstSegmentID != header.LastSegmentID {
|
||||
t.Fatal(header)
|
||||
}
|
||||
|
||||
lastSeqNum := wal.Info().LastSeqNum
|
||||
|
||||
err = wal.DeleteBefore(time.Now().Unix()+1, lastSeqNum+100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
header = wal.header
|
||||
if header.FirstSegmentID != header.LastSegmentID {
|
||||
t.Fatal(header)
|
||||
}
|
||||
}
|
||||
391
lib/wal/generic_test.go
Normal file
391
lib/wal/generic_test.go
Normal file
@@ -0,0 +1,391 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"math/rand"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type waLog interface {
|
||||
Append(int64, io.Reader) (int64, int64, error)
|
||||
appendRecord(Record) (int64, int64, error)
|
||||
Iterator(int64) (Iterator, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
func TestGenericTestHarness_segment(t *testing.T) {
|
||||
t.Parallel()
|
||||
(&GenericTestHarness{
|
||||
New: func(tmpDir string, firstSeqNum int64) (waLog, error) {
|
||||
l, err := createSegment(filepath.Join(tmpDir, "x"), 1, firstSeqNum, 12345)
|
||||
return l, err
|
||||
},
|
||||
}).Run(t)
|
||||
}
|
||||
|
||||
func TestGenericTestHarness_wal(t *testing.T) {
|
||||
t.Parallel()
|
||||
(&GenericTestHarness{
|
||||
New: func(tmpDir string, firstSeqNum int64) (waLog, error) {
|
||||
l, err := Create(tmpDir, firstSeqNum, Config{
|
||||
SegMinCount: 1,
|
||||
SegMaxAgeSec: 1,
|
||||
})
|
||||
return l, err
|
||||
},
|
||||
}).Run(t)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type GenericTestHarness struct {
|
||||
New func(tmpDir string, firstSeqNum int64) (waLog, error)
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) Run(t *testing.T) {
|
||||
val := reflect.ValueOf(h)
|
||||
typ := val.Type()
|
||||
for i := 0; i < typ.NumMethod(); i++ {
|
||||
method := typ.Method(i)
|
||||
|
||||
if !strings.HasPrefix(method.Name, "Test") {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(method.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
firstSeqNum := rand.Int63n(23423)
|
||||
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
wal, err := h.New(tmpDir, firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer wal.Close()
|
||||
|
||||
val.MethodByName(method.Name).Call([]reflect.Value{
|
||||
reflect.ValueOf(t),
|
||||
reflect.ValueOf(firstSeqNum),
|
||||
reflect.ValueOf(wal),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (h *GenericTestHarness) TestBasic(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
expected := appendRandomRecords(t, wal, 123)
|
||||
|
||||
for i := 0; i < 123; i++ {
|
||||
it, err := wal.Iterator(firstSeqNum + int64(i))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
checkIteratorMatches(t, it, expected[i:])
|
||||
|
||||
it.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestAppendNotFound(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
recs := appendRandomRecords(t, wal, 123)
|
||||
lastSeqNum := recs[len(recs)-1].SeqNum
|
||||
|
||||
it, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
it.Close()
|
||||
|
||||
it, err = wal.Iterator(lastSeqNum + 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
it.Close()
|
||||
|
||||
if _, err = wal.Iterator(firstSeqNum - 1); !errs.NotFound.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err = wal.Iterator(lastSeqNum + 2); !errs.NotFound.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestNextAfterClose(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
appendRandomRecords(t, wal, 123)
|
||||
|
||||
it, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
if !it.Next(0) {
|
||||
t.Fatal("Should be next")
|
||||
}
|
||||
|
||||
if err := wal.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if it.Next(0) {
|
||||
t.Fatal("Shouldn't be next")
|
||||
}
|
||||
|
||||
if !errs.Closed.Is(it.Error()) {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestNextTimeout(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
recs := appendRandomRecords(t, wal, 123)
|
||||
|
||||
it, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for range recs {
|
||||
if !it.Next(0) {
|
||||
t.Fatal("Expected next")
|
||||
}
|
||||
}
|
||||
|
||||
if it.Next(time.Millisecond) {
|
||||
t.Fatal("Unexpected next")
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestNextNotify(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
it, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
recsC := make(chan []RawRecord, 1)
|
||||
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
recsC <- appendRandomRecords(t, wal, 1)
|
||||
}()
|
||||
|
||||
if !it.Next(time.Hour) {
|
||||
t.Fatal("expected next")
|
||||
}
|
||||
|
||||
recs := <-recsC
|
||||
rec := it.Record()
|
||||
if rec.SeqNum != recs[0].SeqNum {
|
||||
t.Fatal(rec)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestNextArchived(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
type archiver interface {
|
||||
Archive() error
|
||||
}
|
||||
|
||||
arch, ok := wal.(archiver)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
recs := appendRandomRecords(t, wal, 10)
|
||||
|
||||
it, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
if err := arch.Archive(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for i, expected := range recs {
|
||||
if !it.Next(time.Millisecond) {
|
||||
t.Fatal(i, "no next")
|
||||
}
|
||||
|
||||
rec := it.Record()
|
||||
if rec.SeqNum != expected.SeqNum {
|
||||
t.Fatal(rec, expected)
|
||||
}
|
||||
}
|
||||
|
||||
if it.Next(time.Minute) {
|
||||
t.Fatal("unexpected next")
|
||||
}
|
||||
|
||||
if !errs.EOFArchived.Is(it.Error()) {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestWriteReadConcurrent(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
N := 1200
|
||||
|
||||
writeErr := make(chan error, 1)
|
||||
|
||||
dataSize := int64(4)
|
||||
makeData := func(i int) []byte {
|
||||
data := make([]byte, 4)
|
||||
binary.LittleEndian.PutUint32(data, uint32(i))
|
||||
return data
|
||||
}
|
||||
|
||||
go func() {
|
||||
for i := 0; i < N; i++ {
|
||||
|
||||
seqNum, _, err := wal.Append(dataSize, bytes.NewBuffer(makeData(i)))
|
||||
if err != nil {
|
||||
writeErr <- err
|
||||
return
|
||||
}
|
||||
|
||||
if seqNum != int64(i)+firstSeqNum {
|
||||
writeErr <- errors.New("Incorrect seq num")
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
writeErr <- nil
|
||||
}()
|
||||
|
||||
it, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for i := 0; i < N; i++ {
|
||||
if !it.Next(time.Minute) {
|
||||
t.Fatal("expected next", i, it.Error(), it.Record())
|
||||
}
|
||||
|
||||
expectedData := makeData(i)
|
||||
rec := it.Record()
|
||||
|
||||
data, err := io.ReadAll(rec.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, expectedData) {
|
||||
t.Fatal(data, expectedData)
|
||||
}
|
||||
}
|
||||
|
||||
if err := <-writeErr; err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestAppendAfterClose(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
if _, _, err := wal.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4})); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wal.Close()
|
||||
|
||||
_, _, err := wal.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4}))
|
||||
if !errs.Closed.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestIterateNegativeOne(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
recs := appendRandomRecords(t, wal, 10)
|
||||
|
||||
it1, err := wal.Iterator(firstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it1.Close()
|
||||
|
||||
it2, err := wal.Iterator(-1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it2.Close()
|
||||
|
||||
if !it1.Next(0) {
|
||||
t.Fatal(0)
|
||||
}
|
||||
if !it2.Next(0) {
|
||||
t.Fatal(0)
|
||||
}
|
||||
|
||||
r1 := it1.Record()
|
||||
r2 := it2.Record()
|
||||
|
||||
if r1.SeqNum != r2.SeqNum || r1.SeqNum != firstSeqNum || r1.SeqNum != recs[0].SeqNum {
|
||||
t.Fatal(r1.SeqNum, r2.SeqNum, firstSeqNum, recs[0].SeqNum)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestIteratorAfterClose(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
appendRandomRecords(t, wal, 10)
|
||||
wal.Close()
|
||||
|
||||
if _, err := wal.Iterator(-1); !errs.Closed.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestIteratorNextWithError(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
appendRandomRecords(t, wal, 10)
|
||||
it, err := wal.Iterator(-1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wal.Close()
|
||||
|
||||
it.Next(0)
|
||||
if !errs.Closed.Is(it.Error()) {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
|
||||
it.Next(0)
|
||||
if !errs.Closed.Is(it.Error()) {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GenericTestHarness) TestIteratorConcurrentClose(t *testing.T, firstSeqNum int64, wal waLog) {
|
||||
it, err := wal.Iterator(-1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
writeRandomWithEOF(wal, 3*time.Second)
|
||||
wal.Close()
|
||||
}()
|
||||
|
||||
for it.Next(time.Hour) {
|
||||
// Skip.
|
||||
}
|
||||
|
||||
// Error may be Closed or NotFound.
|
||||
if !errs.Closed.Is(it.Error()) && !errs.NotFound.Is(it.Error()) {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
}
|
||||
125
lib/wal/io.go
Normal file
125
lib/wal/io.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
)
|
||||
|
||||
func ioErrOrEOF(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return err
|
||||
}
|
||||
return errs.IO.WithErr(err)
|
||||
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type readAtReader struct {
|
||||
f io.ReaderAt
|
||||
offset int64
|
||||
}
|
||||
|
||||
func readerAtToReader(f io.ReaderAt, offset int64) io.Reader {
|
||||
return &readAtReader{f: f, offset: offset}
|
||||
}
|
||||
|
||||
func (r *readAtReader) Read(b []byte) (int, error) {
|
||||
n, err := r.f.ReadAt(b, r.offset)
|
||||
r.offset += int64(n)
|
||||
return n, ioErrOrEOF(err)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type writeAtWriter struct {
|
||||
w io.WriterAt
|
||||
offset int64
|
||||
}
|
||||
|
||||
func writerAtToWriter(w io.WriterAt, offset int64) io.Writer {
|
||||
return &writeAtWriter{w: w, offset: offset}
|
||||
}
|
||||
|
||||
func (w *writeAtWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.w.WriteAt(b, w.offset)
|
||||
w.offset += int64(n)
|
||||
return n, ioErrOrEOF(err)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type crcWriter struct {
|
||||
w io.Writer
|
||||
crc uint32
|
||||
}
|
||||
|
||||
func newCRCWriter(w io.Writer) *crcWriter {
|
||||
return &crcWriter{w: w}
|
||||
}
|
||||
|
||||
func (w *crcWriter) Write(b []byte) (int, error) {
|
||||
n, err := w.w.Write(b)
|
||||
w.crc = crc32.Update(w.crc, crc32.IEEETable, b[:n])
|
||||
return n, ioErrOrEOF(err)
|
||||
}
|
||||
|
||||
func (w *crcWriter) CRC() uint32 {
|
||||
return w.crc
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type dataReader struct {
|
||||
r io.Reader
|
||||
remaining int64
|
||||
crc uint32
|
||||
}
|
||||
|
||||
func newDataReader(r io.Reader, dataSize int64) *dataReader {
|
||||
return &dataReader{r: r, remaining: dataSize}
|
||||
}
|
||||
|
||||
func (r *dataReader) Read(b []byte) (int, error) {
|
||||
if r.remaining == 0 {
|
||||
return 0, io.EOF
|
||||
}
|
||||
|
||||
if int64(len(b)) > r.remaining {
|
||||
b = b[:r.remaining]
|
||||
}
|
||||
|
||||
n, err := r.r.Read(b)
|
||||
r.crc = crc32.Update(r.crc, crc32.IEEETable, b[:n])
|
||||
r.remaining -= int64(n)
|
||||
|
||||
if r.remaining == 0 {
|
||||
if err := r.checkCRC(); err != nil {
|
||||
return n, err
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
return n, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *dataReader) checkCRC() error {
|
||||
buf := make([]byte, 4)
|
||||
if _, err := r.r.Read(buf); err != nil {
|
||||
return errs.Corrupt.WithErr(err)
|
||||
}
|
||||
crc := binary.LittleEndian.Uint32(buf)
|
||||
if crc != r.crc {
|
||||
return errs.Corrupt.WithMsg("crc mismatch")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
79
lib/wal/notify.go
Normal file
79
lib/wal/notify.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package wal
|
||||
|
||||
import "sync"
|
||||
|
||||
type segmentState struct {
|
||||
Closed bool
|
||||
Archived bool
|
||||
FirstSeqNum int64
|
||||
LastSeqNum int64
|
||||
}
|
||||
|
||||
func newSegmentState(closed bool, header segmentHeader) segmentState {
|
||||
return segmentState{
|
||||
Closed: closed,
|
||||
Archived: header.ArchivedAt != 0,
|
||||
FirstSeqNum: header.FirstSeqNum,
|
||||
LastSeqNum: header.LastSeqNum,
|
||||
}
|
||||
}
|
||||
|
||||
type notifyMux struct {
|
||||
lock sync.Mutex
|
||||
nextID int64
|
||||
recvrs map[int64]chan segmentState
|
||||
}
|
||||
|
||||
type stateRecvr struct {
|
||||
// Each recvr will always get the most recent sequence number on C.
|
||||
// When the segment is closed, a -1 is sent.
|
||||
C chan segmentState
|
||||
Close func()
|
||||
}
|
||||
|
||||
func newNotifyMux() *notifyMux {
|
||||
return ¬ifyMux{
|
||||
recvrs: map[int64]chan segmentState{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *notifyMux) NewRecvr(header segmentHeader) stateRecvr {
|
||||
state := newSegmentState(false, header)
|
||||
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
m.nextID++
|
||||
|
||||
recvrID := m.nextID
|
||||
|
||||
recvr := stateRecvr{
|
||||
C: make(chan segmentState, 1),
|
||||
Close: func() {
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
delete(m.recvrs, recvrID)
|
||||
},
|
||||
}
|
||||
|
||||
recvr.C <- state
|
||||
m.recvrs[recvrID] = recvr.C
|
||||
|
||||
return recvr
|
||||
}
|
||||
|
||||
func (m *notifyMux) Notify(closed bool, header segmentHeader) {
|
||||
|
||||
state := newSegmentState(closed, header)
|
||||
|
||||
m.lock.Lock()
|
||||
defer m.lock.Unlock()
|
||||
|
||||
for _, c := range m.recvrs {
|
||||
select {
|
||||
case c <- state:
|
||||
case <-c:
|
||||
c <- state
|
||||
}
|
||||
}
|
||||
}
|
||||
90
lib/wal/record.go
Normal file
90
lib/wal/record.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
)
|
||||
|
||||
const recordHeaderSize = 28
|
||||
|
||||
type Record struct {
|
||||
SeqNum int64
|
||||
TimestampMS int64
|
||||
DataSize int64
|
||||
Reader io.Reader
|
||||
}
|
||||
|
||||
func (rec Record) writeHeaderTo(w io.Writer) (int, error) {
|
||||
buf := make([]byte, recordHeaderSize)
|
||||
binary.LittleEndian.PutUint64(buf[0:], uint64(rec.SeqNum))
|
||||
binary.LittleEndian.PutUint64(buf[8:], uint64(rec.TimestampMS))
|
||||
binary.LittleEndian.PutUint64(buf[16:], uint64(rec.DataSize))
|
||||
crc := crc32.ChecksumIEEE(buf[:recordHeaderSize-4])
|
||||
binary.LittleEndian.PutUint32(buf[24:], crc)
|
||||
|
||||
n, err := w.Write(buf)
|
||||
if err != nil {
|
||||
err = errs.IO.WithErr(err)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (rec *Record) readHeaderFrom(r io.Reader) error {
|
||||
buf := make([]byte, recordHeaderSize)
|
||||
if _, err := io.ReadFull(r, buf); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
crc := crc32.ChecksumIEEE(buf[:recordHeaderSize-4])
|
||||
stored := binary.LittleEndian.Uint32(buf[recordHeaderSize-4:])
|
||||
if crc != stored {
|
||||
return errs.Corrupt.WithMsg("checksum mismatch")
|
||||
}
|
||||
|
||||
rec.SeqNum = int64(binary.LittleEndian.Uint64(buf[0:]))
|
||||
rec.TimestampMS = int64(binary.LittleEndian.Uint64(buf[8:]))
|
||||
rec.DataSize = int64(binary.LittleEndian.Uint64(buf[16:]))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rec Record) serializedSize() int64 {
|
||||
return recordHeaderSize + rec.DataSize + 4 // 4 for data CRC32.
|
||||
}
|
||||
|
||||
func (rec Record) writeTo(w io.Writer) (int64, error) {
|
||||
nn, err := rec.writeHeaderTo(w)
|
||||
if err != nil {
|
||||
return int64(nn), err
|
||||
}
|
||||
|
||||
n := int64(nn)
|
||||
|
||||
// Write the data.
|
||||
crcW := newCRCWriter(w)
|
||||
n2, err := io.CopyN(crcW, rec.Reader, rec.DataSize)
|
||||
n += n2
|
||||
if err != nil {
|
||||
return n, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
// Write the data crc value.
|
||||
err = binary.Write(w, binary.LittleEndian, crcW.CRC())
|
||||
if err != nil {
|
||||
return n, errs.IO.WithErr(err)
|
||||
}
|
||||
n += 4
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (rec *Record) readFrom(r io.Reader) error {
|
||||
if err := rec.readHeaderFrom(r); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rec.Reader = newDataReader(r, rec.DataSize)
|
||||
return nil
|
||||
}
|
||||
171
lib/wal/record_test.go
Normal file
171
lib/wal/record_test.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"git.crumpington.com/public/jldb/lib/testutil"
|
||||
"math/rand"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func NewRecordForTesting() Record {
|
||||
data := randData()
|
||||
return Record{
|
||||
SeqNum: rand.Int63(),
|
||||
TimestampMS: rand.Int63(),
|
||||
DataSize: int64(len(data)),
|
||||
Reader: bytes.NewBuffer(data),
|
||||
}
|
||||
}
|
||||
|
||||
func AssertRecordHeadersEqual(t *testing.T, r1, r2 Record) {
|
||||
t.Helper()
|
||||
eq := r1.SeqNum == r2.SeqNum &&
|
||||
r1.TimestampMS == r2.TimestampMS &&
|
||||
r1.DataSize == r2.DataSize
|
||||
if !eq {
|
||||
t.Fatal(r1, r2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordWriteHeaderToReadHeaderFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
rec1 := NewRecordForTesting()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
n, err := rec1.writeHeaderTo(b)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if n != recordHeaderSize {
|
||||
t.Fatal(n)
|
||||
}
|
||||
|
||||
rec2 := Record{}
|
||||
if err := rec2.readHeaderFrom(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
AssertRecordHeadersEqual(t, rec1, rec2)
|
||||
}
|
||||
|
||||
func TestRecordWriteHeaderToEOF(t *testing.T) {
|
||||
t.Parallel()
|
||||
rec := NewRecordForTesting()
|
||||
|
||||
for limit := 1; limit < recordHeaderSize; limit++ {
|
||||
buf := &bytes.Buffer{}
|
||||
w := testutil.NewLimitWriter(buf, limit)
|
||||
|
||||
n, err := rec.writeHeaderTo(w)
|
||||
if !errs.IO.Is(err) {
|
||||
t.Fatal(limit, n, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordReadHeaderFromError(t *testing.T) {
|
||||
t.Parallel()
|
||||
rec := NewRecordForTesting()
|
||||
|
||||
for limit := 1; limit < recordHeaderSize; limit++ {
|
||||
b := &bytes.Buffer{}
|
||||
if _, err := rec.writeHeaderTo(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
r := io.LimitReader(b, int64(limit))
|
||||
if err := rec.readFrom(r); !errs.IO.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordReadHeaderFromCorrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
rec := NewRecordForTesting()
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
|
||||
for i := 0; i < recordHeaderSize; i++ {
|
||||
if _, err := rec.writeHeaderTo(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
b.Bytes()[i]++
|
||||
if err := rec.readHeaderFrom(b); !errs.Corrupt.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordWriteToReadFrom(t *testing.T) {
|
||||
t.Parallel()
|
||||
r1 := NewRecordForTesting()
|
||||
data := randData()
|
||||
r1.Reader = bytes.NewBuffer(bytes.Clone(data))
|
||||
r1.DataSize = int64(len(data))
|
||||
|
||||
r2 := Record{}
|
||||
|
||||
b := &bytes.Buffer{}
|
||||
if _, err := r1.writeTo(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := r2.readFrom(b); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
AssertRecordHeadersEqual(t, r1, r2)
|
||||
|
||||
data2, err := io.ReadAll(r2.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !bytes.Equal(data, data2) {
|
||||
t.Fatal(data, data2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordReadFromCorrupt(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := randData()
|
||||
r1 := NewRecordForTesting()
|
||||
|
||||
for i := 0; i < int(r1.serializedSize()); i++ {
|
||||
r1.Reader = bytes.NewBuffer(data)
|
||||
r1.DataSize = int64(len(data))
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
r1.writeTo(buf)
|
||||
buf.Bytes()[i]++
|
||||
|
||||
r2 := Record{}
|
||||
if err := r2.readFrom(buf); err != nil {
|
||||
if !errs.Corrupt.Is(err) {
|
||||
t.Fatal(i, err)
|
||||
}
|
||||
continue // OK.
|
||||
}
|
||||
|
||||
if _, err := io.ReadAll(r2.Reader); !errs.Corrupt.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRecordWriteToError(t *testing.T) {
|
||||
t.Parallel()
|
||||
data := randData()
|
||||
r1 := NewRecordForTesting()
|
||||
r1.Reader = bytes.NewBuffer(data)
|
||||
r1.DataSize = int64(len(data))
|
||||
|
||||
for i := 0; i < int(r1.serializedSize()); i++ {
|
||||
w := testutil.NewLimitWriter(&bytes.Buffer{}, i)
|
||||
r1.Reader = bytes.NewBuffer(data)
|
||||
if _, err := r1.writeTo(w); !errs.IO.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
44
lib/wal/segment-header.go
Normal file
44
lib/wal/segment-header.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package wal
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
type segmentHeader struct {
|
||||
CreatedAt int64
|
||||
ArchivedAt int64
|
||||
FirstSeqNum int64
|
||||
LastSeqNum int64 // FirstSeqNum - 1 if empty.
|
||||
LastTimestampMS int64 // 0 if empty.
|
||||
InsertAt int64
|
||||
}
|
||||
|
||||
func (h segmentHeader) WriteTo(b []byte) {
|
||||
vals := []int64{
|
||||
h.CreatedAt,
|
||||
h.ArchivedAt,
|
||||
h.FirstSeqNum,
|
||||
h.LastSeqNum,
|
||||
h.LastTimestampMS,
|
||||
h.InsertAt,
|
||||
}
|
||||
|
||||
for _, val := range vals {
|
||||
binary.LittleEndian.PutUint64(b[0:8], uint64(val))
|
||||
b = b[8:]
|
||||
}
|
||||
}
|
||||
|
||||
func (h *segmentHeader) ReadFrom(b []byte) {
|
||||
ptrs := []*int64{
|
||||
&h.CreatedAt,
|
||||
&h.ArchivedAt,
|
||||
&h.FirstSeqNum,
|
||||
&h.LastSeqNum,
|
||||
&h.LastTimestampMS,
|
||||
&h.InsertAt,
|
||||
}
|
||||
|
||||
for _, ptr := range ptrs {
|
||||
*ptr = int64(binary.LittleEndian.Uint64(b[0:8]))
|
||||
b = b[8:]
|
||||
}
|
||||
}
|
||||
165
lib/wal/segment-iterator.go
Normal file
165
lib/wal/segment-iterator.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"git.crumpington.com/public/jldb/lib/atomicheader"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
type segmentIterator struct {
|
||||
f *os.File
|
||||
|
||||
recvr stateRecvr
|
||||
state segmentState
|
||||
|
||||
offset int64
|
||||
err error
|
||||
rec Record
|
||||
|
||||
ticker *time.Ticker // Ticker if timeout has been set.
|
||||
tickerC <-chan time.Time // Ticker channel if timeout has been set.
|
||||
}
|
||||
|
||||
func newSegmentIterator(
|
||||
f *os.File,
|
||||
fromSeqNum int64,
|
||||
recvr stateRecvr,
|
||||
) (
|
||||
Iterator,
|
||||
error,
|
||||
) {
|
||||
it := &segmentIterator{
|
||||
f: f,
|
||||
recvr: recvr,
|
||||
state: <-recvr.C,
|
||||
}
|
||||
|
||||
if err := it.seekToSeqNum(fromSeqNum); err != nil {
|
||||
it.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
it.rec.SeqNum = fromSeqNum - 1
|
||||
|
||||
it.ticker = time.NewTicker(time.Second)
|
||||
it.tickerC = it.ticker.C
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (it *segmentIterator) seekToSeqNum(fromSeqNum int64) error {
|
||||
|
||||
state := it.state
|
||||
|
||||
// Is the requested sequence number out-of-range?
|
||||
if fromSeqNum < state.FirstSeqNum || fromSeqNum > state.LastSeqNum+1 {
|
||||
return errs.NotFound.WithMsg("sequence number not in segment")
|
||||
}
|
||||
|
||||
// Seek to start.
|
||||
it.offset = atomicheader.ReservedBytes
|
||||
|
||||
// Seek to first seq num - we're already there.
|
||||
if fromSeqNum == it.state.FirstSeqNum {
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
if err := it.readRecord(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
it.offset += it.rec.serializedSize()
|
||||
|
||||
if it.rec.SeqNum == fromSeqNum-1 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (it *segmentIterator) Close() {
|
||||
it.f.Close()
|
||||
it.recvr.Close()
|
||||
}
|
||||
|
||||
// Next returns true if there's a record available to read via it.Record().
|
||||
//
|
||||
// If Next returns false, the caller should check the error value with
|
||||
// it.Error().
|
||||
func (it *segmentIterator) Next(timeout time.Duration) bool {
|
||||
if it.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Get new state if available.
|
||||
select {
|
||||
case it.state = <-it.recvr.C:
|
||||
default:
|
||||
}
|
||||
|
||||
if it.state.Closed {
|
||||
it.err = errs.Closed
|
||||
return false
|
||||
}
|
||||
|
||||
if it.rec.SeqNum < it.state.LastSeqNum {
|
||||
if it.err = it.readRecord(); it.err != nil {
|
||||
return false
|
||||
}
|
||||
it.offset += it.rec.serializedSize()
|
||||
return true
|
||||
}
|
||||
|
||||
if it.state.Archived {
|
||||
it.err = errs.EOFArchived
|
||||
return false
|
||||
}
|
||||
|
||||
if timeout <= 0 {
|
||||
return false // Nothing to return.
|
||||
}
|
||||
|
||||
// Wait for new record, or timeout.
|
||||
it.ticker.Reset(timeout)
|
||||
|
||||
// Get new state if available.
|
||||
select {
|
||||
case it.state = <-it.recvr.C:
|
||||
// OK
|
||||
case <-it.tickerC:
|
||||
return false // Timeout, no error.
|
||||
}
|
||||
|
||||
if it.state.Closed {
|
||||
it.err = errs.Closed
|
||||
return false
|
||||
}
|
||||
|
||||
if it.rec.SeqNum < it.state.LastSeqNum {
|
||||
if it.err = it.readRecord(); it.err != nil {
|
||||
return false
|
||||
}
|
||||
it.offset += it.rec.serializedSize()
|
||||
return true
|
||||
}
|
||||
|
||||
if it.state.Archived {
|
||||
it.err = errs.EOFArchived
|
||||
return false
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (it *segmentIterator) Record() Record {
|
||||
return it.rec
|
||||
}
|
||||
|
||||
func (it *segmentIterator) Error() error {
|
||||
return it.err
|
||||
}
|
||||
|
||||
func (it *segmentIterator) readRecord() error {
|
||||
return it.rec.readFrom(readerAtToReader(it.f, it.offset))
|
||||
}
|
||||
250
lib/wal/segment.go
Normal file
250
lib/wal/segment.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/atomicheader"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type segment struct {
|
||||
ID int64
|
||||
|
||||
lock sync.Mutex
|
||||
|
||||
closed bool
|
||||
header segmentHeader
|
||||
headWriter *atomicheader.Handler
|
||||
f *os.File
|
||||
notifyMux *notifyMux
|
||||
|
||||
// For non-archived segments.
|
||||
w *bufio.Writer
|
||||
}
|
||||
|
||||
func createSegment(path string, id, firstSeqNum, timestampMS int64) (*segment, error) {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := atomicheader.Init(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handler, err := atomicheader.Open(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header := segmentHeader{
|
||||
CreatedAt: time.Now().Unix(),
|
||||
FirstSeqNum: firstSeqNum,
|
||||
LastSeqNum: firstSeqNum - 1,
|
||||
LastTimestampMS: timestampMS,
|
||||
InsertAt: atomicheader.ReservedBytes,
|
||||
}
|
||||
|
||||
err = handler.Write(func(page []byte) error {
|
||||
header.WriteTo(page)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return openSegment(path, id)
|
||||
}
|
||||
|
||||
func openSegment(path string, id int64) (*segment, error) {
|
||||
f, err := os.OpenFile(path, os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
handler, err := atomicheader.Open(f)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header segmentHeader
|
||||
err = handler.Read(func(page []byte) error {
|
||||
header.ReadFrom(page)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := f.Seek(header.InsertAt, io.SeekStart); err != nil {
|
||||
f.Close()
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
seg := &segment{
|
||||
ID: id,
|
||||
header: header,
|
||||
headWriter: handler,
|
||||
f: f,
|
||||
notifyMux: newNotifyMux(),
|
||||
}
|
||||
|
||||
if header.ArchivedAt == 0 {
|
||||
seg.w = bufio.NewWriterSize(f, 1024*1024)
|
||||
}
|
||||
|
||||
return seg, nil
|
||||
}
|
||||
|
||||
// Append appends the data from r to the log atomically. If an error is
|
||||
// returned, the caller should check for errs.Fatal. If a fatal error occurs,
|
||||
// the segment should no longer be used.
|
||||
func (seg *segment) Append(dataSize int64, r io.Reader) (int64, int64, error) {
|
||||
return seg.appendRecord(Record{
|
||||
SeqNum: -1,
|
||||
TimestampMS: time.Now().UnixMilli(),
|
||||
DataSize: dataSize,
|
||||
Reader: r,
|
||||
})
|
||||
}
|
||||
|
||||
func (seg *segment) Header() segmentHeader {
|
||||
seg.lock.Lock()
|
||||
defer seg.lock.Unlock()
|
||||
return seg.header
|
||||
}
|
||||
|
||||
// appendRecord appends a record in an atomic fashion. Do not use the segment
|
||||
// after a fatal error.
|
||||
func (seg *segment) appendRecord(rec Record) (int64, int64, error) {
|
||||
seg.lock.Lock()
|
||||
defer seg.lock.Unlock()
|
||||
|
||||
header := seg.header // Copy.
|
||||
|
||||
if seg.closed {
|
||||
return 0, 0, errs.Closed
|
||||
}
|
||||
|
||||
if header.ArchivedAt != 0 {
|
||||
return 0, 0, errs.Archived
|
||||
}
|
||||
|
||||
if rec.SeqNum == -1 {
|
||||
rec.SeqNum = header.LastSeqNum + 1
|
||||
} else if rec.SeqNum != header.LastSeqNum+1 {
|
||||
return 0, 0, errs.Unexpected.WithMsg(
|
||||
"expected sequence number %d but got %d",
|
||||
header.LastSeqNum+1,
|
||||
rec.SeqNum)
|
||||
}
|
||||
|
||||
seg.w.Reset(writerAtToWriter(seg.f, header.InsertAt))
|
||||
|
||||
n, err := rec.writeTo(seg.w)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
if err := seg.w.Flush(); err != nil {
|
||||
return 0, 0, ioErrOrEOF(err)
|
||||
}
|
||||
|
||||
// Write new header to sync.
|
||||
header.LastSeqNum = rec.SeqNum
|
||||
header.LastTimestampMS = rec.TimestampMS
|
||||
header.InsertAt += n
|
||||
|
||||
err = seg.headWriter.Write(func(page []byte) error {
|
||||
header.WriteTo(page)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
seg.header = header
|
||||
seg.notifyMux.Notify(false, header)
|
||||
|
||||
return rec.SeqNum, rec.TimestampMS, nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (seg *segment) Archive() error {
|
||||
seg.lock.Lock()
|
||||
defer seg.lock.Unlock()
|
||||
|
||||
header := seg.header // Copy
|
||||
if header.ArchivedAt != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
header.ArchivedAt = time.Now().Unix()
|
||||
err := seg.headWriter.Write(func(page []byte) error {
|
||||
header.WriteTo(page)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
seg.w = nil // We won't be writing any more.
|
||||
|
||||
seg.header = header
|
||||
seg.notifyMux.Notify(false, header)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (seg *segment) Iterator(fromSeqNum int64) (Iterator, error) {
|
||||
seg.lock.Lock()
|
||||
defer seg.lock.Unlock()
|
||||
|
||||
if seg.closed {
|
||||
return nil, errs.Closed
|
||||
}
|
||||
|
||||
f, err := os.Open(seg.f.Name())
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
header := seg.header
|
||||
if fromSeqNum == -1 {
|
||||
fromSeqNum = header.FirstSeqNum
|
||||
}
|
||||
|
||||
return newSegmentIterator(
|
||||
f,
|
||||
fromSeqNum,
|
||||
seg.notifyMux.NewRecvr(header))
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (seg *segment) Close() error {
|
||||
seg.lock.Lock()
|
||||
defer seg.lock.Unlock()
|
||||
|
||||
if seg.closed {
|
||||
return nil
|
||||
}
|
||||
|
||||
seg.closed = true
|
||||
|
||||
header := seg.header
|
||||
seg.notifyMux.Notify(true, header)
|
||||
seg.f.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
145
lib/wal/segment_test.go
Normal file
145
lib/wal/segment_test.go
Normal file
@@ -0,0 +1,145 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/atomicheader"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func newSegmentForTesting(t *testing.T) *segment {
|
||||
tmpDir := t.TempDir()
|
||||
seg, err := createSegment(filepath.Join(tmpDir, "x"), 1, 100, 200)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return seg
|
||||
}
|
||||
|
||||
func TestNewSegmentDirNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "notFound", "1245")
|
||||
|
||||
if _, err := createSegment(p, 1, 1234, 5678); !errs.IO.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenSegmentNotFound(t *testing.T) {
|
||||
t.Parallel()
|
||||
tmpDir := t.TempDir()
|
||||
p := filepath.Join(tmpDir, "notFound")
|
||||
|
||||
if _, err := openSegment(p, 1); !errs.IO.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenSegmentTruncatedFile(t *testing.T) {
|
||||
t.Parallel()
|
||||
seg := newSegmentForTesting(t)
|
||||
|
||||
path := seg.f.Name()
|
||||
if err := seg.f.Truncate(4); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
seg.Close()
|
||||
|
||||
if _, err := openSegment(path, 1); !errs.IO.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenSegmentCorruptHeader(t *testing.T) {
|
||||
t.Parallel()
|
||||
seg := newSegmentForTesting(t)
|
||||
|
||||
path := seg.f.Name()
|
||||
buf := make([]byte, atomicheader.ReservedBytes)
|
||||
crand.Read(buf)
|
||||
|
||||
if _, err := seg.f.Seek(0, io.SeekStart); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := seg.f.Write(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
seg.Close()
|
||||
|
||||
if _, err := openSegment(path, 1); !errs.Corrupt.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenSegmentCorruptHeader2(t *testing.T) {
|
||||
t.Parallel()
|
||||
seg := newSegmentForTesting(t)
|
||||
|
||||
path := seg.f.Name()
|
||||
buf := make([]byte, 1024) // 2 pages.
|
||||
crand.Read(buf)
|
||||
|
||||
if _, err := seg.f.Seek(1024, io.SeekStart); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if _, err := seg.f.Write(buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
seg.Close()
|
||||
|
||||
if _, err := openSegment(path, 1); !errs.Corrupt.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentArchiveTwice(t *testing.T) {
|
||||
t.Parallel()
|
||||
seg := newSegmentForTesting(t)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := seg.Archive(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentAppendArchived(t *testing.T) {
|
||||
t.Parallel()
|
||||
seg := newSegmentForTesting(t)
|
||||
|
||||
appendRandomRecords(t, seg, 8)
|
||||
|
||||
if err := seg.Archive(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, _, err := seg.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4}))
|
||||
if !errs.Archived.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSegmentAppendRecordInvalidSeqNum(t *testing.T) {
|
||||
t.Parallel()
|
||||
seg := newSegmentForTesting(t)
|
||||
|
||||
appendRandomRecords(t, seg, 8) // 109 is next.
|
||||
|
||||
_, _, err := seg.appendRecord(Record{
|
||||
SeqNum: 110,
|
||||
TimestampMS: time.Now().UnixMilli(),
|
||||
DataSize: 100,
|
||||
})
|
||||
if !errs.Unexpected.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
232
lib/wal/test-util_test.go
Normal file
232
lib/wal/test-util_test.go
Normal file
@@ -0,0 +1,232 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
crand "crypto/rand"
|
||||
"encoding/base32"
|
||||
"encoding/binary"
|
||||
"hash/crc32"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func randString() string {
|
||||
size := 8 + rand.Intn(92)
|
||||
buf := make([]byte, size)
|
||||
if _, err := crand.Read(buf); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return base32.StdEncoding.EncodeToString(buf)
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type RawRecord struct {
|
||||
Record
|
||||
Data []byte
|
||||
DataCRC uint32
|
||||
}
|
||||
|
||||
func (rr *RawRecord) ReadFrom(t *testing.T, f *os.File, offset int64) {
|
||||
t.Helper()
|
||||
|
||||
buf := make([]byte, recordHeaderSize)
|
||||
if _, err := f.ReadAt(buf, offset); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := rr.Record.readHeaderFrom(readerAtToReader(f, offset)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rr.Data = make([]byte, rr.DataSize+4) // For data and CRC32.
|
||||
if _, err := f.ReadAt(rr.Data, offset+recordHeaderSize); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
storedCRC := binary.LittleEndian.Uint32(rr.Data[rr.DataSize:])
|
||||
computedCRC := crc32.ChecksumIEEE(rr.Data[:rr.DataSize])
|
||||
|
||||
if storedCRC != computedCRC {
|
||||
t.Fatal(storedCRC, computedCRC)
|
||||
}
|
||||
|
||||
rr.Data = rr.Data[:rr.DataSize]
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func appendRandomRecords(t *testing.T, w waLog, count int64) []RawRecord {
|
||||
t.Helper()
|
||||
|
||||
recs := make([]RawRecord, count)
|
||||
|
||||
for i := range recs {
|
||||
rec := RawRecord{
|
||||
Data: []byte(randString()),
|
||||
}
|
||||
rec.DataSize = int64(len(rec.Data))
|
||||
|
||||
seqNum, _, err := w.Append(int64(len(rec.Data)), bytes.NewBuffer(rec.Data))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
rec.SeqNum = seqNum
|
||||
|
||||
recs[i] = rec
|
||||
}
|
||||
|
||||
// Check that sequence numbers are sequential.
|
||||
seqNum := recs[0].SeqNum
|
||||
for _, rec := range recs {
|
||||
if rec.SeqNum != seqNum {
|
||||
t.Fatal(seqNum, rec)
|
||||
}
|
||||
seqNum++
|
||||
}
|
||||
|
||||
return recs
|
||||
}
|
||||
|
||||
func checkIteratorMatches(t *testing.T, it Iterator, recs []RawRecord) {
|
||||
for i, expected := range recs {
|
||||
if !it.Next(time.Millisecond) {
|
||||
t.Fatal(i, "no next")
|
||||
}
|
||||
|
||||
rec := it.Record()
|
||||
|
||||
if rec.SeqNum != expected.SeqNum {
|
||||
t.Fatal(i, rec.SeqNum, expected.SeqNum)
|
||||
}
|
||||
|
||||
if rec.DataSize != expected.DataSize {
|
||||
t.Fatal(i, rec.DataSize, expected.DataSize)
|
||||
}
|
||||
|
||||
if rec.TimestampMS == 0 {
|
||||
t.Fatal(rec.TimestampMS)
|
||||
}
|
||||
|
||||
data := make([]byte, rec.DataSize)
|
||||
if _, err := io.ReadFull(rec.Reader, data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(data, expected.Data) {
|
||||
t.Fatalf("%d %s != %s", i, data, expected.Data)
|
||||
}
|
||||
}
|
||||
|
||||
if it.Error() != nil {
|
||||
t.Fatal(it.Error())
|
||||
}
|
||||
|
||||
// Check that iterator is empty.
|
||||
if it.Next(0) {
|
||||
t.Fatal("extra", it.Record())
|
||||
}
|
||||
}
|
||||
|
||||
func randData() []byte {
|
||||
data := make([]byte, 1+rand.Intn(128))
|
||||
crand.Read(data)
|
||||
return data
|
||||
}
|
||||
|
||||
func writeRandomWithEOF(w waLog, dt time.Duration) error {
|
||||
tStart := time.Now()
|
||||
for time.Since(tStart) < dt {
|
||||
data := randData()
|
||||
_, _, err := w.Append(int64(len(data)), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
_, _, err := w.Append(3, bytes.NewBuffer([]byte("EOF")))
|
||||
return err
|
||||
}
|
||||
|
||||
func waitForEOF(t *testing.T, w *WAL) {
|
||||
t.Helper()
|
||||
|
||||
h := w.seg.Header()
|
||||
it, err := w.Iterator(h.FirstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for it.Next(time.Hour) {
|
||||
rec := it.Record()
|
||||
buf := make([]byte, rec.DataSize)
|
||||
if _, err := io.ReadFull(rec.Reader, buf); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if bytes.Equal(buf, []byte("EOF")) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
t.Fatal("waitForEOF", it.Error())
|
||||
}
|
||||
|
||||
func checkWALsEqual(t *testing.T, w1, w2 *WAL) {
|
||||
t.Helper()
|
||||
|
||||
info1 := w1.Info()
|
||||
info2 := w2.Info()
|
||||
|
||||
if !reflect.DeepEqual(info1, info2) {
|
||||
t.Fatal(info1, info2)
|
||||
}
|
||||
|
||||
it1, err := w1.Iterator(info1.FirstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it1.Close()
|
||||
|
||||
it2, err := w2.Iterator(info2.FirstSeqNum)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer it2.Close()
|
||||
|
||||
for {
|
||||
ok1 := it1.Next(time.Second)
|
||||
ok2 := it2.Next(time.Second)
|
||||
if ok1 != ok2 {
|
||||
t.Fatal(ok1, ok2)
|
||||
}
|
||||
|
||||
if !ok1 {
|
||||
return
|
||||
}
|
||||
|
||||
rec1 := it1.Record()
|
||||
rec2 := it2.Record()
|
||||
|
||||
data1, err := io.ReadAll(rec1.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data2, err := io.ReadAll(rec2.Reader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(data1, data2) {
|
||||
t.Fatal(data1, data2)
|
||||
}
|
||||
}
|
||||
}
|
||||
25
lib/wal/wal-header.go
Normal file
25
lib/wal/wal-header.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package wal
|
||||
|
||||
import "encoding/binary"
|
||||
|
||||
type walHeader struct {
|
||||
FirstSegmentID int64
|
||||
LastSegmentID int64
|
||||
}
|
||||
|
||||
func (h walHeader) WriteTo(b []byte) {
|
||||
vals := []int64{h.FirstSegmentID, h.LastSegmentID}
|
||||
|
||||
for _, val := range vals {
|
||||
binary.LittleEndian.PutUint64(b[0:8], uint64(val))
|
||||
b = b[8:]
|
||||
}
|
||||
}
|
||||
|
||||
func (h *walHeader) ReadFrom(b []byte) {
|
||||
ptrs := []*int64{&h.FirstSegmentID, &h.LastSegmentID}
|
||||
for _, ptr := range ptrs {
|
||||
*ptr = int64(binary.LittleEndian.Uint64(b[0:8]))
|
||||
b = b[8:]
|
||||
}
|
||||
}
|
||||
88
lib/wal/wal-iterator.go
Normal file
88
lib/wal/wal-iterator.go
Normal file
@@ -0,0 +1,88 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"time"
|
||||
)
|
||||
|
||||
type walIterator struct {
|
||||
// getSeg should return a segment given its ID, or return nil.
|
||||
getSeg func(id int64) (*segment, error)
|
||||
seg *segment // Our current segment.
|
||||
it Iterator // Our current segment iterator.
|
||||
seqNum int64
|
||||
err error
|
||||
}
|
||||
|
||||
func newWALIterator(
|
||||
getSeg func(id int64) (*segment, error),
|
||||
seg *segment,
|
||||
fromSeqNum int64,
|
||||
) (
|
||||
*walIterator,
|
||||
error,
|
||||
) {
|
||||
segIter, err := seg.Iterator(fromSeqNum)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &walIterator{
|
||||
getSeg: getSeg,
|
||||
seg: seg,
|
||||
it: segIter,
|
||||
seqNum: fromSeqNum,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (it *walIterator) Next(timeout time.Duration) bool {
|
||||
if it.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if it.it.Next(timeout) {
|
||||
it.seqNum++
|
||||
return true
|
||||
}
|
||||
|
||||
it.err = it.it.Error()
|
||||
if !errs.EOFArchived.Is(it.err) {
|
||||
return false
|
||||
}
|
||||
|
||||
it.it.Close()
|
||||
|
||||
id := it.seg.ID + 1
|
||||
it.seg, it.err = it.getSeg(id)
|
||||
|
||||
if it.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if it.seg == nil {
|
||||
it.err = errs.NotFound // Could be not-found, or closed.
|
||||
return false
|
||||
}
|
||||
|
||||
it.it, it.err = it.seg.Iterator(it.seqNum)
|
||||
if it.err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return it.Next(timeout)
|
||||
}
|
||||
|
||||
func (it *walIterator) Record() Record {
|
||||
return it.it.Record()
|
||||
}
|
||||
|
||||
func (it *walIterator) Error() error {
|
||||
return it.err
|
||||
}
|
||||
|
||||
func (it *walIterator) Close() {
|
||||
if it.it != nil {
|
||||
it.it.Close()
|
||||
}
|
||||
it.it = nil
|
||||
}
|
||||
60
lib/wal/wal-recv.go
Normal file
60
lib/wal/wal-recv.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func (wal *WAL) Recv(conn net.Conn, timeout time.Duration) error {
|
||||
defer conn.Close()
|
||||
|
||||
var (
|
||||
rec Record
|
||||
msgType = make([]byte, 1)
|
||||
)
|
||||
|
||||
// Send sequence number.
|
||||
seqNum := wal.Info().LastSeqNum + 1
|
||||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
if err := binary.Write(conn, binary.LittleEndian, seqNum); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
conn.SetWriteDeadline(time.Time{})
|
||||
|
||||
for {
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
|
||||
if _, err := io.ReadFull(conn, msgType); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
switch msgType[0] {
|
||||
|
||||
case msgTypeHeartbeat:
|
||||
// Nothing to do.
|
||||
|
||||
case msgTypeError:
|
||||
e := &errs.Error{}
|
||||
if err := e.Read(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return e
|
||||
|
||||
case msgTypeRec:
|
||||
if err := rec.readFrom(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, _, err := wal.appendRecord(rec); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return errs.Unexpected.WithMsg("Unknown message type: %d", msgType[0])
|
||||
}
|
||||
}
|
||||
}
|
||||
73
lib/wal/wal-send.go
Normal file
73
lib/wal/wal-send.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
msgTypeRec = 8
|
||||
msgTypeHeartbeat = 16
|
||||
msgTypeError = 32
|
||||
)
|
||||
|
||||
func (wal *WAL) Send(conn net.Conn, timeout time.Duration) error {
|
||||
defer conn.Close()
|
||||
|
||||
var (
|
||||
seqNum int64
|
||||
heartbeatTimeout = timeout / 8
|
||||
)
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(timeout))
|
||||
if err := binary.Read(conn, binary.LittleEndian, &seqNum); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
conn.SetReadDeadline(time.Time{})
|
||||
|
||||
it, err := wal.Iterator(seqNum)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer it.Close()
|
||||
|
||||
for {
|
||||
if it.Next(heartbeatTimeout) {
|
||||
rec := it.Record()
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
|
||||
if _, err := conn.Write([]byte{msgTypeRec}); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
if _, err := rec.writeTo(conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if it.Error() != nil {
|
||||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
if _, err := conn.Write([]byte{msgTypeError}); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
err, ok := it.Error().(*errs.Error)
|
||||
if !ok {
|
||||
err = errs.Unexpected.WithErr(err)
|
||||
}
|
||||
err.Write(conn)
|
||||
// w.Flush()
|
||||
return err
|
||||
}
|
||||
|
||||
conn.SetWriteDeadline(time.Now().Add(timeout))
|
||||
if _, err := conn.Write([]byte{msgTypeHeartbeat}); err != nil {
|
||||
return errs.IO.WithErr(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
271
lib/wal/wal-sendrecv_test.go
Normal file
271
lib/wal/wal-sendrecv_test.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"git.crumpington.com/public/jldb/lib/testutil"
|
||||
"log"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestSendRecvHarness(t *testing.T) {
|
||||
t.Parallel()
|
||||
(&SendRecvTestHarness{}).Run(t)
|
||||
}
|
||||
|
||||
type SendRecvTestHarness struct{}
|
||||
|
||||
func (h *SendRecvTestHarness) Run(t *testing.T) {
|
||||
val := reflect.ValueOf(h)
|
||||
typ := val.Type()
|
||||
for i := 0; i < typ.NumMethod(); i++ {
|
||||
method := typ.Method(i)
|
||||
if !strings.HasPrefix(method.Name, "Test") {
|
||||
continue
|
||||
}
|
||||
|
||||
t.Run(method.Name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pDir := t.TempDir()
|
||||
sDir := t.TempDir()
|
||||
|
||||
config := Config{
|
||||
SegMinCount: 8,
|
||||
SegMaxAgeSec: 1,
|
||||
}
|
||||
|
||||
pWAL, err := Create(pDir, 1, config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer pWAL.Close()
|
||||
|
||||
sWAL, err := Create(sDir, 1, config)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer sWAL.Close()
|
||||
|
||||
nw := testutil.NewNetwork()
|
||||
defer func() {
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
}()
|
||||
|
||||
val.MethodByName(method.Name).Call([]reflect.Value{
|
||||
reflect.ValueOf(t),
|
||||
reflect.ValueOf(pWAL),
|
||||
reflect.ValueOf(sWAL),
|
||||
reflect.ValueOf(nw),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (h *SendRecvTestHarness) TestSimple(
|
||||
t *testing.T,
|
||||
pWAL *WAL,
|
||||
sWAL *WAL,
|
||||
nw *testutil.Network,
|
||||
) {
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := writeRandomWithEOF(pWAL, 5*time.Second); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Accept()
|
||||
if err := pWAL.Send(conn, 8*time.Second); err != nil {
|
||||
log.Printf("Send error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Recv in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Dial()
|
||||
if err := sWAL.Recv(conn, 8*time.Second); err != nil {
|
||||
log.Printf("Recv error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
waitForEOF(t, sWAL)
|
||||
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
wg.Wait()
|
||||
|
||||
checkWALsEqual(t, pWAL, sWAL)
|
||||
}
|
||||
|
||||
func (h *SendRecvTestHarness) TestWriteThenRead(
|
||||
t *testing.T,
|
||||
pWAL *WAL,
|
||||
sWAL *WAL,
|
||||
nw *testutil.Network,
|
||||
) {
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
if err := writeRandomWithEOF(pWAL, 2*time.Second); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Accept()
|
||||
if err := pWAL.Send(conn, 8*time.Second); err != nil {
|
||||
log.Printf("Send error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Recv in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Dial()
|
||||
if err := sWAL.Recv(conn, 8*time.Second); err != nil {
|
||||
log.Printf("Recv error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
waitForEOF(t, sWAL)
|
||||
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
wg.Wait()
|
||||
|
||||
checkWALsEqual(t, pWAL, sWAL)
|
||||
}
|
||||
|
||||
func (h *SendRecvTestHarness) TestNetworkFailures(
|
||||
t *testing.T,
|
||||
pWAL *WAL,
|
||||
sWAL *WAL,
|
||||
nw *testutil.Network,
|
||||
) {
|
||||
recvDone := &atomic.Bool{}
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
writeRandomWithEOF(pWAL, 10*time.Second)
|
||||
}()
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
for {
|
||||
if recvDone.Load() {
|
||||
return
|
||||
}
|
||||
if conn := nw.Accept(); conn != nil {
|
||||
pWAL.Send(conn, 8*time.Second)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Recv in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for !recvDone.Load() {
|
||||
if conn := nw.Dial(); conn != nil {
|
||||
sWAL.Recv(conn, 8*time.Second)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
wg.Add(1)
|
||||
failureCount := 0
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
if recvDone.Load() {
|
||||
return
|
||||
}
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
|
||||
failureCount++
|
||||
if rand.Float64() < 0.5 {
|
||||
nw.CloseClient()
|
||||
} else {
|
||||
nw.CloseServer()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
waitForEOF(t, sWAL)
|
||||
recvDone.Store(true)
|
||||
wg.Wait()
|
||||
|
||||
log.Printf("%d network failures.", failureCount)
|
||||
|
||||
if failureCount < 10 {
|
||||
t.Fatal("Expected more failures.")
|
||||
}
|
||||
|
||||
checkWALsEqual(t, pWAL, sWAL)
|
||||
}
|
||||
|
||||
func (h *SendRecvTestHarness) TestSenderClose(
|
||||
t *testing.T,
|
||||
pWAL *WAL,
|
||||
sWAL *WAL,
|
||||
nw *testutil.Network,
|
||||
) {
|
||||
wg := sync.WaitGroup{}
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if err := writeRandomWithEOF(pWAL, 5*time.Second); !errs.Closed.Is(err) {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Close primary after some time.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
time.Sleep(time.Second)
|
||||
pWAL.Close()
|
||||
}()
|
||||
|
||||
// Send in the background.
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
conn := nw.Accept()
|
||||
if err := pWAL.Send(conn, 8*time.Second); err != nil {
|
||||
log.Printf("Send error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
conn := nw.Dial()
|
||||
if err := sWAL.Recv(conn, 8*time.Second); !errs.Closed.Is(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nw.CloseServer()
|
||||
nw.CloseClient()
|
||||
wg.Wait()
|
||||
}
|
||||
321
lib/wal/wal.go
Normal file
321
lib/wal/wal.go
Normal file
@@ -0,0 +1,321 @@
|
||||
package wal
|
||||
|
||||
import (
|
||||
"io"
|
||||
"git.crumpington.com/public/jldb/lib/atomicheader"
|
||||
"git.crumpington.com/public/jldb/lib/errs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
SegMinCount int64
|
||||
SegMaxAgeSec int64
|
||||
}
|
||||
|
||||
type WAL struct {
|
||||
rootDir string
|
||||
conf Config
|
||||
|
||||
lock sync.Mutex // Protects the fields below.
|
||||
|
||||
closed bool
|
||||
header walHeader
|
||||
headerWriter *atomicheader.Handler
|
||||
f *os.File // WAL header.
|
||||
segments map[int64]*segment // Used by the iterator.
|
||||
seg *segment // Current segment.
|
||||
}
|
||||
|
||||
func Create(rootDir string, firstSeqNum int64, conf Config) (*WAL, error) {
|
||||
w := &WAL{rootDir: rootDir, conf: conf}
|
||||
|
||||
seg, err := createSegment(w.segmentPath(1), 1, firstSeqNum, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer seg.Close()
|
||||
|
||||
f, err := os.Create(w.headerPath())
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := atomicheader.Init(f); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handler, err := atomicheader.Open(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
header := walHeader{
|
||||
FirstSegmentID: 1,
|
||||
LastSegmentID: 1,
|
||||
}
|
||||
|
||||
err = handler.Write(func(page []byte) error {
|
||||
header.WriteTo(page)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return Open(rootDir, conf)
|
||||
}
|
||||
|
||||
func Open(rootDir string, conf Config) (*WAL, error) {
|
||||
w := &WAL{rootDir: rootDir, conf: conf}
|
||||
|
||||
f, err := os.OpenFile(w.headerPath(), os.O_RDWR, 0600)
|
||||
if err != nil {
|
||||
return nil, errs.IO.WithErr(err)
|
||||
}
|
||||
|
||||
handler, err := atomicheader.Open(f)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var header walHeader
|
||||
err = handler.Read(func(page []byte) error {
|
||||
header.ReadFrom(page)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.header = header
|
||||
w.headerWriter = handler
|
||||
w.f = f
|
||||
w.segments = map[int64]*segment{}
|
||||
|
||||
for segID := header.FirstSegmentID; segID < header.LastSegmentID+1; segID++ {
|
||||
segID := segID
|
||||
seg, err := openSegment(w.segmentPath(segID), segID)
|
||||
if err != nil {
|
||||
w.Close()
|
||||
return nil, err
|
||||
}
|
||||
w.segments[segID] = seg
|
||||
}
|
||||
|
||||
w.seg = w.segments[header.LastSegmentID]
|
||||
if err := w.grow(); err != nil {
|
||||
w.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return w, nil
|
||||
}
|
||||
|
||||
func (w *WAL) Close() error {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return nil
|
||||
}
|
||||
w.closed = true
|
||||
|
||||
for _, seg := range w.segments {
|
||||
seg.Close()
|
||||
delete(w.segments, seg.ID)
|
||||
}
|
||||
|
||||
w.f.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WAL) Info() (info Info) {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
h := w.header
|
||||
|
||||
info.FirstSeqNum = w.segments[h.FirstSegmentID].Header().FirstSeqNum
|
||||
|
||||
lastHeader := w.segments[h.LastSegmentID].Header()
|
||||
info.LastSeqNum = lastHeader.LastSeqNum
|
||||
info.LastTimestampMS = lastHeader.LastTimestampMS
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func (w *WAL) Append(dataSize int64, r io.Reader) (int64, int64, error) {
|
||||
return w.appendRecord(Record{
|
||||
SeqNum: -1,
|
||||
TimestampMS: time.Now().UnixMilli(),
|
||||
DataSize: dataSize,
|
||||
Reader: r,
|
||||
})
|
||||
}
|
||||
|
||||
func (w *WAL) appendRecord(rec Record) (int64, int64, error) {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return 0, 0, errs.Closed
|
||||
}
|
||||
|
||||
if err := w.grow(); err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
return w.seg.appendRecord(rec)
|
||||
}
|
||||
|
||||
func (w *WAL) Iterator(fromSeqNum int64) (Iterator, error) {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
if w.closed {
|
||||
return nil, errs.Closed
|
||||
}
|
||||
|
||||
header := w.header
|
||||
var seg *segment
|
||||
|
||||
getSeg := func(id int64) (*segment, error) {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
if w.closed {
|
||||
return nil, errs.Closed
|
||||
}
|
||||
return w.segments[id], nil
|
||||
}
|
||||
|
||||
if fromSeqNum == -1 {
|
||||
seg = w.segments[header.FirstSegmentID]
|
||||
return newWALIterator(getSeg, seg, fromSeqNum)
|
||||
}
|
||||
|
||||
// Seek to the appropriate segment.
|
||||
seg = w.segments[header.FirstSegmentID]
|
||||
for seg != nil {
|
||||
h := seg.Header()
|
||||
if fromSeqNum >= h.FirstSeqNum && fromSeqNum <= h.LastSeqNum+1 {
|
||||
return newWALIterator(getSeg, seg, fromSeqNum)
|
||||
}
|
||||
seg = w.segments[seg.ID+1]
|
||||
}
|
||||
|
||||
return nil, errs.NotFound
|
||||
}
|
||||
|
||||
func (w *WAL) DeleteBefore(timestamp, keepSeqNum int64) error {
|
||||
for {
|
||||
seg, err := w.removeSeg(timestamp, keepSeqNum)
|
||||
if err != nil || seg == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id := seg.ID
|
||||
os.RemoveAll(w.segmentPath(id))
|
||||
seg.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *WAL) removeSeg(timestamp, keepSeqNum int64) (*segment, error) {
|
||||
w.lock.Lock()
|
||||
defer w.lock.Unlock()
|
||||
|
||||
header := w.header
|
||||
if header.FirstSegmentID == header.LastSegmentID {
|
||||
return nil, nil // Nothing to delete now.
|
||||
}
|
||||
|
||||
id := header.FirstSegmentID
|
||||
seg := w.segments[id]
|
||||
if seg == nil {
|
||||
return nil, errs.Unexpected.WithMsg("segment %d not found", id)
|
||||
}
|
||||
|
||||
segHeader := seg.Header()
|
||||
if seg == w.seg || segHeader.ArchivedAt > timestamp {
|
||||
return nil, nil // Nothing to delete now.
|
||||
}
|
||||
|
||||
if segHeader.LastSeqNum >= keepSeqNum {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
header.FirstSegmentID = id + 1
|
||||
err := w.headerWriter.Write(func(page []byte) error {
|
||||
header.WriteTo(page)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
w.header = header
|
||||
delete(w.segments, id)
|
||||
|
||||
return seg, nil
|
||||
}
|
||||
|
||||
func (w *WAL) grow() error {
|
||||
segHeader := w.seg.Header()
|
||||
|
||||
if segHeader.ArchivedAt == 0 {
|
||||
if (segHeader.LastSeqNum - segHeader.FirstSeqNum) < w.conf.SegMinCount {
|
||||
return nil
|
||||
}
|
||||
if time.Now().Unix()-segHeader.CreatedAt < w.conf.SegMaxAgeSec {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
newSegID := w.seg.ID + 1
|
||||
firstSeqNum := segHeader.LastSeqNum + 1
|
||||
timestampMS := segHeader.LastTimestampMS
|
||||
|
||||
newSeg, err := createSegment(w.segmentPath(newSegID), newSegID, firstSeqNum, timestampMS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
walHeader := w.header
|
||||
walHeader.LastSegmentID = newSegID
|
||||
|
||||
err = w.headerWriter.Write(func(page []byte) error {
|
||||
walHeader.WriteTo(page)
|
||||
return nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
newSeg.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.seg.Archive(); err != nil {
|
||||
newSeg.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
w.seg = newSeg
|
||||
w.segments[newSeg.ID] = newSeg
|
||||
w.header = walHeader
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *WAL) headerPath() string {
|
||||
return filepath.Join(w.rootDir, "header")
|
||||
}
|
||||
|
||||
func (w *WAL) segmentPath(segID int64) string {
|
||||
return filepath.Join(w.rootDir, "seg."+strconv.FormatInt(segID, 10))
|
||||
}
|
||||
Reference in New Issue
Block a user