commit 71eb6b0c7ef39b1cf189012237a309054fe3285c Author: jdl Date: Fri Oct 13 11:43:27 2023 +0200 Initial commit diff --git a/README.md b/README.md new file mode 100644 index 0000000..877eca0 --- /dev/null +++ b/README.md @@ -0,0 +1,9 @@ +# jldb + +Replicated in-memory database and file store. + +## TODO + +* [ ] mdb: tests for sanitize and validate functions +* [ ] Test: lib/wal iterator w/ corrupt file (random corruptions) +* [ ] Test: lib/wal io.go diff --git a/dep-graph.sh b/dep-graph.sh new file mode 100755 index 0000000..8562fd7 --- /dev/null +++ b/dep-graph.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +godepgraph \ + -s \ + -p github.com \ + ./$1 > .deps.dot && + xdot .deps.dot + +rm .deps.dot diff --git a/fstore/browser.go b/fstore/browser.go new file mode 100644 index 0000000..bfe71fb --- /dev/null +++ b/fstore/browser.go @@ -0,0 +1,136 @@ +package fstore + +import ( + "embed" + "io" + "git.crumpington.com/public/jldb/fstore/pages" + "net/http" + "os" + "path/filepath" +) + +//go:embed static/* +var staticFS embed.FS + +type browser struct { + store *Store +} + +func (s *Store) ServeBrowser(listenAddr string) error { + b := &browser{s} + http.HandleFunc("/", b.handle) + http.Handle("/static/", http.FileServer(http.FS(staticFS))) + return http.ListenAndServe(listenAddr, nil) +} + +func (b *browser) handle(w http.ResponseWriter, r *http.Request) { + switch r.Method { + case http.MethodGet: + b.handleGet(w, r) + case http.MethodPost: + b.handlePOST(w, r) + default: + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + } +} + +func (b *browser) handleGet(w http.ResponseWriter, r *http.Request) { + path := cleanPath(r.URL.Path) + if err := validatePath(path); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + info, err := b.store.Stat(path) + if err != nil { + if os.IsNotExist(err) { + pages.Page{Path: path}.Render(w) + return + } + http.Error(w, err.Error(), http.StatusNotFound) + return + } + + if !info.IsDir() { + b.store.Serve(w, r, path) + return + } + + dirs, files, err := b.store.List(path) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + pages.Page{ + Path: path, + Dirs: dirs, + Files: files, + }.Render(w) +} + +// Handle actions: +// - upload (multipart), +// - delete +func (b *browser) handlePOST(w http.ResponseWriter, r *http.Request) { + if err := r.ParseMultipartForm(1024 * 1024); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + switch r.Form.Get("Action") { + case "Upload": + b.handlePOSTUpload(w, r) + case "Delete": + b.handlePOSTDelete(w, r) + default: + http.Error(w, "unknown action", http.StatusBadRequest) + } +} + +func (b *browser) handlePOSTUpload(w http.ResponseWriter, r *http.Request) { + file, handler, err := r.FormFile("File") + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + defer file.Close() + + relativePath := handler.Filename + if p := r.Form.Get("Path"); p != "" { + relativePath = p + } + fullPath := filepath.Join(r.URL.Path, relativePath) + + tmpPath := b.store.GetTempFilePath() + defer os.RemoveAll(tmpPath) + + tmpFile, err := os.Create(tmpPath) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer tmpFile.Close() + + if _, err := io.Copy(tmpFile, file); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + if err := b.store.Store(tmpPath, fullPath); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, filepath.Dir(fullPath), http.StatusSeeOther) +} + +func (b *browser) handlePOSTDelete(w http.ResponseWriter, r *http.Request) { + path := r.Form.Get("Path") + if err := b.store.Remove(path); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, filepath.Dir(path), http.StatusSeeOther) +} diff --git a/fstore/command.go b/fstore/command.go new file mode 100644 index 0000000..07ffa81 --- /dev/null +++ b/fstore/command.go @@ -0,0 +1,64 @@ +package fstore + +import ( + "bytes" + "encoding/binary" + "io" + "git.crumpington.com/public/jldb/lib/errs" +) + +type command struct { + Path string + Store bool + TempID uint64 + FileSize int64 // In bytes. + File io.Reader +} + +func (c command) Reader(buf *bytes.Buffer) (int64, io.Reader) { + buf.Reset() + vars := []any{ + uint32(len(c.Path)), + c.Store, + c.TempID, + c.FileSize, + } + + for _, v := range vars { + binary.Write(buf, binary.LittleEndian, v) + } + buf.Write([]byte(c.Path)) + + if c.Store { + return int64(buf.Len()) + c.FileSize, io.MultiReader(buf, c.File) + } else { + return int64(buf.Len()), buf + } +} + +func (c *command) ReadFrom(r io.Reader) error { + pathLen := uint32(0) + + ptrs := []any{ + &pathLen, + &c.Store, + &c.TempID, + &c.FileSize, + } + + for _, ptr := range ptrs { + if err := binary.Read(r, binary.LittleEndian, ptr); err != nil { + return errs.IO.WithErr(err) + } + } + + pathBuf := make([]byte, pathLen) + if _, err := r.Read(pathBuf); err != nil { + return errs.IO.WithErr(err) + } + + c.Path = string(pathBuf) + c.File = r + + return nil +} diff --git a/fstore/pages/page.go b/fstore/pages/page.go new file mode 100644 index 0000000..cb2686c --- /dev/null +++ b/fstore/pages/page.go @@ -0,0 +1,95 @@ +package pages + +import ( + "html/template" + "io" + "path/filepath" +) + +type Page struct { + Path string + Dirs []string + Files []string + + // Created in Render. + BreadCrumbs []Directory +} + +func (ctx Page) FullPath(dir string) string { + return filepath.Join(ctx.Path, dir) +} + +func (ctx Page) Render(w io.Writer) { + crumbs := []Directory{} + current := Directory(ctx.Path) + + for current != "/" { + crumbs = append([]Directory{current}, crumbs...) + current = Directory(filepath.Dir(string(current))) + } + + ctx.BreadCrumbs = crumbs + authRegisterTmpl.Execute(w, ctx) +} + +var authRegisterTmpl = template.Must(template.New("").Parse(` + + + + + + +

+ root / {{range .BreadCrumbs}}{{.Name}} / {{end -}} +

+ +
+
+ Upload a file + + + + +
+
+ + ../
+ {{range .Dirs}} + {{.}}/
+ {{end}} + + {{range .Files}} +
+ + + [X] +
+ {{.}} +
+ {{end}} + + + +`)) + +type Directory string + +func (d Directory) Name() string { + return filepath.Base(string(d)) +} diff --git a/fstore/paths.go b/fstore/paths.go new file mode 100644 index 0000000..0efe19a --- /dev/null +++ b/fstore/paths.go @@ -0,0 +1,63 @@ +package fstore + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "path/filepath" + "strconv" +) + +func filesRootPath(rootDir string) string { + return filepath.Clean(filepath.Join(rootDir, "files")) +} + +func repDirPath(rootDir string) string { + return filepath.Clean(filepath.Join(rootDir, "rep")) +} + +func tempDirPath(rootDir string) string { + return filepath.Clean(filepath.Join(rootDir, "tmp")) +} + +func tempFilePath(inDir string, id uint64) string { + return filepath.Join(inDir, strconv.FormatUint(id, 10)) +} + +func validatePath(p string) error { + if len(p) == 0 { + return errs.InvalidPath.WithMsg("empty path") + } + + if p[0] != '/' { + return errs.InvalidPath.WithMsg("path must be absolute") + } + + for _, c := range p { + switch c { + case '/', '-', '_', '.': + continue + default: + } + + if c >= 'a' && c <= 'z' { + continue + } + + if c >= '0' && c <= '9' { + continue + } + + return errs.InvalidPath.WithMsg("invalid character in path: %s", string([]rune{c})) + } + return nil +} + +func cleanPath(p string) string { + if len(p) == 0 { + return "/" + } + + if p[0] != '/' { + p = "/" + p + } + return filepath.Clean(p) +} diff --git a/fstore/paths_test.go b/fstore/paths_test.go new file mode 100644 index 0000000..38cfb20 --- /dev/null +++ b/fstore/paths_test.go @@ -0,0 +1,55 @@ +package fstore + +import "testing" + +func TestValidatePath_valid(t *testing.T) { + cases := []string{ + "/", + "/a/z/0.9/a-b_c.d/", + "/a/../b", + "/x/abcdefghijklmnopqrstuvwxyz/0123456789/_.-", + } + + for _, s := range cases { + if err := validatePath(s); err != nil { + t.Fatal(s, err) + } + } +} + +func TestValidatePath_invalid(t *testing.T) { + cases := []string{ + "", + "/A", + "/a/b/~xyz/", + "/a\\b", + "a/b/c", + } + + for _, s := range cases { + if err := validatePath(s); err == nil { + t.Fatal(s) + } + } +} + +func TestCleanPath(t *testing.T) { + type Case struct { + In, Out string + } + + cases := []Case{ + {"", "/"}, + {"../", "/"}, + {"a/b", "/a/b"}, + {"/a/b/../../../", "/"}, + {"a/b/../../../", "/"}, + } + + for _, c := range cases { + out := cleanPath(c.In) + if out != c.Out { + t.Fatal(c.In, out, c.Out) + } + } +} diff --git a/fstore/stateutil_test.go b/fstore/stateutil_test.go new file mode 100644 index 0000000..96f0891 --- /dev/null +++ b/fstore/stateutil_test.go @@ -0,0 +1,50 @@ +package fstore + +import ( + "path/filepath" + "strings" +) + +type StoreState struct { + Path string + IsDir bool + Dirs map[string]StoreState + Files map[string]StoreState + FileData string +} + +func NewStoreState(in map[string]string) StoreState { + root := StoreState{ + Path: "/", + IsDir: true, + Dirs: map[string]StoreState{}, + Files: map[string]StoreState{}, + } + + for path, fileData := range in { + slugs := strings.Split(path[1:], "/") // Remove leading slash. + + parent := root + + // Add directories. + for _, part := range slugs[:len(slugs)-1] { + if _, ok := parent.Dirs[part]; !ok { + parent.Dirs[part] = StoreState{ + Path: filepath.Join(parent.Path, part), + IsDir: true, + Dirs: map[string]StoreState{}, + Files: map[string]StoreState{}, + } + } + parent = parent.Dirs[part] + } + + parent.Files[slugs[len(slugs)-1]] = StoreState{ + Path: path, + IsDir: false, + FileData: fileData, + } + } + + return root +} diff --git a/fstore/static/css/pure-min.css b/fstore/static/css/pure-min.css new file mode 100644 index 0000000..acdc431 --- /dev/null +++ b/fstore/static/css/pure-min.css @@ -0,0 +1,11 @@ +/*! +Pure v3.0.0 +Copyright 2013 Yahoo! +Licensed under the BSD License. +https://github.com/pure-css/pure/blob/master/LICENSE +*/ +/*! +normalize.css v | MIT License | https://necolas.github.io/normalize.css/ +Copyright (c) Nicolas Gallagher and Jonathan Neal +*/ +/*! normalize.css v8.0.1 | MIT License | github.com/necolas/normalize.css */html{line-height:1.15;-webkit-text-size-adjust:100%}body{margin:0}main{display:block}h1{font-size:2em;margin:.67em 0}hr{box-sizing:content-box;height:0;overflow:visible}pre{font-family:monospace,monospace;font-size:1em}a{background-color:transparent}abbr[title]{border-bottom:none;text-decoration:underline;-webkit-text-decoration:underline dotted;text-decoration:underline dotted}b,strong{font-weight:bolder}code,kbd,samp{font-family:monospace,monospace;font-size:1em}small{font-size:80%}sub,sup{font-size:75%;line-height:0;position:relative;vertical-align:baseline}sub{bottom:-.25em}sup{top:-.5em}img{border-style:none}button,input,optgroup,select,textarea{font-family:inherit;font-size:100%;line-height:1.15;margin:0}button,input{overflow:visible}button,select{text-transform:none}[type=button],[type=reset],[type=submit],button{-webkit-appearance:button}[type=button]::-moz-focus-inner,[type=reset]::-moz-focus-inner,[type=submit]::-moz-focus-inner,button::-moz-focus-inner{border-style:none;padding:0}[type=button]:-moz-focusring,[type=reset]:-moz-focusring,[type=submit]:-moz-focusring,button:-moz-focusring{outline:1px dotted ButtonText}fieldset{padding:.35em .75em .625em}legend{box-sizing:border-box;color:inherit;display:table;max-width:100%;padding:0;white-space:normal}progress{vertical-align:baseline}textarea{overflow:auto}[type=checkbox],[type=radio]{box-sizing:border-box;padding:0}[type=number]::-webkit-inner-spin-button,[type=number]::-webkit-outer-spin-button{height:auto}[type=search]{-webkit-appearance:textfield;outline-offset:-2px}[type=search]::-webkit-search-decoration{-webkit-appearance:none}::-webkit-file-upload-button{-webkit-appearance:button;font:inherit}details{display:block}summary{display:list-item}template{display:none}[hidden]{display:none}html{font-family:sans-serif}.hidden,[hidden]{display:none!important}.pure-img{max-width:100%;height:auto;display:block}.pure-g{display:flex;flex-flow:row wrap;align-content:flex-start}.pure-u{display:inline-block;vertical-align:top}.pure-u-1,.pure-u-1-1,.pure-u-1-12,.pure-u-1-2,.pure-u-1-24,.pure-u-1-3,.pure-u-1-4,.pure-u-1-5,.pure-u-1-6,.pure-u-1-8,.pure-u-10-24,.pure-u-11-12,.pure-u-11-24,.pure-u-12-24,.pure-u-13-24,.pure-u-14-24,.pure-u-15-24,.pure-u-16-24,.pure-u-17-24,.pure-u-18-24,.pure-u-19-24,.pure-u-2-24,.pure-u-2-3,.pure-u-2-5,.pure-u-20-24,.pure-u-21-24,.pure-u-22-24,.pure-u-23-24,.pure-u-24-24,.pure-u-3-24,.pure-u-3-4,.pure-u-3-5,.pure-u-3-8,.pure-u-4-24,.pure-u-4-5,.pure-u-5-12,.pure-u-5-24,.pure-u-5-5,.pure-u-5-6,.pure-u-5-8,.pure-u-6-24,.pure-u-7-12,.pure-u-7-24,.pure-u-7-8,.pure-u-8-24,.pure-u-9-24{display:inline-block;letter-spacing:normal;word-spacing:normal;vertical-align:top;text-rendering:auto}.pure-u-1-24{width:4.1667%}.pure-u-1-12,.pure-u-2-24{width:8.3333%}.pure-u-1-8,.pure-u-3-24{width:12.5%}.pure-u-1-6,.pure-u-4-24{width:16.6667%}.pure-u-1-5{width:20%}.pure-u-5-24{width:20.8333%}.pure-u-1-4,.pure-u-6-24{width:25%}.pure-u-7-24{width:29.1667%}.pure-u-1-3,.pure-u-8-24{width:33.3333%}.pure-u-3-8,.pure-u-9-24{width:37.5%}.pure-u-2-5{width:40%}.pure-u-10-24,.pure-u-5-12{width:41.6667%}.pure-u-11-24{width:45.8333%}.pure-u-1-2,.pure-u-12-24{width:50%}.pure-u-13-24{width:54.1667%}.pure-u-14-24,.pure-u-7-12{width:58.3333%}.pure-u-3-5{width:60%}.pure-u-15-24,.pure-u-5-8{width:62.5%}.pure-u-16-24,.pure-u-2-3{width:66.6667%}.pure-u-17-24{width:70.8333%}.pure-u-18-24,.pure-u-3-4{width:75%}.pure-u-19-24{width:79.1667%}.pure-u-4-5{width:80%}.pure-u-20-24,.pure-u-5-6{width:83.3333%}.pure-u-21-24,.pure-u-7-8{width:87.5%}.pure-u-11-12,.pure-u-22-24{width:91.6667%}.pure-u-23-24{width:95.8333%}.pure-u-1,.pure-u-1-1,.pure-u-24-24,.pure-u-5-5{width:100%}.pure-button{display:inline-block;line-height:normal;white-space:nowrap;vertical-align:middle;text-align:center;cursor:pointer;-webkit-user-drag:none;-webkit-user-select:none;user-select:none;box-sizing:border-box}.pure-button::-moz-focus-inner{padding:0;border:0}.pure-button-group{letter-spacing:-.31em;text-rendering:optimizespeed}.opera-only :-o-prefocus,.pure-button-group{word-spacing:-0.43em}.pure-button-group .pure-button{letter-spacing:normal;word-spacing:normal;vertical-align:top;text-rendering:auto}.pure-button{font-family:inherit;font-size:100%;padding:.5em 1em;color:rgba(0,0,0,.8);border:none transparent;background-color:#e6e6e6;text-decoration:none;border-radius:2px}.pure-button-hover,.pure-button:focus,.pure-button:hover{background-image:linear-gradient(transparent,rgba(0,0,0,.05) 40%,rgba(0,0,0,.1))}.pure-button:focus{outline:0}.pure-button-active,.pure-button:active{box-shadow:0 0 0 1px rgba(0,0,0,.15) inset,0 0 6px rgba(0,0,0,.2) inset;border-color:#000}.pure-button-disabled,.pure-button-disabled:active,.pure-button-disabled:focus,.pure-button-disabled:hover,.pure-button[disabled]{border:none;background-image:none;opacity:.4;cursor:not-allowed;box-shadow:none;pointer-events:none}.pure-button-hidden{display:none}.pure-button-primary,.pure-button-selected,a.pure-button-primary,a.pure-button-selected{background-color:#0078e7;color:#fff}.pure-button-group .pure-button{margin:0;border-radius:0;border-right:1px solid rgba(0,0,0,.2)}.pure-button-group .pure-button:first-child{border-top-left-radius:2px;border-bottom-left-radius:2px}.pure-button-group .pure-button:last-child{border-top-right-radius:2px;border-bottom-right-radius:2px;border-right:none}.pure-form input[type=color],.pure-form input[type=date],.pure-form input[type=datetime-local],.pure-form input[type=datetime],.pure-form input[type=email],.pure-form input[type=month],.pure-form input[type=number],.pure-form input[type=password],.pure-form input[type=search],.pure-form input[type=tel],.pure-form input[type=text],.pure-form input[type=time],.pure-form input[type=url],.pure-form input[type=week],.pure-form select,.pure-form textarea{padding:.5em .6em;display:inline-block;border:1px solid #ccc;box-shadow:inset 0 1px 3px #ddd;border-radius:4px;vertical-align:middle;box-sizing:border-box}.pure-form input:not([type]){padding:.5em .6em;display:inline-block;border:1px solid #ccc;box-shadow:inset 0 1px 3px #ddd;border-radius:4px;box-sizing:border-box}.pure-form input[type=color]{padding:.2em .5em}.pure-form input[type=color]:focus,.pure-form input[type=date]:focus,.pure-form input[type=datetime-local]:focus,.pure-form input[type=datetime]:focus,.pure-form input[type=email]:focus,.pure-form input[type=month]:focus,.pure-form input[type=number]:focus,.pure-form input[type=password]:focus,.pure-form input[type=search]:focus,.pure-form input[type=tel]:focus,.pure-form input[type=text]:focus,.pure-form input[type=time]:focus,.pure-form input[type=url]:focus,.pure-form input[type=week]:focus,.pure-form select:focus,.pure-form textarea:focus{outline:0;border-color:#129fea}.pure-form input:not([type]):focus{outline:0;border-color:#129fea}.pure-form input[type=checkbox]:focus,.pure-form input[type=file]:focus,.pure-form input[type=radio]:focus{outline:thin solid #129FEA;outline:1px auto #129FEA}.pure-form .pure-checkbox,.pure-form .pure-radio{margin:.5em 0;display:block}.pure-form input[type=color][disabled],.pure-form input[type=date][disabled],.pure-form input[type=datetime-local][disabled],.pure-form input[type=datetime][disabled],.pure-form input[type=email][disabled],.pure-form input[type=month][disabled],.pure-form input[type=number][disabled],.pure-form input[type=password][disabled],.pure-form input[type=search][disabled],.pure-form input[type=tel][disabled],.pure-form input[type=text][disabled],.pure-form input[type=time][disabled],.pure-form input[type=url][disabled],.pure-form input[type=week][disabled],.pure-form select[disabled],.pure-form textarea[disabled]{cursor:not-allowed;background-color:#eaeded;color:#cad2d3}.pure-form input:not([type])[disabled]{cursor:not-allowed;background-color:#eaeded;color:#cad2d3}.pure-form input[readonly],.pure-form select[readonly],.pure-form textarea[readonly]{background-color:#eee;color:#777;border-color:#ccc}.pure-form input:focus:invalid,.pure-form select:focus:invalid,.pure-form textarea:focus:invalid{color:#b94a48;border-color:#e9322d}.pure-form input[type=checkbox]:focus:invalid:focus,.pure-form input[type=file]:focus:invalid:focus,.pure-form input[type=radio]:focus:invalid:focus{outline-color:#e9322d}.pure-form select{height:2.25em;border:1px solid #ccc;background-color:#fff}.pure-form select[multiple]{height:auto}.pure-form label{margin:.5em 0 .2em}.pure-form fieldset{margin:0;padding:.35em 0 .75em;border:0}.pure-form legend{display:block;width:100%;padding:.3em 0;margin-bottom:.3em;color:#333;border-bottom:1px solid #e5e5e5}.pure-form-stacked input[type=color],.pure-form-stacked input[type=date],.pure-form-stacked input[type=datetime-local],.pure-form-stacked input[type=datetime],.pure-form-stacked input[type=email],.pure-form-stacked input[type=file],.pure-form-stacked input[type=month],.pure-form-stacked input[type=number],.pure-form-stacked input[type=password],.pure-form-stacked input[type=search],.pure-form-stacked input[type=tel],.pure-form-stacked input[type=text],.pure-form-stacked input[type=time],.pure-form-stacked input[type=url],.pure-form-stacked input[type=week],.pure-form-stacked label,.pure-form-stacked select,.pure-form-stacked textarea{display:block;margin:.25em 0}.pure-form-stacked input:not([type]){display:block;margin:.25em 0}.pure-form-aligned input,.pure-form-aligned select,.pure-form-aligned textarea,.pure-form-message-inline{display:inline-block;vertical-align:middle}.pure-form-aligned textarea{vertical-align:top}.pure-form-aligned .pure-control-group{margin-bottom:.5em}.pure-form-aligned .pure-control-group label{text-align:right;display:inline-block;vertical-align:middle;width:10em;margin:0 1em 0 0}.pure-form-aligned .pure-controls{margin:1.5em 0 0 11em}.pure-form .pure-input-rounded,.pure-form input.pure-input-rounded{border-radius:2em;padding:.5em 1em}.pure-form .pure-group fieldset{margin-bottom:10px}.pure-form .pure-group input,.pure-form .pure-group textarea{display:block;padding:10px;margin:0 0 -1px;border-radius:0;position:relative;top:-1px}.pure-form .pure-group input:focus,.pure-form .pure-group textarea:focus{z-index:3}.pure-form .pure-group input:first-child,.pure-form .pure-group textarea:first-child{top:1px;border-radius:4px 4px 0 0;margin:0}.pure-form .pure-group input:first-child:last-child,.pure-form .pure-group textarea:first-child:last-child{top:1px;border-radius:4px;margin:0}.pure-form .pure-group input:last-child,.pure-form .pure-group textarea:last-child{top:-2px;border-radius:0 0 4px 4px;margin:0}.pure-form .pure-group button{margin:.35em 0}.pure-form .pure-input-1{width:100%}.pure-form .pure-input-3-4{width:75%}.pure-form .pure-input-2-3{width:66%}.pure-form .pure-input-1-2{width:50%}.pure-form .pure-input-1-3{width:33%}.pure-form .pure-input-1-4{width:25%}.pure-form-message-inline{display:inline-block;padding-left:.3em;color:#666;vertical-align:middle;font-size:.875em}.pure-form-message{display:block;color:#666;font-size:.875em}@media only screen and (max-width :480px){.pure-form button[type=submit]{margin:.7em 0 0}.pure-form input:not([type]),.pure-form input[type=color],.pure-form input[type=date],.pure-form input[type=datetime-local],.pure-form input[type=datetime],.pure-form input[type=email],.pure-form input[type=month],.pure-form input[type=number],.pure-form input[type=password],.pure-form input[type=search],.pure-form input[type=tel],.pure-form input[type=text],.pure-form input[type=time],.pure-form input[type=url],.pure-form input[type=week],.pure-form label{margin-bottom:.3em;display:block}.pure-group input:not([type]),.pure-group input[type=color],.pure-group input[type=date],.pure-group input[type=datetime-local],.pure-group input[type=datetime],.pure-group input[type=email],.pure-group input[type=month],.pure-group input[type=number],.pure-group input[type=password],.pure-group input[type=search],.pure-group input[type=tel],.pure-group input[type=text],.pure-group input[type=time],.pure-group input[type=url],.pure-group input[type=week]{margin-bottom:0}.pure-form-aligned .pure-control-group label{margin-bottom:.3em;text-align:left;display:block;width:100%}.pure-form-aligned .pure-controls{margin:1.5em 0 0 0}.pure-form-message,.pure-form-message-inline{display:block;font-size:.75em;padding:.2em 0 .8em}}.pure-menu{box-sizing:border-box}.pure-menu-fixed{position:fixed;left:0;top:0;z-index:3}.pure-menu-item,.pure-menu-list{position:relative}.pure-menu-list{list-style:none;margin:0;padding:0}.pure-menu-item{padding:0;margin:0;height:100%}.pure-menu-heading,.pure-menu-link{display:block;text-decoration:none;white-space:nowrap}.pure-menu-horizontal{width:100%;white-space:nowrap}.pure-menu-horizontal .pure-menu-list{display:inline-block}.pure-menu-horizontal .pure-menu-heading,.pure-menu-horizontal .pure-menu-item,.pure-menu-horizontal .pure-menu-separator{display:inline-block;vertical-align:middle}.pure-menu-item .pure-menu-item{display:block}.pure-menu-children{display:none;position:absolute;left:100%;top:0;margin:0;padding:0;z-index:3}.pure-menu-horizontal .pure-menu-children{left:0;top:auto;width:inherit}.pure-menu-active>.pure-menu-children,.pure-menu-allow-hover:hover>.pure-menu-children{display:block;position:absolute}.pure-menu-has-children>.pure-menu-link:after{padding-left:.5em;content:"\25B8";font-size:small}.pure-menu-horizontal .pure-menu-has-children>.pure-menu-link:after{content:"\25BE"}.pure-menu-scrollable{overflow-y:scroll;overflow-x:hidden}.pure-menu-scrollable .pure-menu-list{display:block}.pure-menu-horizontal.pure-menu-scrollable .pure-menu-list{display:inline-block}.pure-menu-horizontal.pure-menu-scrollable{white-space:nowrap;overflow-y:hidden;overflow-x:auto;padding:.5em 0}.pure-menu-horizontal .pure-menu-children .pure-menu-separator,.pure-menu-separator{background-color:#ccc;height:1px;margin:.3em 0}.pure-menu-horizontal .pure-menu-separator{width:1px;height:1.3em;margin:0 .3em}.pure-menu-horizontal .pure-menu-children .pure-menu-separator{display:block;width:auto}.pure-menu-heading{text-transform:uppercase;color:#565d64}.pure-menu-link{color:#777}.pure-menu-children{background-color:#fff}.pure-menu-heading,.pure-menu-link{padding:.5em 1em}.pure-menu-disabled{opacity:.5}.pure-menu-disabled .pure-menu-link:hover{background-color:transparent;cursor:default}.pure-menu-active>.pure-menu-link,.pure-menu-link:focus,.pure-menu-link:hover{background-color:#eee}.pure-menu-selected>.pure-menu-link,.pure-menu-selected>.pure-menu-link:visited{color:#000}.pure-table{border-collapse:collapse;border-spacing:0;empty-cells:show;border:1px solid #cbcbcb}.pure-table caption{color:#000;font:italic 85%/1 arial,sans-serif;padding:1em 0;text-align:center}.pure-table td,.pure-table th{border-left:1px solid #cbcbcb;border-width:0 0 0 1px;font-size:inherit;margin:0;overflow:visible;padding:.5em 1em}.pure-table thead{background-color:#e0e0e0;color:#000;text-align:left;vertical-align:bottom}.pure-table td{background-color:transparent}.pure-table-odd td{background-color:#f2f2f2}.pure-table-striped tr:nth-child(2n-1) td{background-color:#f2f2f2}.pure-table-bordered td{border-bottom:1px solid #cbcbcb}.pure-table-bordered tbody>tr:last-child>td{border-bottom-width:0}.pure-table-horizontal td,.pure-table-horizontal th{border-width:0 0 1px 0;border-bottom:1px solid #cbcbcb}.pure-table-horizontal tbody>tr:last-child>td{border-bottom-width:0} \ No newline at end of file diff --git a/fstore/store-commands.go b/fstore/store-commands.go new file mode 100644 index 0000000..6889e2a --- /dev/null +++ b/fstore/store-commands.go @@ -0,0 +1,82 @@ +package fstore + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/idgen" + "log" + "os" + "path/filepath" +) + +func (s *Store) applyStoreFromReader(cmd command) error { + tmpPath := tempFilePath(s.tmpDir, idgen.Next()) + + f, err := os.Create(tmpPath) + if err != nil { + return errs.IO.WithErr(err) + } + defer f.Close() + + n, err := f.ReadFrom(cmd.File) + if err != nil { + return errs.IO.WithErr(err) + } + if n != cmd.FileSize { + return errs.IO.WithMsg("expected to %d bytes, but got %d", cmd.FileSize, n) + } + + if err := f.Sync(); err != nil { + return errs.IO.WithErr(err) + } + + fullPath := filepath.Join(s.filesRoot, cmd.Path) + + if err := os.MkdirAll(filepath.Dir(fullPath), 0700); err != nil { + return errs.IO.WithErr(err) + } + + if err := os.Rename(tmpPath, fullPath); err != nil { + return errs.IO.WithErr(err) + } + + return nil +} + +func (s *Store) applyStoreFromTempID(cmd command) error { + tmpPath := tempFilePath(s.tmpDir, cmd.TempID) + fullPath := filepath.Join(s.filesRoot, cmd.Path) + + info, err := os.Stat(tmpPath) + if err != nil || info.Size() != cmd.FileSize { + log.Printf("[STORE] Primary falling back on reader copy: %v", err) + return s.applyStoreFromReader(cmd) + } + + if err := os.MkdirAll(filepath.Dir(fullPath), 0700); err != nil { + return errs.IO.WithErr(err) + } + + if err := os.Rename(tmpPath, fullPath); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +func (s *Store) applyRemove(cmd command) error { + finalPath := filepath.Join(s.filesRoot, cmd.Path) + if err := os.Remove(finalPath); err != nil { + if !os.IsNotExist(err) { + return errs.IO.WithErr(err) + } + } + + parent := filepath.Dir(finalPath) + for parent != s.filesRoot { + if err := os.Remove(parent); err != nil { + return nil + } + parent = filepath.Dir(parent) + } + + return nil +} diff --git a/fstore/store-harness_test.go b/fstore/store-harness_test.go new file mode 100644 index 0000000..be8bf45 --- /dev/null +++ b/fstore/store-harness_test.go @@ -0,0 +1,136 @@ +package fstore + +import ( + "net/http" + "net/http/httptest" + "reflect" + "strings" + "testing" + "time" +) + +func TestStoreHarness(t *testing.T) { + StoreTestHarness{}.Run(t) +} + +type StoreTestHarness struct{} + +func (h StoreTestHarness) Run(t *testing.T) { + val := reflect.ValueOf(h) + typ := val.Type() + for i := 0; i < typ.NumMethod(); i++ { + method := typ.Method(i) + + if !strings.HasPrefix(method.Name, "Test") { + continue + } + + t.Run(method.Name, func(t *testing.T) { + t.Parallel() + + rootDir := t.TempDir() + + primary, err := Open(Config{ + RootDir: rootDir, + Primary: true, + WALSegMinCount: 1, + WALSegMaxAgeSec: 1, + WALSegGCAgeSec: 2, + }) + if err != nil { + t.Fatal(err) + } + defer primary.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/rep/", primary.Handle) + testServer := httptest.NewServer(mux) + defer testServer.Close() + + rootDir2 := t.TempDir() + secondary, err := Open(Config{ + RootDir: rootDir2, + Primary: false, + PrimaryEndpoint: testServer.URL + "/rep/", + }) + if err != nil { + t.Fatal(err) + } + defer secondary.Close() + + val.MethodByName(method.Name).Call([]reflect.Value{ + reflect.ValueOf(t), + reflect.ValueOf(primary), + reflect.ValueOf(secondary), + }) + }) + } +} + +func (StoreTestHarness) TestBasic(t *testing.T, primary, secondary *Store) { + stateChan := make(chan map[string]string, 1) + go func() { + stateChan <- primary.WriteRandomFor(4 * time.Second) + }() + + state := <-stateChan + secondary.WaitForParity(primary) + + primary.AssertState(t, state) + secondary.AssertState(t, state) +} + +func (StoreTestHarness) TestWriteThenFollow(t *testing.T, primary, secondary *Store) { + secondary.Close() + stateChan := make(chan map[string]string, 1) + go func() { + stateChan <- primary.WriteRandomFor(4 * time.Second) + }() + + state := <-stateChan + + var err error + secondary, err = Open(secondary.conf) + if err != nil { + t.Fatal(err) + } + + secondary.WaitForParity(primary) + + primary.AssertState(t, state) + secondary.AssertState(t, state) +} + +func (StoreTestHarness) TestCloseAndOpenFollowerConcurrently(t *testing.T, primary, secondary *Store) { + secondary.Close() + stateChan := make(chan map[string]string, 1) + go func() { + stateChan <- primary.WriteRandomFor(8 * time.Second) + }() + + var err error + + for i := 0; i < 4; i++ { + time.Sleep(time.Second) + + secondary, err = Open(secondary.conf) + if err != nil { + t.Fatal(err) + } + time.Sleep(time.Second) + + secondary.Close() + } + + secondary, err = Open(secondary.conf) + if err != nil { + t.Fatal(err) + } + + state := <-stateChan + + secondary.WaitForParity(primary) + + primary.AssertState(t, state) + secondary.AssertState(t, state) +} diff --git a/fstore/store-rep.go b/fstore/store-rep.go new file mode 100644 index 0000000..4d3d563 --- /dev/null +++ b/fstore/store-rep.go @@ -0,0 +1,171 @@ +package fstore + +import ( + "encoding/binary" + "errors" + "io" + "io/fs" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/wal" + "net" + "os" + "path/filepath" + "time" +) + +func (s *Store) repSendState(conn net.Conn) error { + err := filepath.Walk(s.filesRoot, func(path string, info fs.FileInfo, err error) error { + if err != nil { + // Skip deleted files. + if os.IsNotExist(err) { + return nil + } + return err + } + if info.IsDir() { + return nil + } + + f, err := os.Open(path) + if err != nil { + return err + } + defer f.Close() + + relPath, err := filepath.Rel(s.filesRoot, path) + if err != nil { + return err + } + + conn.SetWriteDeadline(time.Now().Add(s.conf.NetTimeout)) + if err := binary.Write(conn, binary.LittleEndian, int32(len(relPath))); err != nil { + return err + } + if _, err := conn.Write([]byte(relPath)); err != nil { + return err + } + if err := binary.Write(conn, binary.LittleEndian, int64(info.Size())); err != nil { + return err + } + + conn.SetWriteDeadline(time.Now().Add(s.conf.NetTimeout)) + if _, err := io.CopyN(conn, f, info.Size()); err != nil { + return err + } + + return nil + }) + + if err != nil { + return errs.IO.WithErr(err) + } + + conn.SetWriteDeadline(time.Now().Add(s.conf.NetTimeout)) + if err := binary.Write(conn, binary.LittleEndian, int32(0)); err != nil { + return errs.IO.WithErr(err) + } + + return nil +} + +func (s *Store) repRecvState(conn net.Conn) error { + var ( + errorDone = errors.New("Done") + pathLen = int32(0) + fileSize = int64(0) + pathBuf = make([]byte, 1024) + ) + + for { + + err := func() error { + conn.SetReadDeadline(time.Now().Add(s.conf.NetTimeout)) + if err := binary.Read(conn, binary.LittleEndian, &pathLen); err != nil { + return err + } + if pathLen == 0 { + return errorDone + } + + if cap(pathBuf) < int(pathLen) { + pathBuf = make([]byte, pathLen) + } + pathBuf = pathBuf[:pathLen] + + if _, err := io.ReadFull(conn, pathBuf); err != nil { + return err + } + + fullPath := filepath.Join(s.filesRoot, string(pathBuf)) + + if err := os.MkdirAll(filepath.Dir(fullPath), 0700); err != nil { + return err + } + + if err := binary.Read(conn, binary.LittleEndian, &fileSize); err != nil { + return err + } + + f, err := os.Create(fullPath) + if err != nil { + return err + } + defer f.Close() + + conn.SetReadDeadline(time.Now().Add(s.conf.NetTimeout)) + if _, err = io.CopyN(f, conn, fileSize); err != nil { + return err + } + + return f.Sync() + }() + + if err != nil { + if err == errorDone { + return nil + } + return errs.IO.WithErr(err) + } + } +} + +func (s *Store) repInitStorage() (err error) { + if err := os.MkdirAll(s.filesRoot, 0700); err != nil { + return errs.IO.WithErr(err) + } + if err := os.MkdirAll(s.tmpDir, 0700); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +func (s *Store) repReplay(rec wal.Record) (err error) { + cmd := command{} + if err := cmd.ReadFrom(rec.Reader); err != nil { + return err + } + if cmd.Store { + return s.applyStoreFromReader(cmd) + } + return s.applyRemove(cmd) +} + +func (s *Store) repLoadFromStorage() (err error) { + // Nothing to do. + return nil +} + +func (s *Store) repApply(rec wal.Record) (err error) { + cmd := command{} + if err := cmd.ReadFrom(rec.Reader); err != nil { + return err + } + + if cmd.Store { + if s.conf.Primary { + return s.applyStoreFromTempID(cmd) + } + return s.applyStoreFromReader(cmd) + } + return s.applyRemove(cmd) +} diff --git a/fstore/store-testrunner_test.go b/fstore/store-testrunner_test.go new file mode 100644 index 0000000..50e5fef --- /dev/null +++ b/fstore/store-testrunner_test.go @@ -0,0 +1,119 @@ +package fstore + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +type StoreTestCase struct { + Name string + Update func(t *testing.T, s *Store) error + ExpectedError error + State map[string]string +} + +func TestRunnerTestCases(t *testing.T) { + t.Helper() + rootDir := t.TempDir() + + store, err := Open(Config{ + RootDir: rootDir, + Primary: true, + }) + if err != nil { + t.Fatal(err) + } + defer store.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/rep/", store.Handle) + testServer := httptest.NewServer(mux) + defer testServer.Close() + + rootDir2 := t.TempDir() + secondary, err := Open(Config{ + RootDir: rootDir2, + Primary: false, + PrimaryEndpoint: testServer.URL + "/rep/", + }) + if err != nil { + t.Fatal(err) + } + defer secondary.Close() + + for _, testCase := range storeTestCases { + testCase := testCase + t.Run(testCase.Name, func(t *testing.T) { + testRunnerRunTestCase(t, store, secondary, testCase) + }) + } +} + +func testRunnerRunTestCase(t *testing.T, store, secondary *Store, testCase StoreTestCase) { + err := testCase.Update(t, store) + if !errors.Is(err, testCase.ExpectedError) { + t.Fatal(testCase.Name, err, testCase.ExpectedError) + } + + store.AssertState(t, testCase.State) + + pInfo := store.rep.Info() + for { + sInfo := secondary.rep.Info() + if sInfo.AppSeqNum == pInfo.AppSeqNum { + break + } + time.Sleep(time.Millisecond) + } + + secondary.AssertState(t, testCase.State) +} + +var storeTestCases = []StoreTestCase{ + { + Name: "store a file", + Update: func(t *testing.T, s *Store) error { + return s.StoreString("hello world", "/a/b/c") + }, + ExpectedError: nil, + State: map[string]string{ + "/a/b/c": "hello world", + }, + }, { + Name: "store more files", + Update: func(t *testing.T, s *Store) error { + if err := s.StoreString("good bye", "/a/b/x"); err != nil { + return err + } + return s.StoreString("xxx", "/x") + }, + ExpectedError: nil, + State: map[string]string{ + "/a/b/c": "hello world", + "/a/b/x": "good bye", + "/x": "xxx", + }, + }, { + Name: "remove a file", + Update: func(t *testing.T, s *Store) error { + return s.Remove("/x") + }, + ExpectedError: nil, + State: map[string]string{ + "/a/b/c": "hello world", + "/a/b/x": "good bye", + }, + }, { + Name: "remove another file", + Update: func(t *testing.T, s *Store) error { + return s.Remove("/a/b/c") + }, + ExpectedError: nil, + State: map[string]string{ + "/a/b/x": "good bye", + }, + }, +} diff --git a/fstore/store.go b/fstore/store.go new file mode 100644 index 0000000..e2d13de --- /dev/null +++ b/fstore/store.go @@ -0,0 +1,279 @@ +package fstore + +import ( + "bytes" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/idgen" + "git.crumpington.com/public/jldb/lib/rep" + "net/http" + "os" + "path/filepath" + "strconv" + "sync" + "time" +) + +type Config struct { + RootDir string + Primary bool + ReplicationPSK string + NetTimeout time.Duration // Default is 1 minute. + + // WAL settings. + WALSegMinCount int64 // Minimum Change sets in a segment. Default is 1024. + WALSegMaxAgeSec int64 // Maximum age of a segment. Default is 1 hour. + WALSegGCAgeSec int64 // Segment age for garbage collection. Default is 7 days. + + // For use by secondary. + PrimaryEndpoint string +} + +func (c Config) repConfig() rep.Config { + return rep.Config{ + RootDir: repDirPath(c.RootDir), + Primary: c.Primary, + ReplicationPSK: c.ReplicationPSK, + NetTimeout: c.NetTimeout, + WALSegMinCount: c.WALSegMinCount, + WALSegMaxAgeSec: c.WALSegMaxAgeSec, + WALSegGCAgeSec: c.WALSegGCAgeSec, + PrimaryEndpoint: c.PrimaryEndpoint, + SynchronousAppend: true, + } +} + +type Store struct { + lock sync.Mutex + buf *bytes.Buffer + rep *rep.Replicator + conf Config + filesRoot string // Absolute, no trailing slash. + tmpDir string // Absolute, no trailing slash. +} + +func Open(conf Config) (*Store, error) { + if conf.NetTimeout <= 0 { + conf.NetTimeout = time.Minute + } + + s := &Store{ + buf: &bytes.Buffer{}, + conf: conf, + filesRoot: filesRootPath(conf.RootDir), + tmpDir: tempDirPath(conf.RootDir), + } + + var err error + + repConf := s.conf.repConfig() + + s.rep, err = rep.Open( + rep.App{ + SendState: s.repSendState, + RecvState: s.repRecvState, + InitStorage: s.repInitStorage, + Replay: s.repReplay, + LoadFromStorage: s.repLoadFromStorage, + Apply: s.repApply, + }, + repConf) + + if err != nil { + return nil, err + } + + return s, nil +} + +func (s *Store) GetTempFilePath() string { + return tempFilePath(s.tmpDir, idgen.Next()) +} + +func (s *Store) StoreString(str string, finalPath string) error { + return s.StoreBytes([]byte(str), finalPath) +} + +func (s *Store) StoreBytes(b []byte, finalPath string) error { + tmpPath := s.GetTempFilePath() + + if err := os.WriteFile(tmpPath, b, 0600); err != nil { + return err + } + defer os.RemoveAll(tmpPath) + + return s.Store(tmpPath, finalPath) +} + +func (s *Store) Store(tmpPath, finalPath string) error { + if !s.conf.Primary { + return errs.NotAuthorized.WithMsg("not primary") + } + + userPath, _, err := s.cleanAndValidatePath(finalPath) + if err != nil { + return err + } + + idStr := filepath.Base(tmpPath) + tmpID, _ := strconv.ParseUint(idStr, 10, 64) + + s.lock.Lock() + defer s.lock.Unlock() + + f, err := os.Open(tmpPath) + if err != nil { + return err + } + defer f.Close() + + fi, err := f.Stat() + if err != nil { + return err + } + + cmd := command{ + Path: userPath, + Store: true, + TempID: tmpID, + FileSize: fi.Size(), + File: f, + } + + size, reader := cmd.Reader(s.buf) + + _, _, err = s.rep.Append(size, reader) + return err +} + +func (s *Store) Remove(filePath string) error { + if !s.conf.Primary { + return errs.NotAuthorized.WithMsg("not primary") + } + + userPath, _, err := s.cleanAndValidatePath(filePath) + if err != nil { + return err + } + + s.lock.Lock() + defer s.lock.Unlock() + + cmd := command{ + Path: userPath, + Store: false, + TempID: 0, + FileSize: 0, + } + + size, reader := cmd.Reader(s.buf) + + _, _, err = s.rep.Append(size, reader) + return err +} + +func (s *Store) List(p string) (dirs, files []string, err error) { + _, fullPath, err := s.cleanAndValidatePath(p) + if err != nil { + return nil, nil, err + } + + fi, err := os.Stat(fullPath) + if err != nil { + if os.IsNotExist(err) { + return nil, nil, nil + } + return nil, nil, err + } + + if !fi.IsDir() { + return nil, []string{fi.Name()}, nil + } + + entries, err := os.ReadDir(fullPath) + if err != nil { + return nil, nil, err + } + + for _, e := range entries { + if e.IsDir() { + dirs = append(dirs, e.Name()) + } else { + files = append(files, e.Name()) + } + } + + return dirs, files, nil +} + +func (s *Store) Stat(p string) (os.FileInfo, error) { + _, fullPath, err := s.cleanAndValidatePath(p) + if err != nil { + return nil, err + } + + fi, err := os.Stat(fullPath) + if err != nil { + return nil, err + } + + return fi, nil +} + +func (s *Store) WriteTo(w io.Writer, filePath string) error { + _, fullPath, err := s.cleanAndValidatePath(filePath) + if err != nil { + return err + } + + f, err := os.Open(fullPath) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.Copy(w, f); err != nil { + return err + } + + return nil +} + +func (s *Store) Serve(w http.ResponseWriter, r *http.Request, p string) { + _, fullPath, err := s.cleanAndValidatePath(p) + if err != nil { + http.Error(w, "not found", http.StatusNotFound) + return + } + http.ServeFile(w, r, fullPath) +} + +func (s *Store) ServeFallback(w http.ResponseWriter, r *http.Request, paths ...string) { + for _, p := range paths { + _, fullPath, err := s.cleanAndValidatePath(p) + if err == nil { + fi, err := os.Stat(fullPath) + if err == nil && !fi.IsDir() { + http.ServeFile(w, r, fullPath) + return + } + } + } + http.Error(w, "not found", http.StatusNotFound) +} + +func (s *Store) Handle(w http.ResponseWriter, r *http.Request) { + s.rep.Handle(w, r) +} + +func (s *Store) Close() error { + return s.rep.Close() +} + +func (s *Store) cleanAndValidatePath(in string) (userPath, fullPath string, err error) { + userPath = cleanPath(in) + if err := validatePath(userPath); err != nil { + return "", "", err + } + return userPath, filepath.Join(s.filesRoot, userPath), nil +} diff --git a/fstore/store_test.go b/fstore/store_test.go new file mode 100644 index 0000000..912969d --- /dev/null +++ b/fstore/store_test.go @@ -0,0 +1,91 @@ +package fstore + +import ( + "bytes" + "math/rand" + "path/filepath" + "strconv" + "testing" + "time" +) + +func (s *Store) ReadString(t *testing.T, filePath string) string { + buf := &bytes.Buffer{} + if err := s.WriteTo(buf, filePath); err != nil { + t.Fatal(err) + } + return buf.String() +} + +func (s *Store) AssertState(t *testing.T, in map[string]string) { + state := NewStoreState(in) + s.AssertStateDir(t, state) +} + +func (s *Store) AssertStateDir(t *testing.T, dir StoreState) { + dirs, files, err := s.List(dir.Path) + if err != nil { + t.Fatal(err) + } + + // check file lengths. + if len(files) != len(dir.Files) { + t.Fatal(files, dir.Files) + } + + // check dir lengths. + if len(dirs) != len(dir.Dirs) { + t.Fatal(dirs, dir.Dirs) + } + + for _, file := range dir.Files { + expectedContent := file.FileData + actualContent := s.ReadString(t, file.Path) + if expectedContent != actualContent { + t.Fatal(expectedContent, actualContent) + } + } + + for _, dir := range dir.Dirs { + s.AssertStateDir(t, dir) + } +} + +func (s *Store) WriteRandomFor(dt time.Duration) map[string]string { + state := map[string]string{} + tStart := time.Now() + for time.Since(tStart) < dt { + slug1 := strconv.FormatInt(rand.Int63n(10), 10) + slug2 := strconv.FormatInt(rand.Int63n(10), 10) + + path := filepath.Join("/", slug1, slug2) + + if rand.Float32() < 0.05 { + if err := s.Remove(path); err != nil { + panic(err) + } + delete(state, path) + } else { + data := randString() + state[path] = data + if err := s.StoreString(data, path); err != nil { + panic(err) + } + } + + time.Sleep(time.Millisecond) + } + + return state +} + +func (s *Store) WaitForParity(rhs *Store) { + for { + i1 := s.rep.Info() + i2 := rhs.rep.Info() + if i1.AppSeqNum == i2.AppSeqNum { + return + } + time.Sleep(time.Millisecond) + } +} diff --git a/fstore/templates/page.html b/fstore/templates/page.html new file mode 100644 index 0000000..1a70262 --- /dev/null +++ b/fstore/templates/page.html @@ -0,0 +1,6 @@ + + + +
{{.}}
+ + diff --git a/fstore/test_util.go b/fstore/test_util.go new file mode 100644 index 0000000..2b6064e --- /dev/null +++ b/fstore/test_util.go @@ -0,0 +1,16 @@ +package fstore + +import ( + crand "crypto/rand" + "encoding/base32" + "math/rand" +) + +func randString() string { + size := 8 + rand.Intn(92) + buf := make([]byte, size) + if _, err := crand.Read(buf); err != nil { + panic(err) + } + return base32.StdEncoding.EncodeToString(buf) +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..31caf80 --- /dev/null +++ b/go.mod @@ -0,0 +1,9 @@ +module git.crumpington.com/public/jldb + +go 1.21.1 + +require ( + github.com/google/btree v1.1.2 + golang.org/x/net v0.15.0 + golang.org/x/sys v0.12.0 +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5921adb --- /dev/null +++ b/go.sum @@ -0,0 +1,6 @@ +github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= +github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= +golang.org/x/net v0.15.0 h1:ugBLEUaxABaB5AJqW9enI0ACdci2RUd4eP51NTBvuJ8= +golang.org/x/net v0.15.0/go.mod h1:idbUs1IY1+zTqbi8yxTbhexhEEk5ur9LInksu6HrEpk= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/lib/atomicheader/atomicheader.go b/lib/atomicheader/atomicheader.go new file mode 100644 index 0000000..9404dab --- /dev/null +++ b/lib/atomicheader/atomicheader.go @@ -0,0 +1,136 @@ +package atomicheader + +import ( + "encoding/binary" + "hash/crc32" + "git.crumpington.com/public/jldb/lib/errs" + "os" + "sync" +) + +const ( + PageSize = 512 + AvailabePageSize = 508 + + ReservedBytes = PageSize * 4 + + offsetSwitch = 1 * PageSize + offset1 = 2 * PageSize + offset2 = 3 * PageSize +) + +type Handler struct { + lock sync.Mutex + switchPage []byte // At offsetSwitch. + page []byte // Page buffer is re-used for reading and writing. + + currentPage int64 // Either 0 or 1. + f *os.File +} + +func Init(f *os.File) error { + if err := f.Truncate(ReservedBytes); err != nil { + return errs.IO.WithErr(err) + } + + switchPage := make([]byte, PageSize) + switchPage[0] = 2 + if _, err := f.WriteAt(switchPage, offsetSwitch); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +func Open(f *os.File) (*Handler, error) { + switchPage := make([]byte, PageSize) + + if _, err := f.ReadAt(switchPage, offsetSwitch); err != nil { + return nil, errs.IO.WithErr(err) + } + + h := &Handler{ + switchPage: switchPage, + page: make([]byte, PageSize), + currentPage: int64(switchPage[0]), + f: f, + } + + if h.currentPage != 1 && h.currentPage != 2 { + return nil, errs.Corrupt.WithMsg("invalid page id: %d", h.currentPage) + } + + return h, nil +} + +// Read reads the currently active header page. +func (h *Handler) Read(read func(page []byte) error) error { + h.lock.Lock() + defer h.lock.Unlock() + + if _, err := h.f.ReadAt(h.page, h.currentOffset()); err != nil { + return errs.IO.WithErr(err) + } + + computedCRC := crc32.ChecksumIEEE(h.page[:PageSize-4]) + storedCRC := binary.LittleEndian.Uint32(h.page[PageSize-4:]) + if computedCRC != storedCRC { + return errs.Corrupt.WithMsg("checksum mismatch") + } + + return read(h.page) +} + +// Write writes the currently active header page. The page buffer given to the +// function may contain old data, so the caller may need to zero some bytes if +// necessary. +func (h *Handler) Write(update func(page []byte) error) error { + h.lock.Lock() + defer h.lock.Unlock() + + if err := update(h.page); err != nil { + return err + } + + crc := crc32.ChecksumIEEE(h.page[:PageSize-4]) + binary.LittleEndian.PutUint32(h.page[PageSize-4:], crc) + + newPageNum := 1 + h.currentPage%2 + newOffset := h.getOffset(newPageNum) + + if _, err := h.f.WriteAt(h.page, newOffset); err != nil { + return errs.IO.WithErr(err) + } + + if err := h.f.Sync(); err != nil { + return errs.IO.WithErr(err) + } + + h.switchPage[0] = byte(newPageNum) + if _, err := h.f.WriteAt(h.switchPage, offsetSwitch); err != nil { + return errs.IO.WithErr(err) + } + + if err := h.f.Sync(); err != nil { + return errs.IO.WithErr(err) + } + + h.currentPage = newPageNum + return nil +} + +// ---------------------------------------------------------------------------- + +func (h *Handler) currentOffset() int64 { + return h.getOffset(h.currentPage) +} + +func (h *Handler) getOffset(pageNum int64) int64 { + switch pageNum { + case 1: + return offset1 + case 2: + return offset2 + default: + panic("Invalid page number.") + } +} diff --git a/lib/atomicheader/atomicheader_test.go b/lib/atomicheader/atomicheader_test.go new file mode 100644 index 0000000..d72d71d --- /dev/null +++ b/lib/atomicheader/atomicheader_test.go @@ -0,0 +1,121 @@ +package atomicheader + +import ( + "errors" + "os" + "path/filepath" + "sync" + "testing" + "time" +) + +func NewForTesting(t *testing.T) (*Handler, func()) { + tmpDir := t.TempDir() + + f, err := os.Create(filepath.Join(tmpDir, "h")) + if err != nil { + t.Fatal(err) + } + + if err := Init(f); err != nil { + t.Fatal(err) + } + + h, err := Open(f) + if err != nil { + t.Fatal(err) + } + + return h, func() { + f.Close() + os.RemoveAll(tmpDir) + } +} + +func TestAtomicHeaderSimple(t *testing.T) { + h, cleanup := NewForTesting(t) + defer cleanup() + + err := h.Write(func(page []byte) error { + for i := range page[:AvailabePageSize] { + page[i] = byte(i) % 11 + } + return nil + }) + + if err != nil { + t.Fatal(err) + } + + err = h.Read(func(page []byte) error { + for i := range page[:AvailabePageSize] { + if page[i] != byte(i)%11 { + t.Fatal(i, page[i], byte(i)%11) + } + } + return nil + }) + + if err != nil { + t.Fatal(err) + } +} + +func TestAtomicHeaderThreaded(t *testing.T) { + h, cleanup := NewForTesting(t) + defer cleanup() + + expectedValue := byte(0) + + writeErr := make(chan error, 1) + stop := make(chan struct{}) + wg := sync.WaitGroup{} + wg.Add(1) + + go func() { + defer wg.Done() + for { + select { + case <-stop: + writeErr <- nil + return + default: + } + + err := h.Write(func(page []byte) error { + if page[0] != expectedValue { + return errors.New("Unexpected current value.") + } + + expectedValue++ + page[0] = expectedValue + return nil + }) + if err != nil { + writeErr <- err + return + } + time.Sleep(time.Millisecond / 13) + } + }() + + for i := 0; i < 2000; i++ { + time.Sleep(time.Millisecond) + err := h.Read(func(page []byte) error { + if page[0] != expectedValue { + t.Fatal(page[0], expectedValue) + } + return nil + }) + if err != nil { + t.Fatal(err) + } + } + + close(stop) + wg.Wait() + + if err := <-writeErr; err != nil { + t.Fatal(err) + } +} diff --git a/lib/errs/error.go b/lib/errs/error.go new file mode 100644 index 0000000..589e8ee --- /dev/null +++ b/lib/errs/error.go @@ -0,0 +1,121 @@ +package errs + +import ( + "encoding/binary" + "fmt" + "io" + "runtime/debug" +) + +type Error struct { + msg string + code int64 + collection string + index string + stackTrace string + err error // Wrapped error +} + +func NewErr(code int64, msg string) *Error { + return &Error{ + msg: msg, + code: code, + } +} + +func (e *Error) Error() string { + if e.collection != "" || e.index != "" { + return fmt.Sprintf(`[%d] (%s/%s) %s`, e.code, e.collection, e.index, e.msg) + } else { + return fmt.Sprintf("[%d] %s", e.code, e.msg) + } +} + +func (e *Error) Is(rhs error) bool { + e2, ok := rhs.(*Error) + if !ok { + return false + } + return e.code == e2.code +} + +func (e *Error) WithErr(err error) *Error { + if e2, ok := err.(*Error); ok && e2.code == e.code { + return e2 + } + + e2 := e.WithMsg(err.Error()) + e2.err = err + return e2 +} + +func (e *Error) Unwrap() error { + if e.err != nil { + return e.err + } + return e +} + +func (e *Error) WithMsg(msg string, args ...any) *Error { + err := *e + err.msg += ": " + fmt.Sprintf(msg, args...) + if len(err.stackTrace) == 0 { + err.stackTrace = string(debug.Stack()) + } + return &err +} + +func (e *Error) WithCollection(s string) *Error { + err := *e + err.collection = s + return &err +} + +func (e *Error) WithIndex(s string) *Error { + err := *e + err.index = s + return &err +} + +func (e *Error) msgTruncacted() string { + if len(e.msg) > 255 { + return e.msg[:255] + } + return e.msg +} + +func (e *Error) Write(w io.Writer) error { + msg := e.msgTruncacted() + + if err := binary.Write(w, binary.LittleEndian, e.code); err != nil { + return IO.WithErr(err) + } + + if _, err := w.Write([]byte{byte(len(msg))}); err != nil { + return err + } + _, err := w.Write([]byte(msg)) + return err +} + +func (e *Error) Read(r io.Reader) error { + var ( + size uint8 + ) + + if err := binary.Read(r, binary.LittleEndian, &e.code); err != nil { + return IO.WithErr(err) + } + + if err := binary.Read(r, binary.LittleEndian, &size); err != nil { + return IO.WithErr(err) + } + + msgBuf := make([]byte, size) + if _, err := io.ReadFull(r, msgBuf); err != nil { + return IO.WithErr(err) + } + + e.msg = string(msgBuf) + return nil +} diff --git a/lib/errs/error_test.go b/lib/errs/error_test.go new file mode 100644 index 0000000..a89165c --- /dev/null +++ b/lib/errs/error_test.go @@ -0,0 +1,26 @@ +package errs + +import ( + "bytes" + "reflect" + "testing" +) + +func TestError_Simple(t *testing.T) { + e := Archived + + b := &bytes.Buffer{} + + if err := e.Write(b); err != nil { + t.Fatal(err) + } + + e2 := &Error{} + if err := e2.Read(b); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(*e, *e2) { + t.Fatal("not equal") + } +} diff --git a/lib/errs/errors.go b/lib/errs/errors.go new file mode 100644 index 0000000..dcbbe21 --- /dev/null +++ b/lib/errs/errors.go @@ -0,0 +1,21 @@ +package errs + +var ( + Archived = NewErr(100, "archived") + EOFArchived = NewErr(101, "EOF-archived") + IO = NewErr(102, "IO error") + NotFound = NewErr(103, "not found") + Locked = NewErr(104, "locked") + NotAuthorized = NewErr(105, "not authorized") + NotAllowed = NewErr(106, "not allowed") + Stopped = NewErr(107, "stopped") + Timeout = NewErr(108, "timeout") + Duplicate = NewErr(109, "duplicate") + ReadOnly = NewErr(110, "read only") + Encoding = NewErr(111, "encoding") + Closed = NewErr(112, "closed") + InvalidPath = NewErr(200, "invalid path") + Corrupt = NewErr(666, "corrupt") + Fatal = NewErr(1053, "fatal") + Unexpected = NewErr(999, "unexpected") +) diff --git a/lib/errs/fmt.go b/lib/errs/fmt.go new file mode 100644 index 0000000..fa43dfc --- /dev/null +++ b/lib/errs/fmt.go @@ -0,0 +1,22 @@ +package errs + +import "fmt" + +func FmtDetails(err error) string { + e, ok := err.(*Error) + if !ok { + return err.Error() + } + + var s string + if e.collection != "" || e.index != "" { + s = fmt.Sprintf(`[%d] (%s/%s) %s`, e.code, e.collection, e.index, e.msg) + } else { + s = fmt.Sprintf("[%d] %s", e.code, e.msg) + } + if len(e.stackTrace) != 0 { + s += "\n\nStack Trace:\n" + e.stackTrace + "\n" + } + + return s +} diff --git a/lib/flock/flock.go b/lib/flock/flock.go new file mode 100644 index 0000000..108694c --- /dev/null +++ b/lib/flock/flock.go @@ -0,0 +1,58 @@ +package flock + +import ( + "errors" + "os" + + "golang.org/x/sys/unix" +) + +// Lock gets an exclusive lock on the file at the given path. If the file +// doesn't exist, it's created. +func Lock(path string) (*os.File, error) { + return lock(path, unix.LOCK_EX) +} + +// TryLock will return a nil file if the file is already locked. +func TryLock(path string) (*os.File, error) { + return lock(path, unix.LOCK_EX|unix.LOCK_NB) +} + +func LockFile(f *os.File) error { + _, err := lockFile(f, unix.LOCK_EX) + return err +} + +// Returns true if the lock was successfully acquired. +func TryLockFile(f *os.File) (bool, error) { + return lockFile(f, unix.LOCK_EX|unix.LOCK_NB) +} + +func lockFile(f *os.File, flags int) (bool, error) { + if err := unix.Flock(int(f.Fd()), flags); err != nil { + if flags&unix.LOCK_NB != 0 && errors.Is(err, unix.EAGAIN) { + return false, nil + } + return false, err + } + return true, nil +} + +func lock(path string, flags int) (*os.File, error) { + perm := os.O_CREATE | os.O_RDWR + f, err := os.OpenFile(path, perm, 0600) + if err != nil { + return nil, err + } + ok, err := lockFile(f, flags) + if err != nil || !ok { + f.Close() + f = nil + } + return f, err +} + +// Unlock releases the lock acquired via the Lock function. +func Unlock(f *os.File) error { + return f.Close() +} diff --git a/lib/httpconn/client.go b/lib/httpconn/client.go new file mode 100644 index 0000000..39e6935 --- /dev/null +++ b/lib/httpconn/client.go @@ -0,0 +1,85 @@ +package httpconn + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "net" + "net/http" + "net/url" + "time" +) + +var ErrInvalidStatus = errors.New("invalid status") + +func Dial(rawURL string) (net.Conn, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, errs.Unexpected.WithErr(err) + } + + switch u.Scheme { + case "https": + return DialHTTPS(u.Host+":443", u.Path) + case "http": + return DialHTTP(u.Host, u.Path) + default: + return nil, errs.Unexpected.WithMsg("Unknown scheme: " + u.Scheme) + } +} + +func DialHTTPS(host, path string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + d := tls.Dialer{} + conn, err := d.DialContext(ctx, "tcp", host) + cancel() + if err != nil { + return nil, errs.IO.WithErr(err) + } + return finishDialing(conn, host, path) +} + +func DialHTTPSWithIP(ip, host, path string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + d := tls.Dialer{Config: &tls.Config{ServerName: host}} + conn, err := d.DialContext(ctx, "tcp", ip) + cancel() + if err != nil { + return nil, errs.IO.WithErr(err) + } + return finishDialing(conn, host, path) +} + +func DialHTTP(host, path string) (net.Conn, error) { + conn, err := net.Dial("tcp", host) + if err != nil { + return nil, errs.IO.WithErr(err) + } + return finishDialing(conn, host, path) +} + +func finishDialing(conn net.Conn, host, path string) (net.Conn, error) { + conn.SetDeadline(time.Now().Add(10 * time.Second)) + + io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n") + io.WriteString(conn, "Host: "+host+"\n\n") + + // Require successful HTTP response before using the conn. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err != nil { + conn.Close() + return nil, errs.IO.WithErr(err) + } + + if resp.Status != "200 OK" { + conn.Close() + return nil, errs.IO.WithMsg("invalid status: %s", resp.Status) + } + + conn.SetDeadline(time.Time{}) + + return conn, nil +} diff --git a/lib/httpconn/conn_test.go b/lib/httpconn/conn_test.go new file mode 100644 index 0000000..5982d65 --- /dev/null +++ b/lib/httpconn/conn_test.go @@ -0,0 +1,42 @@ +package httpconn + +import ( + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "golang.org/x/net/nettest" +) + +func TestNetTest_TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { + + connCh := make(chan net.Conn, 1) + doneCh := make(chan bool) + + mux := http.NewServeMux() + mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) { + conn, err := Accept(w, r) + if err != nil { + panic(err) + } + connCh <- conn + <-doneCh + }) + + srv := httptest.NewServer(mux) + + c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect") + if err != nil { + panic(err) + } + c2 = <-connCh + + return c1, c2, func() { + doneCh <- true + srv.Close() + }, nil + }) +} diff --git a/lib/httpconn/server.go b/lib/httpconn/server.go new file mode 100644 index 0000000..f2a4dd5 --- /dev/null +++ b/lib/httpconn/server.go @@ -0,0 +1,32 @@ +package httpconn + +import ( + "io" + "net" + "net/http" + "time" +) + +func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) { + if r.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + io.WriteString(w, "405 must CONNECT\n") + return nil, http.ErrNotSupported + } + + hj, ok := w.(http.Hijacker) + if !ok { + return nil, http.ErrNotSupported + } + + conn, _, err := hj.Hijack() + if err != nil { + return nil, err + } + + _, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n") + conn.SetDeadline(time.Time{}) + + return conn, nil +} diff --git a/lib/idgen/gen.go b/lib/idgen/gen.go new file mode 100644 index 0000000..07daf07 --- /dev/null +++ b/lib/idgen/gen.go @@ -0,0 +1,32 @@ +package idgen + +import ( + "sync" + "time" +) + +var ( + lock sync.Mutex + ts uint64 = uint64(time.Now().Unix()) + counter uint64 = 1 + counterMax uint64 = 1 << 28 +) + +// Next can generate ~268M ints per second for ~1000 years. +func Next() uint64 { + lock.Lock() + defer lock.Unlock() + + tt := uint64(time.Now().Unix()) + if tt > ts { + ts = tt + counter = 1 + } else { + counter++ + if counter == counterMax { + panic("Too many IDs.") + } + } + + return ts<<28 + counter +} diff --git a/lib/idgen/gen_test.go b/lib/idgen/gen_test.go new file mode 100644 index 0000000..3cfb9ca --- /dev/null +++ b/lib/idgen/gen_test.go @@ -0,0 +1,11 @@ +package idgen + +import ( + "testing" +) + +func BenchmarkNext(b *testing.B) { + for i := 0; i < b.N; i++ { + Next() + } +} diff --git a/lib/rep/functions.go b/lib/rep/functions.go new file mode 100644 index 0000000..d9ac3d2 --- /dev/null +++ b/lib/rep/functions.go @@ -0,0 +1,51 @@ +package rep + +import ( + "encoding/binary" + "encoding/json" + "git.crumpington.com/public/jldb/lib/errs" + "net" + "path/filepath" + "time" +) + +// ---------------------------------------------------------------------------- + +func lockFilePath(rootDir string) string { + return filepath.Join(rootDir, "lock") +} + +func walRootDir(rootDir string) string { + return filepath.Join(rootDir, "wal") +} + +func stateFilePath(rootDir string) string { + return filepath.Join(rootDir, "state") +} + +// ---------------------------------------------------------------------------- + +func sendJSON( + item any, + conn net.Conn, + timeout time.Duration, +) error { + + buf := bufPoolGet() + defer bufPoolPut(buf) + + if err := json.NewEncoder(buf).Encode(item); err != nil { + return errs.Unexpected.WithErr(err) + } + + sizeBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(sizeBuf, uint16(buf.Len())) + + conn.SetWriteDeadline(time.Now().Add(timeout)) + buffers := net.Buffers{sizeBuf, buf.Bytes()} + if _, err := buffers.WriteTo(conn); err != nil { + return errs.IO.WithErr(err) + } + + return nil +} diff --git a/lib/rep/http-client.go b/lib/rep/http-client.go new file mode 100644 index 0000000..a5e4fe2 --- /dev/null +++ b/lib/rep/http-client.go @@ -0,0 +1,178 @@ +package rep + +import ( + "encoding/binary" + "encoding/json" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/httpconn" + "git.crumpington.com/public/jldb/lib/wal" + "net" + "sync" + "time" +) + +type client struct { + // Mutex-protected variables. + lock sync.Mutex + closed bool + conn net.Conn + + // The following are constant. + endpoint string + psk []byte + timeout time.Duration + + buf []byte +} + +func newClient(endpoint, psk string, timeout time.Duration) *client { + b := make([]byte, 256) + copy(b, []byte(psk)) + + return &client{ + endpoint: endpoint, + psk: b, + timeout: timeout, + } +} + +func (c *client) GetInfo() (info Info, err error) { + err = c.withConn(cmdGetInfo, func(conn net.Conn) error { + return c.recvJSON(&info, conn, c.timeout) + }) + return info, err +} + +func (c *client) RecvState(recv func(net.Conn) error) error { + return c.withConn(cmdSendState, recv) +} + +func (c *client) StreamWAL(w *wal.WAL) error { + return c.withConn(cmdStreamWAL, func(conn net.Conn) error { + return w.Recv(conn, c.timeout) + }) +} + +func (c *client) Close() { + c.lock.Lock() + defer c.lock.Unlock() + c.closed = true + + if c.conn != nil { + c.conn.Close() + c.conn = nil + } +} + +// ---------------------------------------------------------------------------- + +func (c *client) writeCmd(cmd byte) error { + c.conn.SetWriteDeadline(time.Now().Add(c.timeout)) + if _, err := c.conn.Write([]byte{cmd}); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +func (c *client) dial() error { + c.conn = nil + + conn, err := httpconn.Dial(c.endpoint) + if err != nil { + return err + } + + conn.SetWriteDeadline(time.Now().Add(c.timeout)) + if _, err := conn.Write(c.psk); err != nil { + conn.Close() + return errs.IO.WithErr(err) + } + + c.conn = conn + return nil +} + +func (c *client) withConn(cmd byte, fn func(net.Conn) error) error { + conn, err := c.getConn(cmd) + if err != nil { + return err + } + + if err := fn(conn); err != nil { + conn.Close() + return err + } + return nil +} + +func (c *client) getConn(cmd byte) (net.Conn, error) { + c.lock.Lock() + defer c.lock.Unlock() + + if c.closed { + return nil, errs.IO.WithErr(io.EOF) + } + + dialed := false + + if c.conn == nil { + if err := c.dial(); err != nil { + return nil, err + } + dialed = true + } + + if err := c.writeCmd(cmd); err != nil { + if dialed { + c.conn = nil + return nil, err + } + + if err := c.dial(); err != nil { + return nil, err + } + + if err := c.writeCmd(cmd); err != nil { + return nil, err + } + } + + return c.conn, nil +} + +func (c *client) recvJSON( + item any, + conn net.Conn, + timeout time.Duration, +) error { + + if cap(c.buf) < 2 { + c.buf = make([]byte, 0, 1024) + } + buf := c.buf[:2] + + conn.SetReadDeadline(time.Now().Add(timeout)) + + if _, err := io.ReadFull(conn, buf); err != nil { + return errs.IO.WithErr(err) + } + + size := binary.LittleEndian.Uint16(buf) + + if cap(buf) < int(size) { + buf = make([]byte, size) + c.buf = buf + } + buf = buf[:size] + + if _, err := io.ReadFull(conn, buf); err != nil { + return errs.IO.WithErr(err) + } + + if err := json.Unmarshal(buf, item); err != nil { + return errs.Unexpected.WithErr(err) + } + + return nil +} diff --git a/lib/rep/http-handler.go b/lib/rep/http-handler.go new file mode 100644 index 0000000..fe513e4 --- /dev/null +++ b/lib/rep/http-handler.go @@ -0,0 +1,79 @@ +package rep + +import ( + "crypto/subtle" + "git.crumpington.com/public/jldb/lib/httpconn" + "log" + "net/http" + "time" +) + +const ( + cmdGetInfo = 10 + cmdSendState = 20 + cmdStreamWAL = 30 +) + +// --------------------------------------------------------------------------- + +func (rep *Replicator) Handle(w http.ResponseWriter, r *http.Request) { + logf := func(pattern string, args ...any) { + log.Printf("[HTTP-HANDLER] "+pattern, args...) + } + + conn, err := httpconn.Accept(w, r) + if err != nil { + logf("Failed to accept connection: %s", err) + return + } + defer conn.Close() + + psk := make([]byte, 256) + + conn.SetReadDeadline(time.Now().Add(rep.conf.NetTimeout)) + if _, err := conn.Read(psk); err != nil { + logf("Failed to read PSK: %v", err) + return + } + + expected := rep.pskBytes + if subtle.ConstantTimeCompare(expected, psk) != 1 { + logf("PSK mismatch.") + return + } + + cmd := make([]byte, 1) + + for { + conn.SetReadDeadline(time.Now().Add(rep.conf.NetTimeout)) + if _, err := conn.Read(cmd); err != nil { + logf("Read failed: %v", err) + return + } + + switch cmd[0] { + + case cmdGetInfo: + if err := sendJSON(rep.Info(), conn, rep.conf.NetTimeout); err != nil { + logf("Failed to send info: %s", err) + return + } + + case cmdSendState: + + if err := rep.sendState(conn); err != nil { + if !rep.stopped() { + logf("Failed to send state: %s", err) + } + return + } + + case cmdStreamWAL: + err := rep.wal.Send(conn, rep.conf.NetTimeout) + if !rep.stopped() { + logf("Failed when sending WAL: %s", err) + } + return + } + } +} diff --git a/lib/rep/info.go b/lib/rep/info.go new file mode 100644 index 0000000..c6c5821 --- /dev/null +++ b/lib/rep/info.go @@ -0,0 +1,9 @@ +package rep + +type Info struct { + AppSeqNum int64 // Page file sequence number. + AppTimestampMS int64 // Page file timestamp. + WALFirstSeqNum int64 // WAL min sequence number. + WALLastSeqNum int64 // WAL max sequence number. + WALLastTimestampMS int64 // WAL timestamp. +} diff --git a/lib/rep/localstate.go b/lib/rep/localstate.go new file mode 100644 index 0000000..e59b340 --- /dev/null +++ b/lib/rep/localstate.go @@ -0,0 +1,20 @@ +package rep + +import ( + "encoding/binary" +) + +type localState struct { + SeqNum int64 + TimestampMS int64 +} + +func (h localState) writeTo(b []byte) { + binary.LittleEndian.PutUint64(b[0:8], uint64(h.SeqNum)) + binary.LittleEndian.PutUint64(b[8:16], uint64(h.TimestampMS)) +} + +func (h *localState) readFrom(b []byte) { + h.SeqNum = int64(binary.LittleEndian.Uint64(b[0:8])) + h.TimestampMS = int64(binary.LittleEndian.Uint64(b[8:16])) +} diff --git a/lib/rep/pools.go b/lib/rep/pools.go new file mode 100644 index 0000000..e539223 --- /dev/null +++ b/lib/rep/pools.go @@ -0,0 +1,21 @@ +package rep + +import ( + "bytes" + "sync" +) + +var bufPool = sync.Pool{ + New: func() any { + return &bytes.Buffer{} + }, +} + +func bufPoolGet() *bytes.Buffer { + return bufPool.Get().(*bytes.Buffer) +} + +func bufPoolPut(b *bytes.Buffer) { + b.Reset() + bufPool.Put(b) +} diff --git a/lib/rep/rep-sendrecv.go b/lib/rep/rep-sendrecv.go new file mode 100644 index 0000000..184b49a --- /dev/null +++ b/lib/rep/rep-sendrecv.go @@ -0,0 +1,41 @@ +package rep + +import ( + "io" + "git.crumpington.com/public/jldb/lib/errs" + "net" + "time" +) + +func (rep *Replicator) sendState(conn net.Conn) error { + state := rep.getState() + + buf := make([]byte, 512) + state.writeTo(buf) + + conn.SetWriteDeadline(time.Now().Add(rep.conf.NetTimeout)) + if _, err := conn.Write(buf); err != nil { + return errs.IO.WithErr(err) + } + conn.SetWriteDeadline(time.Time{}) + + return rep.app.SendState(conn) +} + +func (rep *Replicator) recvState(conn net.Conn) error { + buf := make([]byte, 512) + conn.SetReadDeadline(time.Now().Add(rep.conf.NetTimeout)) + if _, err := io.ReadFull(conn, buf); err != nil { + return errs.IO.WithErr(err) + } + conn.SetReadDeadline(time.Time{}) + + if err := rep.app.RecvState(conn); err != nil { + return err + } + + state := localState{} + state.readFrom(buf) + + return rep.setState(state) +} diff --git a/lib/rep/replicator-open.go b/lib/rep/replicator-open.go new file mode 100644 index 0000000..0eb6b04 --- /dev/null +++ b/lib/rep/replicator-open.go @@ -0,0 +1,184 @@ +package rep + +import ( + "git.crumpington.com/public/jldb/lib/atomicheader" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/flock" + "git.crumpington.com/public/jldb/lib/wal" + "os" + "time" +) + +func (rep *Replicator) loadConfigDefaults() { + conf := rep.conf + + if conf.NetTimeout <= 0 { + conf.NetTimeout = time.Minute + } + if conf.WALSegMinCount <= 0 { + conf.WALSegMinCount = 1024 + } + if conf.WALSegMaxAgeSec <= 0 { + conf.WALSegMaxAgeSec = 3600 + } + if conf.WALSegGCAgeSec <= 0 { + conf.WALSegGCAgeSec = 7 * 86400 + } + + rep.conf = conf + + rep.pskBytes = make([]byte, 256) + copy(rep.pskBytes, []byte(conf.ReplicationPSK)) +} + +func (rep *Replicator) initDirectories() error { + if err := os.MkdirAll(walRootDir(rep.conf.RootDir), 0700); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +func (rep *Replicator) acquireLock() error { + lockFile, err := flock.TryLock(lockFilePath(rep.conf.RootDir)) + if err != nil { + return errs.IO.WithMsg("locked: %s", lockFilePath(rep.conf.RootDir)) + } + if lockFile == nil { + return errs.Locked + } + rep.lockFile = lockFile + return nil +} + +func (rep *Replicator) loadLocalState() error { + f, err := os.OpenFile(stateFilePath(rep.conf.RootDir), os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return errs.IO.WithErr(err) + } + + info, err := f.Stat() + if err != nil { + f.Close() + return errs.IO.WithErr(err) + } + + if info.Size() < atomicheader.ReservedBytes { + if err := atomicheader.Init(f); err != nil { + f.Close() + return errs.IO.WithErr(err) + } + } + + rep.stateHandler, err = atomicheader.Open(f) + if err != nil { + f.Close() + return err + } + + rep.stateFile = f + var state localState + + err = rep.stateHandler.Read(func(page []byte) error { + state.readFrom(page) + return nil + }) + if err == nil { + rep.state.Store(&state) + return nil + } + + // Write a clean state. + state = localState{} + rep.state.Store(&state) + return rep.stateHandler.Write(func(page []byte) error { + state.writeTo(page) + return nil + }) +} + +func (rep *Replicator) walConfig() wal.Config { + return wal.Config{ + SegMinCount: rep.conf.WALSegMinCount, + SegMaxAgeSec: rep.conf.WALSegMaxAgeSec, + } +} + +func (rep *Replicator) openWAL() (err error) { + rep.wal, err = wal.Open(walRootDir(rep.conf.RootDir), rep.walConfig()) + if err != nil { + rep.wal, err = wal.Create(walRootDir(rep.conf.RootDir), 1, rep.walConfig()) + if err != nil { + return err + } + } + + return nil +} + +func (rep *Replicator) recvStateIfNecessary() error { + if rep.conf.Primary { + return nil + } + + sInfo := rep.Info() + pInfo, err := rep.client.GetInfo() + if err != nil { + return err + } + + if pInfo.WALFirstSeqNum <= sInfo.WALLastSeqNum { + return nil + } + + // Make a new WAL. + rep.wal.Close() + + if err = rep.client.RecvState(rep.recvState); err != nil { + return err + } + + state := rep.getState() + + rep.wal, err = wal.Create(walRootDir(rep.conf.RootDir), state.SeqNum+1, rep.walConfig()) + return err +} + +// Replays un-acked entries in the WAL. Acks after all records are replayed. +func (rep *Replicator) replay() error { + state := rep.getState() + it, err := rep.wal.Iterator(state.SeqNum + 1) + if err != nil { + return err + } + defer it.Close() + + for it.Next(0) { + rec := it.Record() + if err := rep.app.Replay(rec); err != nil { + return err + } + state.SeqNum = rec.SeqNum + state.TimestampMS = rec.TimestampMS + } + + if it.Error() != nil { + return it.Error() + } + + return rep.ack(state.SeqNum, state.TimestampMS) +} + +func (rep *Replicator) startWALGC() { + rep.done.Add(1) + go rep.runWALGC() +} + +func (rep *Replicator) startWALFollower() { + rep.done.Add(1) + go rep.runWALFollower() +} + +func (rep *Replicator) startWALRecvr() { + rep.done.Add(1) + go rep.runWALRecvr() +} diff --git a/lib/rep/replicator-walfollower.go b/lib/rep/replicator-walfollower.go new file mode 100644 index 0000000..669ea90 --- /dev/null +++ b/lib/rep/replicator-walfollower.go @@ -0,0 +1,66 @@ +package rep + +import ( + "log" + "time" +) + +func (rep *Replicator) runWALFollower() { + defer rep.done.Done() + + for { + rep.followOnce() + + select { + case <-rep.stop: + return + default: + time.Sleep(time.Second) + } + } +} + +func (rep *Replicator) followOnce() { + logf := func(pattern string, args ...any) { + log.Printf("[WAL-FOLLOWER] "+pattern, args...) + } + + state := rep.getState() + it, err := rep.wal.Iterator(state.SeqNum + 1) + if err != nil { + logf("Failed to create WAL iterator: %v", err) + return + } + defer it.Close() + + for { + select { + case <-rep.stop: + logf("Stopped") + return + default: + } + + if it.Next(time.Second) { + rec := it.Record() + + if err := rep.app.Apply(rec); err != nil { + logf("App failed to apply change: %v", err) + return + } + + if err := rep.ack(rec.SeqNum, rec.TimestampMS); err != nil { + logf("App failed to update local state: %v", err) + return + } + + select { + case rep.appendNotify <- struct{}{}: + default: + } + } else if it.Error() != nil { + logf("Iteration error: %v", err) + return + } + } +} diff --git a/lib/rep/replicator-walgc.go b/lib/rep/replicator-walgc.go new file mode 100644 index 0000000..2f51da8 --- /dev/null +++ b/lib/rep/replicator-walgc.go @@ -0,0 +1,28 @@ +package rep + +import ( + "log" + "time" +) + +func (rep *Replicator) runWALGC() { + defer rep.done.Done() + + ticker := time.NewTicker(time.Minute) + defer ticker.Stop() + + for { + select { + case <-ticker.C: + state := rep.getState() + before := time.Now().Unix() - rep.conf.WALSegMaxAgeSec + if err := rep.wal.DeleteBefore(before, state.SeqNum); err != nil { + log.Printf("[WAL-GC] failed to delete wal segments: %v", err) + } + // OK + case <-rep.stop: + log.Print("[WAL-GC] Stopped") + return + } + } +} diff --git a/lib/rep/replicator-walrecvr.go b/lib/rep/replicator-walrecvr.go new file mode 100644 index 0000000..a5b5b05 --- /dev/null +++ b/lib/rep/replicator-walrecvr.go @@ -0,0 +1,38 @@ +package rep + +import ( + "log" + "time" +) + +func (rep *Replicator) runWALRecvr() { + go func() { + <-rep.stop + rep.client.Close() + }() + + defer rep.done.Done() + + for { + rep.runWALRecvrOnce() + select { + case <-rep.stop: + log.Print("[WAL-RECVR] Stopped") + return + default: + time.Sleep(time.Second) + } + } +} + +func (rep *Replicator) runWALRecvrOnce() { + logf := func(pattern string, args ...any) { + log.Printf("[WAL-RECVR] "+pattern, args...) + } + + if err := rep.client.StreamWAL(rep.wal); err != nil { + if !rep.stopped() { + logf("Recv failed: %v", err) + } + } +} diff --git a/lib/rep/replicator.go b/lib/rep/replicator.go new file mode 100644 index 0000000..2e4ce6e --- /dev/null +++ b/lib/rep/replicator.go @@ -0,0 +1,235 @@ +package rep + +import ( + "io" + "git.crumpington.com/public/jldb/lib/atomicheader" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/wal" + "net" + "os" + "sync" + "sync/atomic" + "time" +) + +type Config struct { + RootDir string + Primary bool + ReplicationPSK string + NetTimeout time.Duration // Default is 1 minute. + + // WAL settings. + WALSegMinCount int64 // Minimum Change sets in a segment. Default is 1024. + WALSegMaxAgeSec int64 // Maximum age of a segment. Default is 1 hour. + WALSegGCAgeSec int64 // Segment age for garbage collection. Default is 7 days. + + // If true, Append won't return until a successful App.Apply. + SynchronousAppend bool + + // Necessary for secondary. + PrimaryEndpoint string +} + +type App struct { + // SendState: The primary may need to send storage state to a secondary node. + SendState func(conn net.Conn) error + + // (1) RecvState: Secondary nodes may need to load state from the primary if the + // WAL is too far behind. + RecvState func(conn net.Conn) error + + // (2) InitStorage: Prepare application storage for possible calls to + // Replay. + InitStorage func() error + + // (3) Replay: write the change to storage. Replay must be idempotent. + Replay func(rec wal.Record) error + + // (4) LoadFromStorage: load the application's state from it's persistent + // storage. + LoadFromStorage func() error + + // (5) Apply: write the change to persistent storage. Apply must be + // idempotent. In normal operation each change is applied exactly once. + Apply func(rec wal.Record) error +} + +type Replicator struct { + app App + conf Config + + lockFile *os.File + pskBytes []byte + wal *wal.WAL + + appendNotify chan struct{} + + // lock protects state. The lock is held when replaying (R), following (R), + // and sending state (W). + stateFile *os.File + state *atomic.Pointer[localState] + stateHandler *atomicheader.Handler + + stop chan struct{} + done *sync.WaitGroup + + client *client // For secondary connection to primary. +} + +func Open(app App, conf Config) (*Replicator, error) { + rep := &Replicator{ + app: app, + conf: conf, + state: &atomic.Pointer[localState]{}, + stop: make(chan struct{}), + done: &sync.WaitGroup{}, + appendNotify: make(chan struct{}, 1), + } + + rep.loadConfigDefaults() + + rep.state.Store(&localState{}) + rep.client = newClient(rep.conf.PrimaryEndpoint, rep.conf.ReplicationPSK, rep.conf.NetTimeout) + + if err := rep.initDirectories(); err != nil { + return nil, err + } + + if err := rep.acquireLock(); err != nil { + rep.Close() + return nil, err + } + + if err := rep.loadLocalState(); err != nil { + rep.Close() + return nil, err + } + + if err := rep.openWAL(); err != nil { + rep.Close() + return nil, err + } + + if err := rep.recvStateIfNecessary(); err != nil { + rep.Close() + return nil, err + } + + if err := rep.app.InitStorage(); err != nil { + rep.Close() + return nil, err + } + + if err := rep.replay(); err != nil { + rep.Close() + return nil, err + } + + if err := rep.app.LoadFromStorage(); err != nil { + rep.Close() + return nil, err + } + + rep.startWALGC() + rep.startWALFollower() + + if !rep.conf.Primary { + rep.startWALRecvr() + } + + return rep, nil +} + +func (rep *Replicator) Append(size int64, r io.Reader) (int64, int64, error) { + if !rep.conf.Primary { + return 0, 0, errs.NotAllowed.WithMsg("cannot write to secondary") + } + + seqNum, timestampMS, err := rep.wal.Append(size, r) + if err != nil { + return 0, 0, err + } + if !rep.conf.SynchronousAppend { + return seqNum, timestampMS, nil + } + + <-rep.appendNotify + return seqNum, timestampMS, nil +} + +func (rep *Replicator) Primary() bool { + return rep.conf.Primary +} + +// TODO: Probably remove this. +// The caller may call Ack after Apply to acknowledge that the change has also +// been applied to the caller's application. Alternatively, the caller may use +// follow to apply changes to their application state. +func (rep *Replicator) ack(seqNum, timestampMS int64) error { + state := rep.getState() + state.SeqNum = seqNum + state.TimestampMS = timestampMS + return rep.setState(state) +} + +func (rep *Replicator) getState() localState { + return *rep.state.Load() +} + +func (rep *Replicator) setState(state localState) error { + err := rep.stateHandler.Write(func(page []byte) error { + state.writeTo(page) + return nil + }) + if err != nil { + return err + } + + rep.state.Store(&state) + return nil +} + +func (rep *Replicator) Info() Info { + state := rep.getState() + walInfo := rep.wal.Info() + + return Info{ + AppSeqNum: state.SeqNum, + AppTimestampMS: state.TimestampMS, + WALFirstSeqNum: walInfo.FirstSeqNum, + WALLastSeqNum: walInfo.LastSeqNum, + WALLastTimestampMS: walInfo.LastTimestampMS, + } +} + +func (rep *Replicator) Close() error { + if rep.stopped() { + return nil + } + + close(rep.stop) + rep.done.Wait() + + if rep.lockFile != nil { + rep.lockFile.Close() + } + + if rep.wal != nil { + rep.wal.Close() + } + + if rep.client != nil { + rep.client.Close() + } + + return nil +} + +func (rep *Replicator) stopped() bool { + select { + case <-rep.stop: + return true + default: + return false + } +} diff --git a/lib/rep/testapp-harness_test.go b/lib/rep/testapp-harness_test.go new file mode 100644 index 0000000..bfcf4e0 --- /dev/null +++ b/lib/rep/testapp-harness_test.go @@ -0,0 +1,128 @@ +package rep + +import ( + "math/rand" + "net/http" + "net/http/httptest" + "path/filepath" + "reflect" + "strings" + "testing" + "time" +) + +func TestAppHarnessRun(t *testing.T) { + TestAppHarness{}.Run(t) +} + +type TestAppHarness struct { +} + +func (h TestAppHarness) Run(t *testing.T) { + val := reflect.ValueOf(h) + typ := val.Type() + for i := 0; i < typ.NumMethod(); i++ { + method := typ.Method(i) + + if !strings.HasPrefix(method.Name, "Test") { + continue + } + + t.Run(method.Name, func(t *testing.T) { + //t.Parallel() + rootDir := t.TempDir() + + app1 := newApp(t, rand.Int63(), Config{ + Primary: true, + RootDir: filepath.Join(rootDir, "app1"), + ReplicationPSK: "123", + WALSegMinCount: 1, + WALSegMaxAgeSec: 1, + WALSegGCAgeSec: 1, + }) + defer app1.Close() + + mux := http.NewServeMux() + mux.HandleFunc("/rep/", app1.rep.Handle) + testServer := httptest.NewServer(mux) + defer testServer.Close() + + app2 := newApp(t, rand.Int63(), Config{ + Primary: false, + RootDir: filepath.Join(rootDir, "app2"), + ReplicationPSK: "123", + PrimaryEndpoint: testServer.URL + "/rep/", + WALSegMinCount: 1, + WALSegMaxAgeSec: 1, + WALSegGCAgeSec: 1, + }) + + val.MethodByName(method.Name).Call([]reflect.Value{ + reflect.ValueOf(t), + reflect.ValueOf(app1), + reflect.ValueOf(app2), + }) + }) + } +} + +func (TestAppHarness) TestRandomUpdates(t *testing.T, app1, app2 *TestApp) { + go app1.UpdateRandomFor(4 * time.Second) + app2.WaitForEOF() + app1.AssertEqual(t, app2) +} + +/* +func (TestAppHarness) TestRandomUpdatesReplay(t *testing.T, app1, app2 *TestApp) { + app1.UpdateRandomFor(4 * time.Second) + app2.WaitForEOF() + + app1.Close() + app1 = newApp(t, app1.ID, app1.rep.conf) + + app1.AssertEqual(t, app2) + info := app1.rep.Info() + if info.AppSeqNum != 0 { + t.Fatal(info) + } +} + +func (TestAppHarness) TestRandomUpdatesAck(t *testing.T, app1, app2 *TestApp) { + go app1.UpdateRandomFor(4 * time.Second) + app2.WaitForEOF() + app1.AssertEqual(t, app2) + info := app1.rep.Info() + if info.AppSeqNum == 0 || info.AppSeqNum != info.WALLastSeqNum { + t.Fatal(info) + } +} + +func (TestAppHarness) TestWriteThenOpenFollower(t *testing.T, app1, app2 *TestApp) { + app2.Close() + app1.UpdateRandomFor(4 * time.Second) + + app2 = newApp(t, app2.ID, app2.rep.conf) + app2.WaitForEOF() + app1.AssertEqual(t, app2) +} + +func (TestAppHarness) TestUpdateOpenFollowerConcurrently(t *testing.T, app1, app2 *TestApp) { + app2.Close() + go app1.UpdateRandomFor(4 * time.Second) + time.Sleep(2 * time.Second) + app2 = newApp(t, app2.ID, app2.rep.conf) + app2.WaitForEOF() + app1.AssertEqual(t, app2) +} + +func (TestAppHarness) TestUpdateCloseOpenFollowerConcurrently(t *testing.T, app1, app2 *TestApp) { + go app1.UpdateRandomFor(4 * time.Second) + + time.Sleep(time.Second) + app2.Close() + time.Sleep(time.Second) + app2 = newApp(t, app2.ID, app2.rep.conf) + app2.WaitForEOF() + app1.AssertEqual(t, app2) +} +*/ diff --git a/lib/rep/testapp_test.go b/lib/rep/testapp_test.go new file mode 100644 index 0000000..6438e12 --- /dev/null +++ b/lib/rep/testapp_test.go @@ -0,0 +1,239 @@ +package rep + +import ( + "bytes" + "encoding/binary" + "encoding/json" + "io" + "git.crumpington.com/public/jldb/lib/wal" + "math/rand" + "net" + "sync" + "testing" + "time" +) + +// ---------------------------------------------------------------------------- + +type TestCmd struct { + Set int64 // 1 for set, 0 for delete + Key int64 + Val int64 +} + +func (c TestCmd) marshal() []byte { + b := make([]byte, 24) + binary.LittleEndian.PutUint64(b, uint64(c.Set)) + binary.LittleEndian.PutUint64(b[8:], uint64(c.Key)) + binary.LittleEndian.PutUint64(b[16:], uint64(c.Val)) + return b +} + +func (c *TestCmd) unmarshal(b []byte) { + c.Set = int64(binary.LittleEndian.Uint64(b)) + c.Key = int64(binary.LittleEndian.Uint64(b[8:])) + c.Val = int64(binary.LittleEndian.Uint64(b[16:])) +} + +func CmdFromRec(rec wal.Record) TestCmd { + cmd := TestCmd{} + + buf, err := io.ReadAll(rec.Reader) + if err != nil { + panic(err) + } + if len(buf) != 24 { + panic(len(buf)) + } + cmd.unmarshal(buf) + return cmd +} + +// ---------------------------------------------------------------------------- + +var storage = map[int64]map[int64]int64{} + +// ---------------------------------------------------------------------------- + +type TestApp struct { + ID int64 + storage map[int64]int64 + + rep *Replicator + + lock sync.Mutex + m map[int64]int64 +} + +func newApp(t *testing.T, id int64, conf Config) *TestApp { + t.Helper() + a := &TestApp{ + ID: id, + m: map[int64]int64{}, + } + + var err error + a.rep, err = Open(App{ + SendState: a.sendState, + RecvState: a.recvState, + InitStorage: a.initStorage, + Replay: a.replay, + LoadFromStorage: a.loadFromStorage, + Apply: a.apply, + }, conf) + if err != nil { + t.Fatal(err) + } + + return a +} + +func (a *TestApp) _set(k, v int64) { + a.lock.Lock() + defer a.lock.Unlock() + a.m[k] = v +} + +func (a *TestApp) _del(k int64) { + a.lock.Lock() + defer a.lock.Unlock() + delete(a.m, k) +} + +func (a *TestApp) Get(k int64) int64 { + a.lock.Lock() + defer a.lock.Unlock() + return a.m[k] +} + +func (app *TestApp) Close() { + app.rep.Close() +} + +func (app *TestApp) Set(k, v int64) error { + cmd := TestCmd{Set: 1, Key: k, Val: v} + if _, _, err := app.rep.Append(24, bytes.NewBuffer(cmd.marshal())); err != nil { + return err + } + app._set(k, v) + return nil +} + +func (app *TestApp) Del(k int64) error { + cmd := TestCmd{Set: 0, Key: k, Val: 0} + if _, _, err := app.rep.Append(24, bytes.NewBuffer(cmd.marshal())); err != nil { + return err + } + app._del(k) + return nil +} + +func (app *TestApp) UpdateRandomFor(dt time.Duration) { + tStart := time.Now() + for time.Since(tStart) < dt { + if rand.Float32() < 0.5 { + if err := app.Set(1+rand.Int63n(10), 1+rand.Int63n(10)); err != nil { + panic(err) + } + } else { + if err := app.Del(1 + rand.Int63n(10)); err != nil { + panic(err) + } + } + time.Sleep(time.Millisecond) + } + + app.Set(999, 999) +} + +func (app *TestApp) WaitForEOF() { + for app.Get(999) != 999 { + time.Sleep(time.Millisecond) + } +} + +func (app *TestApp) AssertEqual(t *testing.T, rhs *TestApp) { + app.lock.Lock() + defer app.lock.Unlock() + rhs.lock.Lock() + defer rhs.lock.Unlock() + + if len(app.m) != len(rhs.m) { + t.Fatal(len(app.m), len(rhs.m)) + } + + for k := range app.m { + if app.m[k] != rhs.m[k] { + t.Fatal(k, app.m[k], rhs.m[k]) + } + } +} + +// ---------------------------------------------------------------------------- + +func (app *TestApp) sendState(conn net.Conn) error { + app.lock.Lock() + b, _ := json.Marshal(app.m) + app.lock.Unlock() + + _, err := conn.Write(b) + return err +} + +func (app *TestApp) recvState(conn net.Conn) error { + m := map[int64]int64{} + if err := json.NewDecoder(conn).Decode(&m); err != nil { + return err + } + storage[app.ID] = m + return nil +} + +func (app *TestApp) initStorage() error { + if _, ok := storage[app.ID]; !ok { + storage[app.ID] = map[int64]int64{} + } + app.storage = storage[app.ID] + return nil +} + +func (app *TestApp) replay(rec wal.Record) error { + cmd := CmdFromRec(rec) + if cmd.Set != 0 { + app.storage[cmd.Key] = cmd.Val + } else { + delete(app.storage, cmd.Key) + } + return nil +} + +func (app *TestApp) loadFromStorage() error { + app.m = map[int64]int64{} + for k, v := range app.storage { + app.m[k] = v + } + return nil +} + +func (app *TestApp) apply(rec wal.Record) error { + cmd := CmdFromRec(rec) + if cmd.Set != 0 { + app.storage[cmd.Key] = cmd.Val + } else { + delete(app.storage, cmd.Key) + } + + // For primary, only update storage. + if app.rep.Primary() { + return nil + } + + // For secondary, update the map. + if cmd.Set != 0 { + app._set(cmd.Key, cmd.Val) + } else { + app._del(cmd.Key) + } + + return nil +} diff --git a/lib/testutil/limitwriter.go b/lib/testutil/limitwriter.go new file mode 100644 index 0000000..0537083 --- /dev/null +++ b/lib/testutil/limitwriter.go @@ -0,0 +1,33 @@ +package testutil + +import ( + "io" + "os" +) + +func NewLimitWriter(w io.Writer, limit int) *LimitWriter { + return &LimitWriter{ + w: w, + limit: limit, + } +} + +type LimitWriter struct { + w io.Writer + limit int + written int +} + +func (lw *LimitWriter) Write(buf []byte) (int, error) { + n, err := lw.w.Write(buf) + if err != nil { + return n, err + } + + lw.written += n + if lw.written > lw.limit { + return n, os.ErrClosed + } + + return n, nil +} diff --git a/lib/testutil/testconn.go b/lib/testutil/testconn.go new file mode 100644 index 0000000..f821a6d --- /dev/null +++ b/lib/testutil/testconn.go @@ -0,0 +1,79 @@ +package testutil + +import ( + "net" + "sync" + "time" +) + +type Network struct { + lock sync.Mutex + // Current connections. + cConn net.Conn + sConn net.Conn + + acceptQ chan net.Conn +} + +func NewNetwork() *Network { + return &Network{ + acceptQ: make(chan net.Conn, 1), + } +} + +func (n *Network) Dial() net.Conn { + cc, sc := net.Pipe() + func() { + n.lock.Lock() + defer n.lock.Unlock() + if n.cConn != nil { + n.cConn.Close() + n.cConn = nil + } + select { + case n.acceptQ <- sc: + n.cConn = cc + default: + cc = nil + } + }() + return cc +} + +func (n *Network) Accept() net.Conn { + var sc net.Conn + select { + case sc = <-n.acceptQ: + case <-time.After(time.Second): + return nil + } + + func() { + n.lock.Lock() + defer n.lock.Unlock() + if n.sConn != nil { + n.sConn.Close() + n.sConn = nil + } + n.sConn = sc + }() + return sc +} + +func (n *Network) CloseClient() { + n.lock.Lock() + defer n.lock.Unlock() + if n.cConn != nil { + n.cConn.Close() + n.cConn = nil + } +} + +func (n *Network) CloseServer() { + n.lock.Lock() + defer n.lock.Unlock() + if n.sConn != nil { + n.sConn.Close() + n.sConn = nil + } +} diff --git a/lib/testutil/util.go b/lib/testutil/util.go new file mode 100644 index 0000000..26ce7c9 --- /dev/null +++ b/lib/testutil/util.go @@ -0,0 +1,10 @@ +package testutil + +import "testing" + +func AssertNotNil(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} diff --git a/lib/wal/corrupt_test.go b/lib/wal/corrupt_test.go new file mode 100644 index 0000000..82f5021 --- /dev/null +++ b/lib/wal/corrupt_test.go @@ -0,0 +1,53 @@ +package wal + +import ( + "io" + "git.crumpington.com/public/jldb/lib/errs" + "testing" +) + +func TestCorruptWAL(t *testing.T) { + t.Parallel() + + tmpDir := t.TempDir() + wal, err := Create(tmpDir, 100, Config{ + SegMinCount: 1024, + SegMaxAgeSec: 3600, + }) + if err != nil { + t.Fatal(err) + } + defer wal.Close() + + appendRandomRecords(t, wal, 100) + + f := wal.seg.f + info, err := f.Stat() + if err != nil { + t.Fatal(err) + } + offset := info.Size() / 2 + if _, err := f.WriteAt([]byte{1, 2, 3, 4, 5, 6, 7, 8}, offset); err != nil { + t.Fatal(err) + } + + it, err := wal.Iterator(-1) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + for it.Next(0) { + rec := it.Record() + if _, err := io.ReadAll(rec.Reader); err != nil { + if errs.Corrupt.Is(err) { + return + } + t.Fatal(err) + } + } + + if !errs.Corrupt.Is(it.Error()) { + t.Fatal(it.Error()) + } +} diff --git a/lib/wal/design.go b/lib/wal/design.go new file mode 100644 index 0000000..01afeed --- /dev/null +++ b/lib/wal/design.go @@ -0,0 +1,28 @@ +package wal + +import ( + "time" +) + +type Info struct { + FirstSeqNum int64 + LastSeqNum int64 + LastTimestampMS int64 +} + +type Iterator interface { + // Next will return false if no record is available during the timeout + // period, or if an error is encountered. After Next returns false, the + // caller should check the return value of the Error function. + Next(timeout time.Duration) bool + + // Call Record after Next returns true to get the next record. + Record() Record + + // The caller must call Close on the iterator so clean-up can be performed. + Close() + + // Call Error to see if there was an error during the previous call to Next + // if Next returned false. + Error() error +} diff --git a/lib/wal/gc_test.go b/lib/wal/gc_test.go new file mode 100644 index 0000000..3d438d2 --- /dev/null +++ b/lib/wal/gc_test.go @@ -0,0 +1,94 @@ +package wal + +import ( + "math/rand" + "sync" + "testing" + "time" +) + +func TestDeleteBefore(t *testing.T) { + t.Parallel() + firstSeqNum := rand.Int63n(9288389) + tmpDir := t.TempDir() + + wal, err := Create(tmpDir, firstSeqNum, Config{ + SegMinCount: 10, + SegMaxAgeSec: 1, + }) + if err != nil { + t.Fatal(err) + } + defer wal.Close() + + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err := writeRandomWithEOF(wal, 8*time.Second) + if err != nil { + panic(err) + } + }() + + wg.Wait() + + info := wal.Info() + if info.FirstSeqNum != firstSeqNum { + t.Fatal(info) + } + + lastSeqNum := info.LastSeqNum + lastTimestampMS := info.LastTimestampMS + + err = wal.DeleteBefore((info.LastTimestampMS/1000)-4, lastSeqNum+100) + if err != nil { + t.Fatal(err) + } + + info = wal.Info() + if info.FirstSeqNum == firstSeqNum || info.LastSeqNum != lastSeqNum || info.LastTimestampMS != lastTimestampMS { + t.Fatal(info) + } + + header := wal.header + if header.FirstSegmentID >= header.LastSegmentID { + t.Fatal(header) + } +} + +func TestDeleteBeforeOnlyOneSegment(t *testing.T) { + t.Parallel() + firstSeqNum := rand.Int63n(9288389) + tmpDir := t.TempDir() + + wal, err := Create(tmpDir, firstSeqNum, Config{ + SegMinCount: 10, + SegMaxAgeSec: 10, + }) + if err != nil { + t.Fatal(err) + } + defer wal.Close() + + if err := writeRandomWithEOF(wal, time.Second); err != nil { + t.Fatal(err) + } + + header := wal.header + if header.FirstSegmentID != header.LastSegmentID { + t.Fatal(header) + } + + lastSeqNum := wal.Info().LastSeqNum + + err = wal.DeleteBefore(time.Now().Unix()+1, lastSeqNum+100) + if err != nil { + t.Fatal(err) + } + + header = wal.header + if header.FirstSegmentID != header.LastSegmentID { + t.Fatal(header) + } +} diff --git a/lib/wal/generic_test.go b/lib/wal/generic_test.go new file mode 100644 index 0000000..4ee6ba2 --- /dev/null +++ b/lib/wal/generic_test.go @@ -0,0 +1,391 @@ +package wal + +import ( + "bytes" + "encoding/binary" + "errors" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "math/rand" + "path/filepath" + "reflect" + "strings" + "testing" + "time" +) + +type waLog interface { + Append(int64, io.Reader) (int64, int64, error) + appendRecord(Record) (int64, int64, error) + Iterator(int64) (Iterator, error) + Close() error +} + +func TestGenericTestHarness_segment(t *testing.T) { + t.Parallel() + (&GenericTestHarness{ + New: func(tmpDir string, firstSeqNum int64) (waLog, error) { + l, err := createSegment(filepath.Join(tmpDir, "x"), 1, firstSeqNum, 12345) + return l, err + }, + }).Run(t) +} + +func TestGenericTestHarness_wal(t *testing.T) { + t.Parallel() + (&GenericTestHarness{ + New: func(tmpDir string, firstSeqNum int64) (waLog, error) { + l, err := Create(tmpDir, firstSeqNum, Config{ + SegMinCount: 1, + SegMaxAgeSec: 1, + }) + return l, err + }, + }).Run(t) +} + +// ---------------------------------------------------------------------------- + +type GenericTestHarness struct { + New func(tmpDir string, firstSeqNum int64) (waLog, error) +} + +func (h *GenericTestHarness) Run(t *testing.T) { + val := reflect.ValueOf(h) + typ := val.Type() + for i := 0; i < typ.NumMethod(); i++ { + method := typ.Method(i) + + if !strings.HasPrefix(method.Name, "Test") { + continue + } + + t.Run(method.Name, func(t *testing.T) { + t.Parallel() + firstSeqNum := rand.Int63n(23423) + + tmpDir := t.TempDir() + + wal, err := h.New(tmpDir, firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer wal.Close() + + val.MethodByName(method.Name).Call([]reflect.Value{ + reflect.ValueOf(t), + reflect.ValueOf(firstSeqNum), + reflect.ValueOf(wal), + }) + }) + } +} + +// ---------------------------------------------------------------------------- + +func (h *GenericTestHarness) TestBasic(t *testing.T, firstSeqNum int64, wal waLog) { + expected := appendRandomRecords(t, wal, 123) + + for i := 0; i < 123; i++ { + it, err := wal.Iterator(firstSeqNum + int64(i)) + if err != nil { + t.Fatal(err) + } + + checkIteratorMatches(t, it, expected[i:]) + + it.Close() + } +} + +func (h *GenericTestHarness) TestAppendNotFound(t *testing.T, firstSeqNum int64, wal waLog) { + recs := appendRandomRecords(t, wal, 123) + lastSeqNum := recs[len(recs)-1].SeqNum + + it, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + it.Close() + + it, err = wal.Iterator(lastSeqNum + 1) + if err != nil { + t.Fatal(err) + } + it.Close() + + if _, err = wal.Iterator(firstSeqNum - 1); !errs.NotFound.Is(err) { + t.Fatal(err) + } + + if _, err = wal.Iterator(lastSeqNum + 2); !errs.NotFound.Is(err) { + t.Fatal(err) + } +} + +func (h *GenericTestHarness) TestNextAfterClose(t *testing.T, firstSeqNum int64, wal waLog) { + appendRandomRecords(t, wal, 123) + + it, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + if !it.Next(0) { + t.Fatal("Should be next") + } + + if err := wal.Close(); err != nil { + t.Fatal(err) + } + + if it.Next(0) { + t.Fatal("Shouldn't be next") + } + + if !errs.Closed.Is(it.Error()) { + t.Fatal(it.Error()) + } +} + +func (h *GenericTestHarness) TestNextTimeout(t *testing.T, firstSeqNum int64, wal waLog) { + recs := appendRandomRecords(t, wal, 123) + + it, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + for range recs { + if !it.Next(0) { + t.Fatal("Expected next") + } + } + + if it.Next(time.Millisecond) { + t.Fatal("Unexpected next") + } +} + +func (h *GenericTestHarness) TestNextNotify(t *testing.T, firstSeqNum int64, wal waLog) { + it, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + recsC := make(chan []RawRecord, 1) + + go func() { + time.Sleep(time.Second) + recsC <- appendRandomRecords(t, wal, 1) + }() + + if !it.Next(time.Hour) { + t.Fatal("expected next") + } + + recs := <-recsC + rec := it.Record() + if rec.SeqNum != recs[0].SeqNum { + t.Fatal(rec) + } +} + +func (h *GenericTestHarness) TestNextArchived(t *testing.T, firstSeqNum int64, wal waLog) { + type archiver interface { + Archive() error + } + + arch, ok := wal.(archiver) + if !ok { + return + } + + recs := appendRandomRecords(t, wal, 10) + + it, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + if err := arch.Archive(); err != nil { + t.Fatal(err) + } + + for i, expected := range recs { + if !it.Next(time.Millisecond) { + t.Fatal(i, "no next") + } + + rec := it.Record() + if rec.SeqNum != expected.SeqNum { + t.Fatal(rec, expected) + } + } + + if it.Next(time.Minute) { + t.Fatal("unexpected next") + } + + if !errs.EOFArchived.Is(it.Error()) { + t.Fatal(it.Error()) + } +} + +func (h *GenericTestHarness) TestWriteReadConcurrent(t *testing.T, firstSeqNum int64, wal waLog) { + N := 1200 + + writeErr := make(chan error, 1) + + dataSize := int64(4) + makeData := func(i int) []byte { + data := make([]byte, 4) + binary.LittleEndian.PutUint32(data, uint32(i)) + return data + } + + go func() { + for i := 0; i < N; i++ { + + seqNum, _, err := wal.Append(dataSize, bytes.NewBuffer(makeData(i))) + if err != nil { + writeErr <- err + return + } + + if seqNum != int64(i)+firstSeqNum { + writeErr <- errors.New("Incorrect seq num") + return + } + + time.Sleep(time.Millisecond) + } + + writeErr <- nil + }() + + it, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + for i := 0; i < N; i++ { + if !it.Next(time.Minute) { + t.Fatal("expected next", i, it.Error(), it.Record()) + } + + expectedData := makeData(i) + rec := it.Record() + + data, err := io.ReadAll(rec.Reader) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, expectedData) { + t.Fatal(data, expectedData) + } + } + + if err := <-writeErr; err != nil { + t.Fatal(err) + } +} + +func (h *GenericTestHarness) TestAppendAfterClose(t *testing.T, firstSeqNum int64, wal waLog) { + if _, _, err := wal.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4})); err != nil { + t.Fatal(err) + } + + wal.Close() + + _, _, err := wal.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4})) + if !errs.Closed.Is(err) { + t.Fatal(err) + } +} + +func (h *GenericTestHarness) TestIterateNegativeOne(t *testing.T, firstSeqNum int64, wal waLog) { + recs := appendRandomRecords(t, wal, 10) + + it1, err := wal.Iterator(firstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it1.Close() + + it2, err := wal.Iterator(-1) + if err != nil { + t.Fatal(err) + } + defer it2.Close() + + if !it1.Next(0) { + t.Fatal(0) + } + if !it2.Next(0) { + t.Fatal(0) + } + + r1 := it1.Record() + r2 := it2.Record() + + if r1.SeqNum != r2.SeqNum || r1.SeqNum != firstSeqNum || r1.SeqNum != recs[0].SeqNum { + t.Fatal(r1.SeqNum, r2.SeqNum, firstSeqNum, recs[0].SeqNum) + } +} + +func (h *GenericTestHarness) TestIteratorAfterClose(t *testing.T, firstSeqNum int64, wal waLog) { + appendRandomRecords(t, wal, 10) + wal.Close() + + if _, err := wal.Iterator(-1); !errs.Closed.Is(err) { + t.Fatal(err) + } +} + +func (h *GenericTestHarness) TestIteratorNextWithError(t *testing.T, firstSeqNum int64, wal waLog) { + appendRandomRecords(t, wal, 10) + it, err := wal.Iterator(-1) + if err != nil { + t.Fatal(err) + } + + wal.Close() + + it.Next(0) + if !errs.Closed.Is(it.Error()) { + t.Fatal(it.Error()) + } + + it.Next(0) + if !errs.Closed.Is(it.Error()) { + t.Fatal(it.Error()) + } +} + +func (h *GenericTestHarness) TestIteratorConcurrentClose(t *testing.T, firstSeqNum int64, wal waLog) { + it, err := wal.Iterator(-1) + if err != nil { + t.Fatal(err) + } + + go func() { + writeRandomWithEOF(wal, 3*time.Second) + wal.Close() + }() + + for it.Next(time.Hour) { + // Skip. + } + + // Error may be Closed or NotFound. + if !errs.Closed.Is(it.Error()) && !errs.NotFound.Is(it.Error()) { + t.Fatal(it.Error()) + } +} diff --git a/lib/wal/io.go b/lib/wal/io.go new file mode 100644 index 0000000..b042973 --- /dev/null +++ b/lib/wal/io.go @@ -0,0 +1,125 @@ +package wal + +import ( + "encoding/binary" + "errors" + "hash/crc32" + "io" + "git.crumpington.com/public/jldb/lib/errs" +) + +func ioErrOrEOF(err error) error { + if err == nil { + return nil + } + if errors.Is(err, io.EOF) { + return err + } + return errs.IO.WithErr(err) + +} + +// ---------------------------------------------------------------------------- + +type readAtReader struct { + f io.ReaderAt + offset int64 +} + +func readerAtToReader(f io.ReaderAt, offset int64) io.Reader { + return &readAtReader{f: f, offset: offset} +} + +func (r *readAtReader) Read(b []byte) (int, error) { + n, err := r.f.ReadAt(b, r.offset) + r.offset += int64(n) + return n, ioErrOrEOF(err) +} + +// ---------------------------------------------------------------------------- + +type writeAtWriter struct { + w io.WriterAt + offset int64 +} + +func writerAtToWriter(w io.WriterAt, offset int64) io.Writer { + return &writeAtWriter{w: w, offset: offset} +} + +func (w *writeAtWriter) Write(b []byte) (int, error) { + n, err := w.w.WriteAt(b, w.offset) + w.offset += int64(n) + return n, ioErrOrEOF(err) +} + +// ---------------------------------------------------------------------------- + +type crcWriter struct { + w io.Writer + crc uint32 +} + +func newCRCWriter(w io.Writer) *crcWriter { + return &crcWriter{w: w} +} + +func (w *crcWriter) Write(b []byte) (int, error) { + n, err := w.w.Write(b) + w.crc = crc32.Update(w.crc, crc32.IEEETable, b[:n]) + return n, ioErrOrEOF(err) +} + +func (w *crcWriter) CRC() uint32 { + return w.crc +} + +// ---------------------------------------------------------------------------- + +type dataReader struct { + r io.Reader + remaining int64 + crc uint32 +} + +func newDataReader(r io.Reader, dataSize int64) *dataReader { + return &dataReader{r: r, remaining: dataSize} +} + +func (r *dataReader) Read(b []byte) (int, error) { + if r.remaining == 0 { + return 0, io.EOF + } + + if int64(len(b)) > r.remaining { + b = b[:r.remaining] + } + + n, err := r.r.Read(b) + r.crc = crc32.Update(r.crc, crc32.IEEETable, b[:n]) + r.remaining -= int64(n) + + if r.remaining == 0 { + if err := r.checkCRC(); err != nil { + return n, err + } + } + + if err != nil && !errors.Is(err, io.EOF) { + return n, errs.IO.WithErr(err) + } + + return n, nil +} + +func (r *dataReader) checkCRC() error { + buf := make([]byte, 4) + if _, err := r.r.Read(buf); err != nil { + return errs.Corrupt.WithErr(err) + } + crc := binary.LittleEndian.Uint32(buf) + if crc != r.crc { + return errs.Corrupt.WithMsg("crc mismatch") + } + return nil +} diff --git a/lib/wal/notify.go b/lib/wal/notify.go new file mode 100644 index 0000000..a68a1e3 --- /dev/null +++ b/lib/wal/notify.go @@ -0,0 +1,79 @@ +package wal + +import "sync" + +type segmentState struct { + Closed bool + Archived bool + FirstSeqNum int64 + LastSeqNum int64 +} + +func newSegmentState(closed bool, header segmentHeader) segmentState { + return segmentState{ + Closed: closed, + Archived: header.ArchivedAt != 0, + FirstSeqNum: header.FirstSeqNum, + LastSeqNum: header.LastSeqNum, + } +} + +type notifyMux struct { + lock sync.Mutex + nextID int64 + recvrs map[int64]chan segmentState +} + +type stateRecvr struct { + // Each recvr will always get the most recent sequence number on C. + // When the segment is closed, a -1 is sent. + C chan segmentState + Close func() +} + +func newNotifyMux() *notifyMux { + return ¬ifyMux{ + recvrs: map[int64]chan segmentState{}, + } +} + +func (m *notifyMux) NewRecvr(header segmentHeader) stateRecvr { + state := newSegmentState(false, header) + + m.lock.Lock() + defer m.lock.Unlock() + + m.nextID++ + + recvrID := m.nextID + + recvr := stateRecvr{ + C: make(chan segmentState, 1), + Close: func() { + m.lock.Lock() + defer m.lock.Unlock() + delete(m.recvrs, recvrID) + }, + } + + recvr.C <- state + m.recvrs[recvrID] = recvr.C + + return recvr +} + +func (m *notifyMux) Notify(closed bool, header segmentHeader) { + + state := newSegmentState(closed, header) + + m.lock.Lock() + defer m.lock.Unlock() + + for _, c := range m.recvrs { + select { + case c <- state: + case <-c: + c <- state + } + } +} diff --git a/lib/wal/record.go b/lib/wal/record.go new file mode 100644 index 0000000..baf41c9 --- /dev/null +++ b/lib/wal/record.go @@ -0,0 +1,90 @@ +package wal + +import ( + "encoding/binary" + "hash/crc32" + "io" + "git.crumpington.com/public/jldb/lib/errs" +) + +const recordHeaderSize = 28 + +type Record struct { + SeqNum int64 + TimestampMS int64 + DataSize int64 + Reader io.Reader +} + +func (rec Record) writeHeaderTo(w io.Writer) (int, error) { + buf := make([]byte, recordHeaderSize) + binary.LittleEndian.PutUint64(buf[0:], uint64(rec.SeqNum)) + binary.LittleEndian.PutUint64(buf[8:], uint64(rec.TimestampMS)) + binary.LittleEndian.PutUint64(buf[16:], uint64(rec.DataSize)) + crc := crc32.ChecksumIEEE(buf[:recordHeaderSize-4]) + binary.LittleEndian.PutUint32(buf[24:], crc) + + n, err := w.Write(buf) + if err != nil { + err = errs.IO.WithErr(err) + } + return n, err +} + +func (rec *Record) readHeaderFrom(r io.Reader) error { + buf := make([]byte, recordHeaderSize) + if _, err := io.ReadFull(r, buf); err != nil { + return errs.IO.WithErr(err) + } + + crc := crc32.ChecksumIEEE(buf[:recordHeaderSize-4]) + stored := binary.LittleEndian.Uint32(buf[recordHeaderSize-4:]) + if crc != stored { + return errs.Corrupt.WithMsg("checksum mismatch") + } + + rec.SeqNum = int64(binary.LittleEndian.Uint64(buf[0:])) + rec.TimestampMS = int64(binary.LittleEndian.Uint64(buf[8:])) + rec.DataSize = int64(binary.LittleEndian.Uint64(buf[16:])) + + return nil +} + +func (rec Record) serializedSize() int64 { + return recordHeaderSize + rec.DataSize + 4 // 4 for data CRC32. +} + +func (rec Record) writeTo(w io.Writer) (int64, error) { + nn, err := rec.writeHeaderTo(w) + if err != nil { + return int64(nn), err + } + + n := int64(nn) + + // Write the data. + crcW := newCRCWriter(w) + n2, err := io.CopyN(crcW, rec.Reader, rec.DataSize) + n += n2 + if err != nil { + return n, errs.IO.WithErr(err) + } + + // Write the data crc value. + err = binary.Write(w, binary.LittleEndian, crcW.CRC()) + if err != nil { + return n, errs.IO.WithErr(err) + } + n += 4 + + return n, nil +} + +func (rec *Record) readFrom(r io.Reader) error { + if err := rec.readHeaderFrom(r); err != nil { + return err + } + + rec.Reader = newDataReader(r, rec.DataSize) + return nil +} diff --git a/lib/wal/record_test.go b/lib/wal/record_test.go new file mode 100644 index 0000000..365411c --- /dev/null +++ b/lib/wal/record_test.go @@ -0,0 +1,171 @@ +package wal + +import ( + "bytes" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/testutil" + "math/rand" + "testing" +) + +func NewRecordForTesting() Record { + data := randData() + return Record{ + SeqNum: rand.Int63(), + TimestampMS: rand.Int63(), + DataSize: int64(len(data)), + Reader: bytes.NewBuffer(data), + } +} + +func AssertRecordHeadersEqual(t *testing.T, r1, r2 Record) { + t.Helper() + eq := r1.SeqNum == r2.SeqNum && + r1.TimestampMS == r2.TimestampMS && + r1.DataSize == r2.DataSize + if !eq { + t.Fatal(r1, r2) + } +} + +func TestRecordWriteHeaderToReadHeaderFrom(t *testing.T) { + t.Parallel() + rec1 := NewRecordForTesting() + + b := &bytes.Buffer{} + n, err := rec1.writeHeaderTo(b) + if err != nil { + t.Fatal(err) + } + if n != recordHeaderSize { + t.Fatal(n) + } + + rec2 := Record{} + if err := rec2.readHeaderFrom(b); err != nil { + t.Fatal(err) + } + + AssertRecordHeadersEqual(t, rec1, rec2) +} + +func TestRecordWriteHeaderToEOF(t *testing.T) { + t.Parallel() + rec := NewRecordForTesting() + + for limit := 1; limit < recordHeaderSize; limit++ { + buf := &bytes.Buffer{} + w := testutil.NewLimitWriter(buf, limit) + + n, err := rec.writeHeaderTo(w) + if !errs.IO.Is(err) { + t.Fatal(limit, n, err) + } + } +} + +func TestRecordReadHeaderFromError(t *testing.T) { + t.Parallel() + rec := NewRecordForTesting() + + for limit := 1; limit < recordHeaderSize; limit++ { + b := &bytes.Buffer{} + if _, err := rec.writeHeaderTo(b); err != nil { + t.Fatal(err) + } + r := io.LimitReader(b, int64(limit)) + if err := rec.readFrom(r); !errs.IO.Is(err) { + t.Fatal(err) + } + } +} + +func TestRecordReadHeaderFromCorrupt(t *testing.T) { + t.Parallel() + rec := NewRecordForTesting() + + b := &bytes.Buffer{} + + for i := 0; i < recordHeaderSize; i++ { + if _, err := rec.writeHeaderTo(b); err != nil { + t.Fatal(err) + } + b.Bytes()[i]++ + if err := rec.readHeaderFrom(b); !errs.Corrupt.Is(err) { + t.Fatal(err) + } + } +} + +func TestRecordWriteToReadFrom(t *testing.T) { + t.Parallel() + r1 := NewRecordForTesting() + data := randData() + r1.Reader = bytes.NewBuffer(bytes.Clone(data)) + r1.DataSize = int64(len(data)) + + r2 := Record{} + + b := &bytes.Buffer{} + if _, err := r1.writeTo(b); err != nil { + t.Fatal(err) + } + + if err := r2.readFrom(b); err != nil { + t.Fatal(err) + } + + AssertRecordHeadersEqual(t, r1, r2) + + data2, err := io.ReadAll(r2.Reader) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(data, data2) { + t.Fatal(data, data2) + } +} + +func TestRecordReadFromCorrupt(t *testing.T) { + t.Parallel() + data := randData() + r1 := NewRecordForTesting() + + for i := 0; i < int(r1.serializedSize()); i++ { + r1.Reader = bytes.NewBuffer(data) + r1.DataSize = int64(len(data)) + + buf := &bytes.Buffer{} + r1.writeTo(buf) + buf.Bytes()[i]++ + + r2 := Record{} + if err := r2.readFrom(buf); err != nil { + if !errs.Corrupt.Is(err) { + t.Fatal(i, err) + } + continue // OK. + } + + if _, err := io.ReadAll(r2.Reader); !errs.Corrupt.Is(err) { + t.Fatal(err) + } + } +} + +func TestRecordWriteToError(t *testing.T) { + t.Parallel() + data := randData() + r1 := NewRecordForTesting() + r1.Reader = bytes.NewBuffer(data) + r1.DataSize = int64(len(data)) + + for i := 0; i < int(r1.serializedSize()); i++ { + w := testutil.NewLimitWriter(&bytes.Buffer{}, i) + r1.Reader = bytes.NewBuffer(data) + if _, err := r1.writeTo(w); !errs.IO.Is(err) { + t.Fatal(err) + } + } +} diff --git a/lib/wal/segment-header.go b/lib/wal/segment-header.go new file mode 100644 index 0000000..f4d28a9 --- /dev/null +++ b/lib/wal/segment-header.go @@ -0,0 +1,44 @@ +package wal + +import "encoding/binary" + +type segmentHeader struct { + CreatedAt int64 + ArchivedAt int64 + FirstSeqNum int64 + LastSeqNum int64 // FirstSeqNum - 1 if empty. + LastTimestampMS int64 // 0 if empty. + InsertAt int64 +} + +func (h segmentHeader) WriteTo(b []byte) { + vals := []int64{ + h.CreatedAt, + h.ArchivedAt, + h.FirstSeqNum, + h.LastSeqNum, + h.LastTimestampMS, + h.InsertAt, + } + + for _, val := range vals { + binary.LittleEndian.PutUint64(b[0:8], uint64(val)) + b = b[8:] + } +} + +func (h *segmentHeader) ReadFrom(b []byte) { + ptrs := []*int64{ + &h.CreatedAt, + &h.ArchivedAt, + &h.FirstSeqNum, + &h.LastSeqNum, + &h.LastTimestampMS, + &h.InsertAt, + } + + for _, ptr := range ptrs { + *ptr = int64(binary.LittleEndian.Uint64(b[0:8])) + b = b[8:] + } +} diff --git a/lib/wal/segment-iterator.go b/lib/wal/segment-iterator.go new file mode 100644 index 0000000..6495f2b --- /dev/null +++ b/lib/wal/segment-iterator.go @@ -0,0 +1,165 @@ +package wal + +import ( + "git.crumpington.com/public/jldb/lib/atomicheader" + "git.crumpington.com/public/jldb/lib/errs" + "os" + "time" +) + +type segmentIterator struct { + f *os.File + + recvr stateRecvr + state segmentState + + offset int64 + err error + rec Record + + ticker *time.Ticker // Ticker if timeout has been set. + tickerC <-chan time.Time // Ticker channel if timeout has been set. +} + +func newSegmentIterator( + f *os.File, + fromSeqNum int64, + recvr stateRecvr, +) ( + Iterator, + error, +) { + it := &segmentIterator{ + f: f, + recvr: recvr, + state: <-recvr.C, + } + + if err := it.seekToSeqNum(fromSeqNum); err != nil { + it.Close() + return nil, err + } + + it.rec.SeqNum = fromSeqNum - 1 + + it.ticker = time.NewTicker(time.Second) + it.tickerC = it.ticker.C + + return it, nil +} + +func (it *segmentIterator) seekToSeqNum(fromSeqNum int64) error { + + state := it.state + + // Is the requested sequence number out-of-range? + if fromSeqNum < state.FirstSeqNum || fromSeqNum > state.LastSeqNum+1 { + return errs.NotFound.WithMsg("sequence number not in segment") + } + + // Seek to start. + it.offset = atomicheader.ReservedBytes + + // Seek to first seq num - we're already there. + if fromSeqNum == it.state.FirstSeqNum { + return nil + } + + for { + if err := it.readRecord(); err != nil { + return err + } + + it.offset += it.rec.serializedSize() + + if it.rec.SeqNum == fromSeqNum-1 { + return nil + } + } +} + +func (it *segmentIterator) Close() { + it.f.Close() + it.recvr.Close() +} + +// Next returns true if there's a record available to read via it.Record(). +// +// If Next returns false, the caller should check the error value with +// it.Error(). +func (it *segmentIterator) Next(timeout time.Duration) bool { + if it.err != nil { + return false + } + + // Get new state if available. + select { + case it.state = <-it.recvr.C: + default: + } + + if it.state.Closed { + it.err = errs.Closed + return false + } + + if it.rec.SeqNum < it.state.LastSeqNum { + if it.err = it.readRecord(); it.err != nil { + return false + } + it.offset += it.rec.serializedSize() + return true + } + + if it.state.Archived { + it.err = errs.EOFArchived + return false + } + + if timeout <= 0 { + return false // Nothing to return. + } + + // Wait for new record, or timeout. + it.ticker.Reset(timeout) + + // Get new state if available. + select { + case it.state = <-it.recvr.C: + // OK + case <-it.tickerC: + return false // Timeout, no error. + } + + if it.state.Closed { + it.err = errs.Closed + return false + } + + if it.rec.SeqNum < it.state.LastSeqNum { + if it.err = it.readRecord(); it.err != nil { + return false + } + it.offset += it.rec.serializedSize() + return true + } + + if it.state.Archived { + it.err = errs.EOFArchived + return false + } + + return false +} + +func (it *segmentIterator) Record() Record { + return it.rec +} + +func (it *segmentIterator) Error() error { + return it.err +} + +func (it *segmentIterator) readRecord() error { + return it.rec.readFrom(readerAtToReader(it.f, it.offset)) +} diff --git a/lib/wal/segment.go b/lib/wal/segment.go new file mode 100644 index 0000000..c4bfaa4 --- /dev/null +++ b/lib/wal/segment.go @@ -0,0 +1,250 @@ +package wal + +import ( + "bufio" + "io" + "git.crumpington.com/public/jldb/lib/atomicheader" + "git.crumpington.com/public/jldb/lib/errs" + "os" + "sync" + "time" +) + +type segment struct { + ID int64 + + lock sync.Mutex + + closed bool + header segmentHeader + headWriter *atomicheader.Handler + f *os.File + notifyMux *notifyMux + + // For non-archived segments. + w *bufio.Writer +} + +func createSegment(path string, id, firstSeqNum, timestampMS int64) (*segment, error) { + f, err := os.Create(path) + if err != nil { + return nil, errs.IO.WithErr(err) + } + defer f.Close() + + if err := atomicheader.Init(f); err != nil { + return nil, err + } + + handler, err := atomicheader.Open(f) + if err != nil { + return nil, err + } + + header := segmentHeader{ + CreatedAt: time.Now().Unix(), + FirstSeqNum: firstSeqNum, + LastSeqNum: firstSeqNum - 1, + LastTimestampMS: timestampMS, + InsertAt: atomicheader.ReservedBytes, + } + + err = handler.Write(func(page []byte) error { + header.WriteTo(page) + return nil + }) + + if err != nil { + return nil, err + } + + return openSegment(path, id) +} + +func openSegment(path string, id int64) (*segment, error) { + f, err := os.OpenFile(path, os.O_RDWR, 0600) + if err != nil { + return nil, errs.IO.WithErr(err) + } + + handler, err := atomicheader.Open(f) + if err != nil { + f.Close() + return nil, err + } + + var header segmentHeader + err = handler.Read(func(page []byte) error { + header.ReadFrom(page) + return nil + }) + + if err != nil { + f.Close() + return nil, err + } + + if _, err := f.Seek(header.InsertAt, io.SeekStart); err != nil { + f.Close() + return nil, errs.IO.WithErr(err) + } + + seg := &segment{ + ID: id, + header: header, + headWriter: handler, + f: f, + notifyMux: newNotifyMux(), + } + + if header.ArchivedAt == 0 { + seg.w = bufio.NewWriterSize(f, 1024*1024) + } + + return seg, nil +} + +// Append appends the data from r to the log atomically. If an error is +// returned, the caller should check for errs.Fatal. If a fatal error occurs, +// the segment should no longer be used. +func (seg *segment) Append(dataSize int64, r io.Reader) (int64, int64, error) { + return seg.appendRecord(Record{ + SeqNum: -1, + TimestampMS: time.Now().UnixMilli(), + DataSize: dataSize, + Reader: r, + }) +} + +func (seg *segment) Header() segmentHeader { + seg.lock.Lock() + defer seg.lock.Unlock() + return seg.header +} + +// appendRecord appends a record in an atomic fashion. Do not use the segment +// after a fatal error. +func (seg *segment) appendRecord(rec Record) (int64, int64, error) { + seg.lock.Lock() + defer seg.lock.Unlock() + + header := seg.header // Copy. + + if seg.closed { + return 0, 0, errs.Closed + } + + if header.ArchivedAt != 0 { + return 0, 0, errs.Archived + } + + if rec.SeqNum == -1 { + rec.SeqNum = header.LastSeqNum + 1 + } else if rec.SeqNum != header.LastSeqNum+1 { + return 0, 0, errs.Unexpected.WithMsg( + "expected sequence number %d but got %d", + header.LastSeqNum+1, + rec.SeqNum) + } + + seg.w.Reset(writerAtToWriter(seg.f, header.InsertAt)) + + n, err := rec.writeTo(seg.w) + if err != nil { + return 0, 0, err + } + + if err := seg.w.Flush(); err != nil { + return 0, 0, ioErrOrEOF(err) + } + + // Write new header to sync. + header.LastSeqNum = rec.SeqNum + header.LastTimestampMS = rec.TimestampMS + header.InsertAt += n + + err = seg.headWriter.Write(func(page []byte) error { + header.WriteTo(page) + return nil + }) + if err != nil { + return 0, 0, err + } + + seg.header = header + seg.notifyMux.Notify(false, header) + + return rec.SeqNum, rec.TimestampMS, nil +} + +// ---------------------------------------------------------------------------- + +func (seg *segment) Archive() error { + seg.lock.Lock() + defer seg.lock.Unlock() + + header := seg.header // Copy + if header.ArchivedAt != 0 { + return nil + } + + header.ArchivedAt = time.Now().Unix() + err := seg.headWriter.Write(func(page []byte) error { + header.WriteTo(page) + return nil + }) + if err != nil { + return err + } + + seg.w = nil // We won't be writing any more. + + seg.header = header + seg.notifyMux.Notify(false, header) + return nil +} + +// ---------------------------------------------------------------------------- + +func (seg *segment) Iterator(fromSeqNum int64) (Iterator, error) { + seg.lock.Lock() + defer seg.lock.Unlock() + + if seg.closed { + return nil, errs.Closed + } + + f, err := os.Open(seg.f.Name()) + if err != nil { + return nil, errs.IO.WithErr(err) + } + + header := seg.header + if fromSeqNum == -1 { + fromSeqNum = header.FirstSeqNum + } + + return newSegmentIterator( + f, + fromSeqNum, + seg.notifyMux.NewRecvr(header)) +} + +// ---------------------------------------------------------------------------- + +func (seg *segment) Close() error { + seg.lock.Lock() + defer seg.lock.Unlock() + + if seg.closed { + return nil + } + + seg.closed = true + + header := seg.header + seg.notifyMux.Notify(true, header) + seg.f.Close() + + return nil +} diff --git a/lib/wal/segment_test.go b/lib/wal/segment_test.go new file mode 100644 index 0000000..ab922f7 --- /dev/null +++ b/lib/wal/segment_test.go @@ -0,0 +1,145 @@ +package wal + +import ( + "bytes" + crand "crypto/rand" + "io" + "git.crumpington.com/public/jldb/lib/atomicheader" + "git.crumpington.com/public/jldb/lib/errs" + "path/filepath" + "testing" + "time" +) + +func newSegmentForTesting(t *testing.T) *segment { + tmpDir := t.TempDir() + seg, err := createSegment(filepath.Join(tmpDir, "x"), 1, 100, 200) + if err != nil { + t.Fatal(err) + } + return seg +} + +func TestNewSegmentDirNotFound(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "notFound", "1245") + + if _, err := createSegment(p, 1, 1234, 5678); !errs.IO.Is(err) { + t.Fatal(err) + } +} + +func TestOpenSegmentNotFound(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + p := filepath.Join(tmpDir, "notFound") + + if _, err := openSegment(p, 1); !errs.IO.Is(err) { + t.Fatal(err) + } +} + +func TestOpenSegmentTruncatedFile(t *testing.T) { + t.Parallel() + seg := newSegmentForTesting(t) + + path := seg.f.Name() + if err := seg.f.Truncate(4); err != nil { + t.Fatal(err) + } + seg.Close() + + if _, err := openSegment(path, 1); !errs.IO.Is(err) { + t.Fatal(err) + } +} + +func TestOpenSegmentCorruptHeader(t *testing.T) { + t.Parallel() + seg := newSegmentForTesting(t) + + path := seg.f.Name() + buf := make([]byte, atomicheader.ReservedBytes) + crand.Read(buf) + + if _, err := seg.f.Seek(0, io.SeekStart); err != nil { + t.Fatal(err) + } + + if _, err := seg.f.Write(buf); err != nil { + t.Fatal(err) + } + + seg.Close() + + if _, err := openSegment(path, 1); !errs.Corrupt.Is(err) { + t.Fatal(err) + } +} + +func TestOpenSegmentCorruptHeader2(t *testing.T) { + t.Parallel() + seg := newSegmentForTesting(t) + + path := seg.f.Name() + buf := make([]byte, 1024) // 2 pages. + crand.Read(buf) + + if _, err := seg.f.Seek(1024, io.SeekStart); err != nil { + t.Fatal(err) + } + + if _, err := seg.f.Write(buf); err != nil { + t.Fatal(err) + } + + seg.Close() + + if _, err := openSegment(path, 1); !errs.Corrupt.Is(err) { + t.Fatal(err) + } +} + +func TestSegmentArchiveTwice(t *testing.T) { + t.Parallel() + seg := newSegmentForTesting(t) + + for i := 0; i < 2; i++ { + if err := seg.Archive(); err != nil { + t.Fatal(err) + } + } +} + +func TestSegmentAppendArchived(t *testing.T) { + t.Parallel() + seg := newSegmentForTesting(t) + + appendRandomRecords(t, seg, 8) + + if err := seg.Archive(); err != nil { + t.Fatal(err) + } + + _, _, err := seg.Append(4, bytes.NewBuffer([]byte{1, 2, 3, 4})) + if !errs.Archived.Is(err) { + t.Fatal(err) + } +} + +func TestSegmentAppendRecordInvalidSeqNum(t *testing.T) { + t.Parallel() + seg := newSegmentForTesting(t) + + appendRandomRecords(t, seg, 8) // 109 is next. + + _, _, err := seg.appendRecord(Record{ + SeqNum: 110, + TimestampMS: time.Now().UnixMilli(), + DataSize: 100, + }) + if !errs.Unexpected.Is(err) { + t.Fatal(err) + } +} diff --git a/lib/wal/test-util_test.go b/lib/wal/test-util_test.go new file mode 100644 index 0000000..0341d22 --- /dev/null +++ b/lib/wal/test-util_test.go @@ -0,0 +1,232 @@ +package wal + +import ( + "bytes" + crand "crypto/rand" + "encoding/base32" + "encoding/binary" + "hash/crc32" + "io" + "math/rand" + "os" + "reflect" + "testing" + "time" +) + +// ---------------------------------------------------------------------------- + +func randString() string { + size := 8 + rand.Intn(92) + buf := make([]byte, size) + if _, err := crand.Read(buf); err != nil { + panic(err) + } + return base32.StdEncoding.EncodeToString(buf) +} + +// ---------------------------------------------------------------------------- + +type RawRecord struct { + Record + Data []byte + DataCRC uint32 +} + +func (rr *RawRecord) ReadFrom(t *testing.T, f *os.File, offset int64) { + t.Helper() + + buf := make([]byte, recordHeaderSize) + if _, err := f.ReadAt(buf, offset); err != nil { + t.Fatal(err) + } + + if err := rr.Record.readHeaderFrom(readerAtToReader(f, offset)); err != nil { + t.Fatal(err) + } + + rr.Data = make([]byte, rr.DataSize+4) // For data and CRC32. + if _, err := f.ReadAt(rr.Data, offset+recordHeaderSize); err != nil { + t.Fatal(err) + } + + storedCRC := binary.LittleEndian.Uint32(rr.Data[rr.DataSize:]) + computedCRC := crc32.ChecksumIEEE(rr.Data[:rr.DataSize]) + + if storedCRC != computedCRC { + t.Fatal(storedCRC, computedCRC) + } + + rr.Data = rr.Data[:rr.DataSize] +} + +// ---------------------------------------------------------------------------- + +func appendRandomRecords(t *testing.T, w waLog, count int64) []RawRecord { + t.Helper() + + recs := make([]RawRecord, count) + + for i := range recs { + rec := RawRecord{ + Data: []byte(randString()), + } + rec.DataSize = int64(len(rec.Data)) + + seqNum, _, err := w.Append(int64(len(rec.Data)), bytes.NewBuffer(rec.Data)) + if err != nil { + t.Fatal(err) + } + + rec.SeqNum = seqNum + + recs[i] = rec + } + + // Check that sequence numbers are sequential. + seqNum := recs[0].SeqNum + for _, rec := range recs { + if rec.SeqNum != seqNum { + t.Fatal(seqNum, rec) + } + seqNum++ + } + + return recs +} + +func checkIteratorMatches(t *testing.T, it Iterator, recs []RawRecord) { + for i, expected := range recs { + if !it.Next(time.Millisecond) { + t.Fatal(i, "no next") + } + + rec := it.Record() + + if rec.SeqNum != expected.SeqNum { + t.Fatal(i, rec.SeqNum, expected.SeqNum) + } + + if rec.DataSize != expected.DataSize { + t.Fatal(i, rec.DataSize, expected.DataSize) + } + + if rec.TimestampMS == 0 { + t.Fatal(rec.TimestampMS) + } + + data := make([]byte, rec.DataSize) + if _, err := io.ReadFull(rec.Reader, data); err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data, expected.Data) { + t.Fatalf("%d %s != %s", i, data, expected.Data) + } + } + + if it.Error() != nil { + t.Fatal(it.Error()) + } + + // Check that iterator is empty. + if it.Next(0) { + t.Fatal("extra", it.Record()) + } +} + +func randData() []byte { + data := make([]byte, 1+rand.Intn(128)) + crand.Read(data) + return data +} + +func writeRandomWithEOF(w waLog, dt time.Duration) error { + tStart := time.Now() + for time.Since(tStart) < dt { + data := randData() + _, _, err := w.Append(int64(len(data)), bytes.NewBuffer(data)) + if err != nil { + return err + } + time.Sleep(time.Millisecond) + } + + _, _, err := w.Append(3, bytes.NewBuffer([]byte("EOF"))) + return err +} + +func waitForEOF(t *testing.T, w *WAL) { + t.Helper() + + h := w.seg.Header() + it, err := w.Iterator(h.FirstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it.Close() + + for it.Next(time.Hour) { + rec := it.Record() + buf := make([]byte, rec.DataSize) + if _, err := io.ReadFull(rec.Reader, buf); err != nil { + t.Fatal(err) + } + if bytes.Equal(buf, []byte("EOF")) { + return + } + } + + t.Fatal("waitForEOF", it.Error()) +} + +func checkWALsEqual(t *testing.T, w1, w2 *WAL) { + t.Helper() + + info1 := w1.Info() + info2 := w2.Info() + + if !reflect.DeepEqual(info1, info2) { + t.Fatal(info1, info2) + } + + it1, err := w1.Iterator(info1.FirstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it1.Close() + + it2, err := w2.Iterator(info2.FirstSeqNum) + if err != nil { + t.Fatal(err) + } + defer it2.Close() + + for { + ok1 := it1.Next(time.Second) + ok2 := it2.Next(time.Second) + if ok1 != ok2 { + t.Fatal(ok1, ok2) + } + + if !ok1 { + return + } + + rec1 := it1.Record() + rec2 := it2.Record() + + data1, err := io.ReadAll(rec1.Reader) + if err != nil { + t.Fatal(err) + } + data2, err := io.ReadAll(rec2.Reader) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(data1, data2) { + t.Fatal(data1, data2) + } + } +} diff --git a/lib/wal/wal-header.go b/lib/wal/wal-header.go new file mode 100644 index 0000000..aa826d8 --- /dev/null +++ b/lib/wal/wal-header.go @@ -0,0 +1,25 @@ +package wal + +import "encoding/binary" + +type walHeader struct { + FirstSegmentID int64 + LastSegmentID int64 +} + +func (h walHeader) WriteTo(b []byte) { + vals := []int64{h.FirstSegmentID, h.LastSegmentID} + + for _, val := range vals { + binary.LittleEndian.PutUint64(b[0:8], uint64(val)) + b = b[8:] + } +} + +func (h *walHeader) ReadFrom(b []byte) { + ptrs := []*int64{&h.FirstSegmentID, &h.LastSegmentID} + for _, ptr := range ptrs { + *ptr = int64(binary.LittleEndian.Uint64(b[0:8])) + b = b[8:] + } +} diff --git a/lib/wal/wal-iterator.go b/lib/wal/wal-iterator.go new file mode 100644 index 0000000..5d0902d --- /dev/null +++ b/lib/wal/wal-iterator.go @@ -0,0 +1,88 @@ +package wal + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "time" +) + +type walIterator struct { + // getSeg should return a segment given its ID, or return nil. + getSeg func(id int64) (*segment, error) + seg *segment // Our current segment. + it Iterator // Our current segment iterator. + seqNum int64 + err error +} + +func newWALIterator( + getSeg func(id int64) (*segment, error), + seg *segment, + fromSeqNum int64, +) ( + *walIterator, + error, +) { + segIter, err := seg.Iterator(fromSeqNum) + if err != nil { + return nil, err + } + + return &walIterator{ + getSeg: getSeg, + seg: seg, + it: segIter, + seqNum: fromSeqNum, + }, nil +} + +func (it *walIterator) Next(timeout time.Duration) bool { + if it.err != nil { + return false + } + + if it.it.Next(timeout) { + it.seqNum++ + return true + } + + it.err = it.it.Error() + if !errs.EOFArchived.Is(it.err) { + return false + } + + it.it.Close() + + id := it.seg.ID + 1 + it.seg, it.err = it.getSeg(id) + + if it.err != nil { + return false + } + + if it.seg == nil { + it.err = errs.NotFound // Could be not-found, or closed. + return false + } + + it.it, it.err = it.seg.Iterator(it.seqNum) + if it.err != nil { + return false + } + + return it.Next(timeout) +} + +func (it *walIterator) Record() Record { + return it.it.Record() +} + +func (it *walIterator) Error() error { + return it.err +} + +func (it *walIterator) Close() { + if it.it != nil { + it.it.Close() + } + it.it = nil +} diff --git a/lib/wal/wal-recv.go b/lib/wal/wal-recv.go new file mode 100644 index 0000000..60cf008 --- /dev/null +++ b/lib/wal/wal-recv.go @@ -0,0 +1,60 @@ +package wal + +import ( + "encoding/binary" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "net" + "time" +) + +func (wal *WAL) Recv(conn net.Conn, timeout time.Duration) error { + defer conn.Close() + + var ( + rec Record + msgType = make([]byte, 1) + ) + + // Send sequence number. + seqNum := wal.Info().LastSeqNum + 1 + conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := binary.Write(conn, binary.LittleEndian, seqNum); err != nil { + return errs.IO.WithErr(err) + } + conn.SetWriteDeadline(time.Time{}) + + for { + conn.SetReadDeadline(time.Now().Add(timeout)) + + if _, err := io.ReadFull(conn, msgType); err != nil { + return errs.IO.WithErr(err) + } + + switch msgType[0] { + + case msgTypeHeartbeat: + // Nothing to do. + + case msgTypeError: + e := &errs.Error{} + if err := e.Read(conn); err != nil { + return err + } + + return e + + case msgTypeRec: + if err := rec.readFrom(conn); err != nil { + return err + } + + if _, _, err := wal.appendRecord(rec); err != nil { + return err + } + + default: + return errs.Unexpected.WithMsg("Unknown message type: %d", msgType[0]) + } + } +} diff --git a/lib/wal/wal-send.go b/lib/wal/wal-send.go new file mode 100644 index 0000000..7ab7e70 --- /dev/null +++ b/lib/wal/wal-send.go @@ -0,0 +1,73 @@ +package wal + +import ( + "encoding/binary" + "git.crumpington.com/public/jldb/lib/errs" + "net" + "time" +) + +const ( + msgTypeRec = 8 + msgTypeHeartbeat = 16 + msgTypeError = 32 +) + +func (wal *WAL) Send(conn net.Conn, timeout time.Duration) error { + defer conn.Close() + + var ( + seqNum int64 + heartbeatTimeout = timeout / 8 + ) + + conn.SetReadDeadline(time.Now().Add(timeout)) + if err := binary.Read(conn, binary.LittleEndian, &seqNum); err != nil { + return errs.IO.WithErr(err) + } + conn.SetReadDeadline(time.Time{}) + + it, err := wal.Iterator(seqNum) + if err != nil { + return err + } + defer it.Close() + + for { + if it.Next(heartbeatTimeout) { + rec := it.Record() + + conn.SetWriteDeadline(time.Now().Add(timeout)) + + if _, err := conn.Write([]byte{msgTypeRec}); err != nil { + return errs.IO.WithErr(err) + } + + if _, err := rec.writeTo(conn); err != nil { + return err + } + + continue + } + + if it.Error() != nil { + conn.SetWriteDeadline(time.Now().Add(timeout)) + if _, err := conn.Write([]byte{msgTypeError}); err != nil { + return errs.IO.WithErr(err) + } + + err, ok := it.Error().(*errs.Error) + if !ok { + err = errs.Unexpected.WithErr(err) + } + err.Write(conn) + // w.Flush() + return err + } + + conn.SetWriteDeadline(time.Now().Add(timeout)) + if _, err := conn.Write([]byte{msgTypeHeartbeat}); err != nil { + return errs.IO.WithErr(err) + } + } +} diff --git a/lib/wal/wal-sendrecv_test.go b/lib/wal/wal-sendrecv_test.go new file mode 100644 index 0000000..e4d2100 --- /dev/null +++ b/lib/wal/wal-sendrecv_test.go @@ -0,0 +1,271 @@ +package wal + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/testutil" + "log" + "math/rand" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +func TestSendRecvHarness(t *testing.T) { + t.Parallel() + (&SendRecvTestHarness{}).Run(t) +} + +type SendRecvTestHarness struct{} + +func (h *SendRecvTestHarness) Run(t *testing.T) { + val := reflect.ValueOf(h) + typ := val.Type() + for i := 0; i < typ.NumMethod(); i++ { + method := typ.Method(i) + if !strings.HasPrefix(method.Name, "Test") { + continue + } + + t.Run(method.Name, func(t *testing.T) { + t.Parallel() + + pDir := t.TempDir() + sDir := t.TempDir() + + config := Config{ + SegMinCount: 8, + SegMaxAgeSec: 1, + } + + pWAL, err := Create(pDir, 1, config) + if err != nil { + t.Fatal(err) + } + defer pWAL.Close() + + sWAL, err := Create(sDir, 1, config) + if err != nil { + t.Fatal(err) + } + defer sWAL.Close() + + nw := testutil.NewNetwork() + defer func() { + nw.CloseServer() + nw.CloseClient() + }() + + val.MethodByName(method.Name).Call([]reflect.Value{ + reflect.ValueOf(t), + reflect.ValueOf(pWAL), + reflect.ValueOf(sWAL), + reflect.ValueOf(nw), + }) + }) + } +} + +func (h *SendRecvTestHarness) TestSimple( + t *testing.T, + pWAL *WAL, + sWAL *WAL, + nw *testutil.Network, +) { + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + if err := writeRandomWithEOF(pWAL, 5*time.Second); err != nil { + panic(err) + } + }() + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Accept() + if err := pWAL.Send(conn, 8*time.Second); err != nil { + log.Printf("Send error: %v", err) + } + }() + + // Recv in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Dial() + if err := sWAL.Recv(conn, 8*time.Second); err != nil { + log.Printf("Recv error: %v", err) + } + }() + + waitForEOF(t, sWAL) + + nw.CloseServer() + nw.CloseClient() + wg.Wait() + + checkWALsEqual(t, pWAL, sWAL) +} + +func (h *SendRecvTestHarness) TestWriteThenRead( + t *testing.T, + pWAL *WAL, + sWAL *WAL, + nw *testutil.Network, +) { + wg := sync.WaitGroup{} + + if err := writeRandomWithEOF(pWAL, 2*time.Second); err != nil { + t.Fatal(err) + } + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Accept() + if err := pWAL.Send(conn, 8*time.Second); err != nil { + log.Printf("Send error: %v", err) + } + }() + + // Recv in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Dial() + if err := sWAL.Recv(conn, 8*time.Second); err != nil { + log.Printf("Recv error: %v", err) + } + }() + + waitForEOF(t, sWAL) + + nw.CloseServer() + nw.CloseClient() + wg.Wait() + + checkWALsEqual(t, pWAL, sWAL) +} + +func (h *SendRecvTestHarness) TestNetworkFailures( + t *testing.T, + pWAL *WAL, + sWAL *WAL, + nw *testutil.Network, +) { + recvDone := &atomic.Bool{} + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + writeRandomWithEOF(pWAL, 10*time.Second) + }() + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + + for { + if recvDone.Load() { + return + } + if conn := nw.Accept(); conn != nil { + pWAL.Send(conn, 8*time.Second) + } + } + }() + + // Recv in the background. + wg.Add(1) + go func() { + defer wg.Done() + for !recvDone.Load() { + if conn := nw.Dial(); conn != nil { + sWAL.Recv(conn, 8*time.Second) + } + } + }() + + wg.Add(1) + failureCount := 0 + go func() { + defer wg.Done() + for { + if recvDone.Load() { + return + } + time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) + failureCount++ + if rand.Float64() < 0.5 { + nw.CloseClient() + } else { + nw.CloseServer() + } + } + }() + + waitForEOF(t, sWAL) + recvDone.Store(true) + wg.Wait() + + log.Printf("%d network failures.", failureCount) + + if failureCount < 10 { + t.Fatal("Expected more failures.") + } + + checkWALsEqual(t, pWAL, sWAL) +} + +func (h *SendRecvTestHarness) TestSenderClose( + t *testing.T, + pWAL *WAL, + sWAL *WAL, + nw *testutil.Network, +) { + wg := sync.WaitGroup{} + + wg.Add(1) + go func() { + defer wg.Done() + if err := writeRandomWithEOF(pWAL, 5*time.Second); !errs.Closed.Is(err) { + panic(err) + } + }() + + // Close primary after some time. + wg.Add(1) + go func() { + defer wg.Done() + time.Sleep(time.Second) + pWAL.Close() + }() + + // Send in the background. + wg.Add(1) + go func() { + defer wg.Done() + conn := nw.Accept() + if err := pWAL.Send(conn, 8*time.Second); err != nil { + log.Printf("Send error: %v", err) + } + }() + + conn := nw.Dial() + if err := sWAL.Recv(conn, 8*time.Second); !errs.Closed.Is(err) { + t.Fatal(err) + } + + nw.CloseServer() + nw.CloseClient() + wg.Wait() +} diff --git a/lib/wal/wal.go b/lib/wal/wal.go new file mode 100644 index 0000000..adeea7d --- /dev/null +++ b/lib/wal/wal.go @@ -0,0 +1,321 @@ +package wal + +import ( + "io" + "git.crumpington.com/public/jldb/lib/atomicheader" + "git.crumpington.com/public/jldb/lib/errs" + "os" + "path/filepath" + "strconv" + "sync" + "time" +) + +type Config struct { + SegMinCount int64 + SegMaxAgeSec int64 +} + +type WAL struct { + rootDir string + conf Config + + lock sync.Mutex // Protects the fields below. + + closed bool + header walHeader + headerWriter *atomicheader.Handler + f *os.File // WAL header. + segments map[int64]*segment // Used by the iterator. + seg *segment // Current segment. +} + +func Create(rootDir string, firstSeqNum int64, conf Config) (*WAL, error) { + w := &WAL{rootDir: rootDir, conf: conf} + + seg, err := createSegment(w.segmentPath(1), 1, firstSeqNum, 0) + if err != nil { + return nil, err + } + defer seg.Close() + + f, err := os.Create(w.headerPath()) + if err != nil { + return nil, errs.IO.WithErr(err) + } + defer f.Close() + + if err := atomicheader.Init(f); err != nil { + return nil, err + } + + handler, err := atomicheader.Open(f) + if err != nil { + return nil, err + } + + header := walHeader{ + FirstSegmentID: 1, + LastSegmentID: 1, + } + + err = handler.Write(func(page []byte) error { + header.WriteTo(page) + return nil + }) + + if err != nil { + return nil, err + } + + return Open(rootDir, conf) +} + +func Open(rootDir string, conf Config) (*WAL, error) { + w := &WAL{rootDir: rootDir, conf: conf} + + f, err := os.OpenFile(w.headerPath(), os.O_RDWR, 0600) + if err != nil { + return nil, errs.IO.WithErr(err) + } + + handler, err := atomicheader.Open(f) + if err != nil { + f.Close() + return nil, err + } + + var header walHeader + err = handler.Read(func(page []byte) error { + header.ReadFrom(page) + return nil + }) + + if err != nil { + f.Close() + return nil, err + } + + w.header = header + w.headerWriter = handler + w.f = f + w.segments = map[int64]*segment{} + + for segID := header.FirstSegmentID; segID < header.LastSegmentID+1; segID++ { + segID := segID + seg, err := openSegment(w.segmentPath(segID), segID) + if err != nil { + w.Close() + return nil, err + } + w.segments[segID] = seg + } + + w.seg = w.segments[header.LastSegmentID] + if err := w.grow(); err != nil { + w.Close() + return nil, err + } + + return w, nil +} + +func (w *WAL) Close() error { + w.lock.Lock() + defer w.lock.Unlock() + + if w.closed { + return nil + } + w.closed = true + + for _, seg := range w.segments { + seg.Close() + delete(w.segments, seg.ID) + } + + w.f.Close() + return nil +} + +func (w *WAL) Info() (info Info) { + w.lock.Lock() + defer w.lock.Unlock() + + h := w.header + + info.FirstSeqNum = w.segments[h.FirstSegmentID].Header().FirstSeqNum + + lastHeader := w.segments[h.LastSegmentID].Header() + info.LastSeqNum = lastHeader.LastSeqNum + info.LastTimestampMS = lastHeader.LastTimestampMS + + return +} + +func (w *WAL) Append(dataSize int64, r io.Reader) (int64, int64, error) { + return w.appendRecord(Record{ + SeqNum: -1, + TimestampMS: time.Now().UnixMilli(), + DataSize: dataSize, + Reader: r, + }) +} + +func (w *WAL) appendRecord(rec Record) (int64, int64, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if w.closed { + return 0, 0, errs.Closed + } + + if err := w.grow(); err != nil { + return 0, 0, err + } + return w.seg.appendRecord(rec) +} + +func (w *WAL) Iterator(fromSeqNum int64) (Iterator, error) { + w.lock.Lock() + defer w.lock.Unlock() + + if w.closed { + return nil, errs.Closed + } + + header := w.header + var seg *segment + + getSeg := func(id int64) (*segment, error) { + w.lock.Lock() + defer w.lock.Unlock() + if w.closed { + return nil, errs.Closed + } + return w.segments[id], nil + } + + if fromSeqNum == -1 { + seg = w.segments[header.FirstSegmentID] + return newWALIterator(getSeg, seg, fromSeqNum) + } + + // Seek to the appropriate segment. + seg = w.segments[header.FirstSegmentID] + for seg != nil { + h := seg.Header() + if fromSeqNum >= h.FirstSeqNum && fromSeqNum <= h.LastSeqNum+1 { + return newWALIterator(getSeg, seg, fromSeqNum) + } + seg = w.segments[seg.ID+1] + } + + return nil, errs.NotFound +} + +func (w *WAL) DeleteBefore(timestamp, keepSeqNum int64) error { + for { + seg, err := w.removeSeg(timestamp, keepSeqNum) + if err != nil || seg == nil { + return err + } + + id := seg.ID + os.RemoveAll(w.segmentPath(id)) + seg.Close() + } +} + +func (w *WAL) removeSeg(timestamp, keepSeqNum int64) (*segment, error) { + w.lock.Lock() + defer w.lock.Unlock() + + header := w.header + if header.FirstSegmentID == header.LastSegmentID { + return nil, nil // Nothing to delete now. + } + + id := header.FirstSegmentID + seg := w.segments[id] + if seg == nil { + return nil, errs.Unexpected.WithMsg("segment %d not found", id) + } + + segHeader := seg.Header() + if seg == w.seg || segHeader.ArchivedAt > timestamp { + return nil, nil // Nothing to delete now. + } + + if segHeader.LastSeqNum >= keepSeqNum { + return nil, nil + } + + header.FirstSegmentID = id + 1 + err := w.headerWriter.Write(func(page []byte) error { + header.WriteTo(page) + return nil + }) + if err != nil { + return nil, err + } + + w.header = header + delete(w.segments, id) + + return seg, nil +} + +func (w *WAL) grow() error { + segHeader := w.seg.Header() + + if segHeader.ArchivedAt == 0 { + if (segHeader.LastSeqNum - segHeader.FirstSeqNum) < w.conf.SegMinCount { + return nil + } + if time.Now().Unix()-segHeader.CreatedAt < w.conf.SegMaxAgeSec { + return nil + } + } + + newSegID := w.seg.ID + 1 + firstSeqNum := segHeader.LastSeqNum + 1 + timestampMS := segHeader.LastTimestampMS + + newSeg, err := createSegment(w.segmentPath(newSegID), newSegID, firstSeqNum, timestampMS) + if err != nil { + return err + } + + walHeader := w.header + walHeader.LastSegmentID = newSegID + + err = w.headerWriter.Write(func(page []byte) error { + walHeader.WriteTo(page) + return nil + }) + + if err != nil { + newSeg.Close() + return err + } + + if err := w.seg.Archive(); err != nil { + newSeg.Close() + return err + } + + w.seg = newSeg + w.segments[newSeg.ID] = newSeg + w.header = walHeader + + return nil +} + +func (w *WAL) headerPath() string { + return filepath.Join(w.rootDir, "header") +} + +func (w *WAL) segmentPath(segID int64) string { + return filepath.Join(w.rootDir, "seg."+strconv.FormatInt(segID, 10)) +} diff --git a/mdb/change/binary.go b/mdb/change/binary.go new file mode 100644 index 0000000..727ce20 --- /dev/null +++ b/mdb/change/binary.go @@ -0,0 +1,25 @@ +package change + +import ( + "encoding/binary" + "io" + "git.crumpington.com/public/jldb/lib/errs" +) + +func writeBin(w io.Writer, data ...any) error { + for _, value := range data { + if err := binary.Write(w, binary.LittleEndian, value); err != nil { + return errs.IO.WithErr(err) + } + } + return nil +} + +func readBin(r io.Reader, ptrs ...any) error { + for _, ptr := range ptrs { + if err := binary.Read(r, binary.LittleEndian, ptr); err != nil { + return errs.IO.WithErr(err) + } + } + return nil +} diff --git a/mdb/change/change.go b/mdb/change/change.go new file mode 100644 index 0000000..ff8f07d --- /dev/null +++ b/mdb/change/change.go @@ -0,0 +1,98 @@ +package change + +import ( + "io" + "git.crumpington.com/public/jldb/lib/errs" +) + +// ---------------------------------------------------------------------------- +// Change +// ---------------------------------------------------------------------------- + +// The Change type encodes a change (store / delete) to be applied to a +// pagefile. +type Change struct { + CollectionID uint64 + ItemID uint64 + Store bool + Data []byte + + WritePageIDs []uint64 + ClearPageIDs []uint64 +} + +func (ch Change) writeTo(w io.Writer) error { + dataSize := int64(len(ch.Data)) + if !ch.Store { + dataSize = -1 + } + + err := writeBin(w, + ch.CollectionID, + ch.ItemID, + dataSize, + uint64(len(ch.WritePageIDs)), + uint64(len(ch.ClearPageIDs)), + ch.WritePageIDs, + ch.ClearPageIDs) + if err != nil { + return err + } + + if ch.Store { + if _, err := w.Write(ch.Data); err != nil { + return errs.IO.WithErr(err) + } + } + + return nil +} + +func (ch *Change) readFrom(r io.Reader) error { + var pageCount, clearCount uint64 + var dataSize int64 + + err := readBin(r, + &ch.CollectionID, + &ch.ItemID, + &dataSize, + &pageCount, + &clearCount) + if err != nil { + return err + } + + if uint64(cap(ch.WritePageIDs)) < pageCount { + ch.WritePageIDs = make([]uint64, pageCount) + } + ch.WritePageIDs = ch.WritePageIDs[:pageCount] + + if uint64(cap(ch.ClearPageIDs)) < clearCount { + ch.ClearPageIDs = make([]uint64, clearCount) + } + ch.ClearPageIDs = ch.ClearPageIDs[:clearCount] + + if err = readBin(r, ch.WritePageIDs); err != nil { + return err + } + + if err = readBin(r, ch.ClearPageIDs); err != nil { + return err + } + + ch.Store = dataSize != -1 + + if ch.Store { + if int64(cap(ch.Data)) < dataSize { + ch.Data = make([]byte, dataSize) + } + ch.Data = ch.Data[:dataSize] + if _, err := r.Read(ch.Data); err != nil { + return errs.IO.WithErr(err) + } + } else { + ch.Data = ch.Data[:0] + } + + return nil +} diff --git a/mdb/change/change_test.go b/mdb/change/change_test.go new file mode 100644 index 0000000..a64d34d --- /dev/null +++ b/mdb/change/change_test.go @@ -0,0 +1,67 @@ +package change + +import ( + "bytes" + "reflect" + "testing" +) + +func (lhs Change) AssertEqual(t *testing.T, rhs Change) { + if lhs.CollectionID != rhs.CollectionID { + t.Fatal(lhs.CollectionID, rhs.CollectionID) + } + if lhs.ItemID != rhs.ItemID { + t.Fatal(lhs.ItemID, rhs.ItemID) + } + if lhs.Store != rhs.Store { + t.Fatal(lhs.Store, rhs.Store) + } + + if len(lhs.Data) != len(rhs.Data) { + t.Fatal(len(lhs.Data), len(rhs.Data)) + } + + if len(lhs.Data) != 0 { + if !reflect.DeepEqual(lhs.Data, rhs.Data) { + t.Fatal(lhs.Data, rhs.Data) + } + } + + if len(lhs.WritePageIDs) != len(rhs.WritePageIDs) { + t.Fatal(len(lhs.WritePageIDs), len(rhs.WritePageIDs)) + } + + if len(lhs.WritePageIDs) != 0 { + if !reflect.DeepEqual(lhs.WritePageIDs, rhs.WritePageIDs) { + t.Fatal(lhs.WritePageIDs, rhs.WritePageIDs) + } + } + + if len(lhs.ClearPageIDs) != len(rhs.ClearPageIDs) { + t.Fatal(len(lhs.ClearPageIDs), len(rhs.ClearPageIDs)) + } + + if len(lhs.ClearPageIDs) != 0 { + if !reflect.DeepEqual(lhs.ClearPageIDs, rhs.ClearPageIDs) { + t.Fatal(lhs.ClearPageIDs, rhs.ClearPageIDs) + } + } +} + +func TestChangeWriteToReadFrom(t *testing.T) { + out := Change{} + + for i := 0; i < 100; i++ { + in := randChange() + buf := &bytes.Buffer{} + if err := in.writeTo(buf); err != nil { + t.Fatal(err) + } + + if err := out.readFrom(buf); err != nil { + t.Fatal(err) + } + + in.AssertEqual(t, out) + } +} diff --git a/mdb/change/encoding.go b/mdb/change/encoding.go new file mode 100644 index 0000000..80da71a --- /dev/null +++ b/mdb/change/encoding.go @@ -0,0 +1,35 @@ +package change + +import "io" + +func Write(changes []Change, w io.Writer) error { + count := uint64(len(changes)) + if err := writeBin(w, count); err != nil { + return err + } + for _, c := range changes { + if err := c.writeTo(w); err != nil { + return err + } + } + return nil +} + +func Read(changes []Change, r io.Reader) ([]Change, error) { + var count uint64 + if err := readBin(r, &count); err != nil { + return changes, err + } + + if uint64(len(changes)) < count { + changes = make([]Change, count) + } + changes = changes[:count] + + for i := range changes { + if err := changes[i].readFrom(r); err != nil { + return changes, err + } + } + return changes, nil +} diff --git a/mdb/change/encoding_test.go b/mdb/change/encoding_test.go new file mode 100644 index 0000000..e8273bd --- /dev/null +++ b/mdb/change/encoding_test.go @@ -0,0 +1,64 @@ +package change + +import ( + "bytes" + crand "crypto/rand" + "math/rand" + "testing" +) + +func randChange() Change { + c := Change{ + CollectionID: rand.Uint64(), + ItemID: rand.Uint64(), + Store: rand.Float32() < 0.5, + } + + if c.Store { + data := make([]byte, 1+rand.Intn(100)) + crand.Read(data) + c.Data = data + } + + c.WritePageIDs = make([]uint64, rand.Intn(10)) + for i := range c.WritePageIDs { + c.WritePageIDs[i] = rand.Uint64() + } + + c.ClearPageIDs = make([]uint64, rand.Intn(10)) + for i := range c.ClearPageIDs { + c.ClearPageIDs[i] = rand.Uint64() + } + + return c +} + +func randChangeSlice() []Change { + changes := make([]Change, 1+rand.Intn(10)) + for i := range changes { + changes[i] = randChange() + } + return changes +} + +func TestWriteRead(t *testing.T) { + in := randChangeSlice() + var out []Change + + buf := &bytes.Buffer{} + if err := Write(in, buf); err != nil { + t.Fatal(err) + } + + out, err := Read(out, buf) + if err != nil { + t.Fatal(err) + } + + if len(in) != len(out) { + t.Fatal(len(in), len(out)) + } + for i := range in { + in[i].AssertEqual(t, out[i]) + } +} diff --git a/mdb/collection-internal.go b/mdb/collection-internal.go new file mode 100644 index 0000000..8481e83 --- /dev/null +++ b/mdb/collection-internal.go @@ -0,0 +1,24 @@ +package mdb + +type collectionState[T any] struct { + Version uint64 + Indices []indexState[T] +} + +func (c *collectionState[T]) clone(version uint64) *collectionState[T] { + indices := make([]indexState[T], len(c.Indices)) + for i := range indices { + indices[i] = c.Indices[i].clone() + } + return &collectionState[T]{ + Version: version, + Indices: indices, + } +} + +// Add an index returning it's assigned ID. +func (c *collectionState[T]) addIndex(idx indexState[T]) uint64 { + id := uint64(len(c.Indices)) + c.Indices = append(c.Indices, idx) + return id +} diff --git a/mdb/collection.go b/mdb/collection.go new file mode 100644 index 0000000..cd4bf63 --- /dev/null +++ b/mdb/collection.go @@ -0,0 +1,347 @@ +package mdb + +import ( + "bytes" + "encoding/json" + "errors" + "hash/crc64" + "git.crumpington.com/public/jldb/lib/errs" + "unsafe" + + "github.com/google/btree" +) + +type Collection[T any] struct { + db *Database + name string + collectionID uint64 + + copy func(*T) *T + sanitize func(*T) + validate func(*T) error + + indices []Index[T] + uniqueIndices []Index[T] + + ByID Index[T] + + buf *bytes.Buffer +} + +type CollectionConfig[T any] struct { + Copy func(*T) *T + Sanitize func(*T) + Validate func(*T) error +} + +func NewCollection[T any](db *Database, name string, conf *CollectionConfig[T]) *Collection[T] { + if conf == nil { + conf = &CollectionConfig[T]{} + } + + if conf.Copy == nil { + conf.Copy = func(from *T) *T { + to := new(T) + *to = *from + return to + } + } + + if conf.Sanitize == nil { + conf.Sanitize = func(*T) {} + } + + if conf.Validate == nil { + conf.Validate = func(*T) error { + return nil + } + } + + c := &Collection[T]{ + db: db, + name: name, + collectionID: crc64.Checksum([]byte(name), crc64Table), + copy: conf.Copy, + sanitize: conf.Sanitize, + validate: conf.Validate, + indices: []Index[T]{}, + uniqueIndices: []Index[T]{}, + buf: &bytes.Buffer{}, + } + + db.addCollection(c.collectionID, c, &collectionState[T]{ + Indices: []indexState[T]{}, + }) + + c.ByID = c.addIndex(indexConfig[T]{ + Name: "ByID", + Unique: true, + Compare: func(lhs, rhs *T) int { + l := c.getID(lhs) + r := c.getID(rhs) + if l < r { + return -1 + } else if l > r { + return 1 + } + return 0 + }, + }) + + return c +} + +func (c Collection[T]) Name() string { + return c.name +} + +type indexConfig[T any] struct { + Name string + Unique bool + + // If an index isn't unique, an additional comparison by ID is added if + // two items are otherwise equal. + Compare func(lhs, rhs *T) int + + // If not nil, indicates if a given item should be in the index. + Include func(item *T) bool +} + +func (c Collection[T]) Get(tx *Snapshot, id uint64) (*T, bool) { + x := new(T) + c.setID(x, id) + return c.ByID.Get(tx, x) +} + +func (c Collection[T]) List(tx *Snapshot, ids []uint64, out []*T) []*T { + if len(ids) == 0 { + return out[:0] + } + + if cap(out) < len(ids) { + out = make([]*T, len(ids)) + } + out = out[:0] + + for _, id := range ids { + item, ok := c.Get(tx, id) + if ok { + out = append(out, item) + } + } + + return out +} + +// AddIndex: Add an index to the collection. +func (c *Collection[T]) addIndex(conf indexConfig[T]) Index[T] { + + var less func(*T, *T) bool + + if conf.Unique { + less = func(lhs, rhs *T) bool { + return conf.Compare(lhs, rhs) == -1 + } + } else { + less = func(lhs, rhs *T) bool { + switch conf.Compare(lhs, rhs) { + case -1: + return true + case 1: + return false + default: + return c.getID(lhs) < c.getID(rhs) + } + } + } + + indexState := indexState[T]{ + BTree: btree.NewG(256, less), + } + + index := Index[T]{ + collectionID: c.collectionID, + name: conf.Name, + indexID: c.getState(c.db.Snapshot()).addIndex(indexState), + include: conf.Include, + copy: c.copy, + } + + c.indices = append(c.indices, index) + if conf.Unique { + c.uniqueIndices = append(c.uniqueIndices, index) + } + + return index +} + +func (c Collection[T]) Insert(tx *Snapshot, userItem *T) error { + if err := c.ensureMutable(tx); err != nil { + return err + } + + item := c.copy(userItem) + c.sanitize(item) + + if err := c.validate(item); err != nil { + return err + } + + for i := range c.uniqueIndices { + if c.uniqueIndices[i].insertConflict(tx, item) { + return ErrDuplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name) + } + } + + tx.store(c.collectionID, c.getID(item), item) + + for i := range c.indices { + c.indices[i].insert(tx, item) + } + + return nil +} + +func (c Collection[T]) Update(tx *Snapshot, userItem *T) error { + if err := c.ensureMutable(tx); err != nil { + return err + } + + item := c.copy(userItem) + c.sanitize(item) + + if err := c.validate(item); err != nil { + return err + } + + old, ok := c.ByID.get(tx, item) + if !ok { + return ErrNotFound + } + + for i := range c.uniqueIndices { + if c.uniqueIndices[i].updateConflict(tx, item) { + return ErrDuplicate.WithCollection(c.name).WithIndex(c.uniqueIndices[i].name) + } + } + + tx.store(c.collectionID, c.getID(item), item) + + for i := range c.indices { + c.indices[i].update(tx, old, item) + } + + return nil +} + +func (c Collection[T]) Upsert(tx *Snapshot, item *T) error { + err := c.Insert(tx, item) + if err == nil { + return nil + } + if errors.Is(err, ErrDuplicate) { + return c.Update(tx, item) + } + return err +} + +func (c Collection[T]) Delete(tx *Snapshot, itemID uint64) error { + if err := c.ensureMutable(tx); err != nil { + return err + } + + return c.deleteItem(tx, itemID) +} + +func (c Collection[T]) getByID(tx *Snapshot, itemID uint64) (*T, bool) { + x := new(T) + c.setID(x, itemID) + return c.ByID.get(tx, x) +} + +func (c Collection[T]) ensureMutable(tx *Snapshot) error { + if !tx.writable() { + return ErrReadOnly + } + + state := c.getState(tx) + if state.Version != tx.version { + tx.collections[c.collectionID] = state.clone(tx.version) + } + + return nil +} + +// For initial data loading. +func (c Collection[T]) insertItem(tx *Snapshot, itemID uint64, data []byte) error { + item := new(T) + if err := json.Unmarshal(data, item); err != nil { + return errs.Encoding.WithErr(err).WithCollection(c.name) + } + + // Check for insert conflict. + for _, index := range c.uniqueIndices { + if index.insertConflict(tx, item) { + return ErrDuplicate + } + } + + // Do the insert. + for _, index := range c.indices { + index.insert(tx, item) + } + + return nil +} + +func (c Collection[T]) deleteItem(tx *Snapshot, itemID uint64) error { + item, ok := c.getByID(tx, itemID) + if !ok { + return ErrNotFound + } + + tx.delete(c.collectionID, itemID) + + for i := range c.indices { + c.indices[i].delete(tx, item) + } + + return nil +} + +// upsertItem inserts or updates the item with itemID and the given serialized +// form. It's called by +func (c Collection[T]) upsertItem(tx *Snapshot, itemID uint64, data []byte) error { + item, ok := c.getByID(tx, itemID) + if ok { + tx.delete(c.collectionID, itemID) + + for i := range c.indices { + c.indices[i].delete(tx, item) + } + } + + item = new(T) + if err := json.Unmarshal(data, item); err != nil { + return errs.Encoding.WithErr(err).WithCollection(c.name) + } + + // Do the insert. + for _, index := range c.indices { + index.insert(tx, item) + } + + return nil +} + +func (c Collection[T]) getID(t *T) uint64 { + return *((*uint64)(unsafe.Pointer(t))) +} + +func (c Collection[T]) setID(t *T, id uint64) { + *((*uint64)(unsafe.Pointer(t))) = id +} + +func (c Collection[T]) getState(tx *Snapshot) *collectionState[T] { + return tx.collections[c.collectionID].(*collectionState[T]) +} diff --git a/mdb/crashconsistency_test.go b/mdb/crashconsistency_test.go new file mode 100644 index 0000000..ead50f8 --- /dev/null +++ b/mdb/crashconsistency_test.go @@ -0,0 +1,76 @@ +package mdb + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "log" + "os" + "os/exec" + "testing" + "time" +) + +func TestCrashConsistency(t *testing.T) { + if testing.Short() { + t.Skip("Sipping test in short mode.") + } + + // Build the test binary. + err := exec.Command( + "go", "build", + "-o", "testing/crashconsistency/p", + "testing/crashconsistency/main.go").Run() + if err != nil { + t.Fatal(err) + } + + defer os.RemoveAll("testing/crashconsistency/p") + + rootDir := t.TempDir() + defer os.RemoveAll(rootDir) + + for i := 0; i < 32; i++ { + cmd := exec.Command("testing/crashconsistency/p", rootDir) + if err := cmd.Start(); err != nil { + t.Fatal(err) + } + + time.Sleep(time.Second / 2) + + for { + if err := cmd.Process.Kill(); err != nil { + log.Printf("Kill failed: %v", err) + time.Sleep(time.Second) + continue + } + break + } + + var ( + db DataDB + err error + ) + + for { + + db, err = OpenDataDB(rootDir) + if err == nil { + break + } + if errs.Locked.Is(err) { + log.Printf("Locked.") + time.Sleep(time.Second / 10) + continue + } + t.Fatal(err) + } + + tx := db.Snapshot() + computed := db.ComputeCRC(tx) + stored := db.ReadCRC(tx) + if computed != stored { + t.Fatal(stored, computed) + } + + db.Close() + } +} diff --git a/mdb/crc.go b/mdb/crc.go new file mode 100644 index 0000000..7ac3369 --- /dev/null +++ b/mdb/crc.go @@ -0,0 +1,5 @@ +package mdb + +import "hash/crc64" + +var crc64Table = crc64.MakeTable(crc64.ECMA) diff --git a/mdb/db-primary.go b/mdb/db-primary.go new file mode 100644 index 0000000..df6ba5b --- /dev/null +++ b/mdb/db-primary.go @@ -0,0 +1,67 @@ +package mdb + +/* +func (db *Database) openPrimary() (err error) { + wal, err := cwal.Open(db.walRootDir, cwal.Config{ + SegMinCount: db.conf.WALSegMinCount, + SegMaxAgeSec: db.conf.WALSegMaxAgeSec, + }) + + pFile, err := pfile.Open(db.pageFilePath, + + pFile, err := openPageFileAndReplayWAL(db.rootDir) + if err != nil { + return err + } + defer pFile.Close() + + pfHeader, err := pFile.ReadHeader() + if err != nil { + return err + } + + tx := db.Snapshot() + tx.seqNum = pfHeader.SeqNum + tx.updatedAt = pfHeader.UpdatedAt + + pIndex, err := pagefile.NewIndex(pFile) + if err != nil { + return err + } + + err = pFile.IterateAllocated(pIndex, func(cID, iID uint64, data []byte) error { + return db.loadItem(tx, cID, iID, data) + }) + if err != nil { + return err + } + + w, err := cwal.OpenWriter(db.walRootDir, &cwal.WriterConfig{ + SegMinCount: db.conf.WALSegMinCount, + SegMaxAgeSec: db.conf.WALSegMaxAgeSec, + }) + if err != nil { + return err + } + + db.done.Add(1) + go txAggregator{ + Stop: db.stop, + Done: db.done, + ModChan: db.modChan, + W: w, + Index: pIndex, + Snapshot: db.snapshot, + }.Run() + + db.done.Add(1) + go (&fileWriter{ + Stop: db.stop, + Done: db.done, + PageFilePath: db.pageFilePath, + WALRootDir: db.walRootDir, + }).Run() + + return nil +} +*/ diff --git a/mdb/db-rep.go b/mdb/db-rep.go new file mode 100644 index 0000000..fef2ba2 --- /dev/null +++ b/mdb/db-rep.go @@ -0,0 +1,118 @@ +package mdb + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/wal" + "git.crumpington.com/public/jldb/mdb/change" + "git.crumpington.com/public/jldb/mdb/pfile" + "log" + "net" + "os" +) + +func (db *Database) repSendState(conn net.Conn) error { + pf, err := pfile.Open(pageFilePath(db.rootDir)) + if err != nil { + return err + } + defer pf.Close() + return pf.Send(conn, db.conf.NetTimeout) +} + +func (db *Database) repRecvState(conn net.Conn) error { + finalPath := pageFilePath(db.rootDir) + tmpPath := finalPath + ".dl" + if err := pfile.Recv(conn, tmpPath, db.conf.NetTimeout); err != nil { + return err + } + + if err := os.Rename(tmpPath, finalPath); err != nil { + return errs.Unexpected.WithErr(err) + } + + return nil +} + +func (db *Database) repInitStorage() (err error) { + db.pf, err = pfile.Open(pageFilePath(db.rootDir)) + return err +} + +func (db *Database) repReplay(rec wal.Record) (err error) { + db.changes, err = change.Read(db.changes[:0], rec.Reader) + if err != nil { + return err + } + + return db.pf.ApplyChanges(db.changes) +} + +func (db *Database) repLoadFromStorage() (err error) { + db.idx, err = pfile.NewIndex(db.pf) + if err != nil { + return err + } + + tx := db.snapshot.Load() + err = pfile.IterateAllocated(db.pf, db.idx, func(cID, iID uint64, data []byte) error { + return db.loadItem(tx, cID, iID, data) + }) + if err != nil { + return err + } + db.snapshot.Store(tx) + return nil +} + +func (db *Database) loadItem(tx *Snapshot, cID, iID uint64, data []byte) error { + c, ok := db.collections[cID] + if !ok { + log.Printf("Failed to find collection %d for item in page file.", cID) + return nil + } + + return c.insertItem(tx, iID, data) +} + +func (db *Database) repApply(rec wal.Record) (err error) { + db.changes, err = change.Read(db.changes[:0], rec.Reader) + if err != nil { + return err + } + + if err := db.pf.ApplyChanges(db.changes); err != nil { + return err + } + + if db.rep.Primary() { + return nil + } + + // For secondary, we need to also apply changes to memory. + + tx := db.snapshot.Load().begin() + for _, change := range db.changes { + if err = db.applyChange(tx, change); err != nil { + return err + } + } + tx.seqNum = rec.SeqNum + tx.timestampMS = rec.TimestampMS + db.snapshot.Store(tx) + return nil +} + +func (db *Database) applyChange(tx *Snapshot, change change.Change) error { + c, ok := db.collections[change.CollectionID] + if !ok { + return nil + } + + if change.Store { + return c.upsertItem(tx, change.ItemID, change.Data) + } + + // The only error this could return is NotFound. We'll ignore that error here. + c.deleteItem(tx, change.ItemID) + return nil +} diff --git a/mdb/db-secondary.go b/mdb/db-secondary.go new file mode 100644 index 0000000..f03b01f --- /dev/null +++ b/mdb/db-secondary.go @@ -0,0 +1,129 @@ +package mdb + +/* +func (db *Database) openSecondary() (err error) { + if db.shouldLoadFromPrimary() { + if err := db.loadFromPrimary(); err != nil { + return err + } + } + + log.Printf("Opening page-file...") + + pFile, err := openPageFileAndReplayWAL(db.rootDir) + if err != nil { + return err + } + defer pFile.Close() + + pfHeader, err := pFile.ReadHeader() + if err != nil { + return err + } + + log.Printf("Building page-file index...") + + pIndex, err := pagefile.NewIndex(pFile) + if err != nil { + return err + } + + tx := db.Snapshot() + tx.seqNum = pfHeader.SeqNum + tx.updatedAt = pfHeader.UpdatedAt + + log.Printf("Loading data into memory...") + + err = pFile.IterateAllocated(pIndex, func(cID, iID uint64, data []byte) error { + return db.loadItem(tx, cID, iID, data) + }) + if err != nil { + return err + } + + log.Printf("Creating writer...") + + w, err := cswal.OpenWriter(db.walRootDir, &cswal.WriterConfig{ + SegMinCount: db.conf.WALSegMinCount, + SegMaxAgeSec: db.conf.WALSegMaxAgeSec, + }) + if err != nil { + return err + } + + db.done.Add(1) + go (&walFollower{ + Stop: db.stop, + Done: db.done, + W: w, + Client: NewClient(db.conf.PrimaryURL, db.conf.ReplicationPSK, db.conf.NetTimeout), + }).Run() + + db.done.Add(1) + go (&follower{ + Stop: db.stop, + Done: db.done, + WALRootDir: db.walRootDir, + SeqNum: pfHeader.SeqNum, + ApplyChanges: db.applyChanges, + }).Run() + + db.done.Add(1) + go (&fileWriter{ + Stop: db.stop, + Done: db.done, + PageFilePath: db.pageFilePath, + WALRootDir: db.walRootDir, + }).Run() + + return nil +} + +func (db *Database) shouldLoadFromPrimary() bool { + if _, err := os.Stat(db.walRootDir); os.IsNotExist(err) { + log.Printf("WAL doesn't exist.") + return true + } + if _, err := os.Stat(db.pageFilePath); os.IsNotExist(err) { + log.Printf("Page-file doesn't exist.") + return true + } + return false +} + +func (db *Database) loadFromPrimary() error { + client := NewClient(db.conf.PrimaryURL, db.conf.ReplicationPSK, db.conf.NetTimeout) + defer client.Disconnect() + + log.Printf("Loading data from primary...") + + if err := os.RemoveAll(db.pageFilePath); err != nil { + log.Printf("Failed to remove page-file: %s", err) + return errs.IO.WithErr(err) // Caller can retry. + } + + if err := os.RemoveAll(db.walRootDir); err != nil { + log.Printf("Failed to remove WAL: %s", err) + return errs.IO.WithErr(err) // Caller can retry. + } + + err := client.DownloadPageFile(db.pageFilePath+".tmp", db.pageFilePath) + if err != nil { + log.Printf("Failed to get page-file from primary: %s", err) + return err // Caller can retry. + } + + pfHeader, err := pagefile.ReadHeader(db.pageFilePath) + if err != nil { + log.Printf("Failed to read page-file sequence number: %s", err) + return err // Caller can retry. + } + + if err = cswal.CreateEx(db.walRootDir, pfHeader.SeqNum+1); err != nil { + log.Printf("Failed to initialize WAL: %s", err) + return err // Caller can retry. + } + + return nil +} +*/ diff --git a/mdb/db-testcases_test.go b/mdb/db-testcases_test.go new file mode 100644 index 0000000..626d4ba --- /dev/null +++ b/mdb/db-testcases_test.go @@ -0,0 +1,852 @@ +package mdb + +import ( + "errors" + "fmt" + "reflect" + "strings" + "testing" +) + +type DBTestCase struct { + Name string + Steps []DBTestStep +} + +type DBTestStep struct { + Name string + Update func(t *testing.T, db TestDB, tx *Snapshot) error + ExpectedUpdateError error + State DBState +} + +type DBState struct { + UsersByID []User + UsersByEmail []User + UsersByName []User + UsersByBlocked []User + DataByID []UserDataItem + DataByName []UserDataItem +} + +var testDBTestCases = []DBTestCase{{ + + Name: "Insert update", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + + Name: "Update", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user, ok := db.Users.ByID.Get(tx, &User{ID: 1}) + if !ok { + return ErrNotFound + } + user.Name = "Bob" + user.Email = "b@c.com" + return db.Users.Update(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Bob", Email: "b@c.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Bob", Email: "b@c.com"}}, + UsersByName: []User{{ID: 1, Name: "Bob", Email: "b@c.com"}}, + }, + }}, +}, { + + Name: "Insert delete", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + + Name: "Delete", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + return db.Users.Delete(tx, 1) + }, + + State: DBState{}, + }}, +}, { + + Name: "Insert duplicate one tx (ID)", + + Steps: []DBTestStep{{ + Name: "Insert with duplicate", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + if err := db.Users.Insert(tx, user); err != nil { + return err + } + user2 := &User{ID: 1, Name: "Bob", Email: "b@c.com"} + return db.Users.Insert(tx, user2) + }, + + ExpectedUpdateError: ErrDuplicate, + + State: DBState{}, + }}, +}, { + + Name: "Insert duplicate one tx (email)", + + Steps: []DBTestStep{{ + Name: "Insert with duplicate", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + if err := db.Users.Insert(tx, user); err != nil { + return err + } + user2 := &User{ID: 2, Name: "Bob", Email: "a@b.com"} + return db.Users.Insert(tx, user2) + }, + + ExpectedUpdateError: ErrDuplicate, + + State: DBState{}, + }}, +}, { + + Name: "Insert duplicate two txs (ID)", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + + Name: "Insert duplicate", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Bob", Email: "b@c.com"} + return db.Users.Insert(tx, user) + }, + + ExpectedUpdateError: ErrDuplicate, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Insert duplicate two txs (email)", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + + Name: "Insert duplicate", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 2, Name: "Bob", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + ExpectedUpdateError: ErrDuplicate, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Insert read-only snapshot", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(db.Snapshot(), user) + }, + + ExpectedUpdateError: ErrReadOnly, + }}, +}, { + + Name: "Insert partial index", + + Steps: []DBTestStep{{ + + Name: "Insert Alice", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 5, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Insert Bob", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{ + {ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true}, + {ID: 5, Name: "Alice", Email: "a@b.com"}, + }, + UsersByEmail: []User{ + {ID: 5, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true}, + }, + UsersByName: []User{ + {ID: 5, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true}, + }, + UsersByBlocked: []User{ + {ID: 2, Name: "Bob", Email: "b@c.com", Blocked: true}, + }, + }, + }}, +}, { + + Name: "Update not found", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 5, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Update", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 4, Name: "Alice", Email: "x@y.com"} + return db.Users.Update(tx, user) + }, + + ExpectedUpdateError: ErrNotFound, + + State: DBState{ + UsersByID: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 5, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Update read-only snapshot", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + + Name: "Update", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user, ok := db.Users.ByID.Get(tx, &User{ID: 1}) + if !ok { + return ErrNotFound + } + user.Name = "Bob" + user.Email = "b@c.com" + return db.Users.Update(db.Snapshot(), user) + }, + + ExpectedUpdateError: ErrReadOnly, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Insert into two collections", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + if err := db.Users.Insert(tx, user); err != nil { + return err + } + data := &UserDataItem{ID: 1, UserID: user.ID, Name: "Item1", Data: "xyz"} + return db.UserData.Insert(tx, data) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + DataByID: []UserDataItem{{ID: 1, UserID: 1, Name: "Item1", Data: "xyz"}}, + DataByName: []UserDataItem{{ID: 1, UserID: 1, Name: "Item1", Data: "xyz"}}, + }, + }}, +}, { + + Name: "Update into index", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Update", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true} + return db.Users.Update(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + UsersByBlocked: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + }, + }}, +}, { + + Name: "Update out of index", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + UsersByBlocked: []User{{ID: 1, Name: "Alice", Email: "a@b.com", Blocked: true}}, + }, + }, { + Name: "Update", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Update(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Update duplicate one tx", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user1 := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + if err := db.Users.Insert(tx, user1); err != nil { + return err + } + user2 := &User{ID: 2, Name: "Bob", Email: "b@c.com"} + if err := db.Users.Insert(tx, user2); err != nil { + return err + } + + user2.Email = "a@b.com" + return db.Users.Update(tx, user2) + }, + + ExpectedUpdateError: ErrDuplicate, + + State: DBState{}, + }}, +}, { + + Name: "Update duplicate two txs", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user1 := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + if err := db.Users.Insert(tx, user1); err != nil { + return err + } + user2 := &User{ID: 2, Name: "Bob", Email: "b@c.com"} + return db.Users.Insert(tx, user2) + }, + + State: DBState{ + UsersByID: []User{ + {ID: 1, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com"}, + }, + UsersByEmail: []User{ + {ID: 1, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com"}, + }, + UsersByName: []User{ + {ID: 1, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com"}, + }, + }, + }, { + + Name: "Update", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + u, ok := db.Users.ByID.Get(tx, &User{ID: 2}) + if !ok { + return ErrNotFound + } + + u.Email = "a@b.com" + return db.Users.Update(tx, u) + }, + + ExpectedUpdateError: ErrDuplicate, + + State: DBState{ + UsersByID: []User{ + {ID: 1, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com"}, + }, + UsersByEmail: []User{ + {ID: 1, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com"}, + }, + UsersByName: []User{ + {ID: 1, Name: "Alice", Email: "a@b.com"}, + {ID: 2, Name: "Bob", Email: "b@c.com"}, + }, + }, + }}, +}, { + + Name: "Delete read only", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Delete", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + return db.Users.Delete(db.Snapshot(), 1) + }, + + ExpectedUpdateError: ErrReadOnly, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Delete not found", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Delete", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + return db.Users.Delete(tx, 2) + }, + + ExpectedUpdateError: ErrNotFound, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + + Name: "Index general", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + user := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + return db.Users.Insert(tx, user) + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Get found", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + expected := &User{ID: 1, Name: "Alice", Email: "a@b.com"} + + u, ok := db.Users.ByID.Get(tx, &User{ID: 1}) + if !ok { + return ErrNotFound + } + if !reflect.DeepEqual(u, expected) { + return errors.New("Not equal (id)") + } + + u, ok = db.Users.ByEmail.Get(tx, &User{Email: "a@b.com"}) + if !ok { + return ErrNotFound + } + if !reflect.DeepEqual(u, expected) { + return errors.New("Not equal (email)") + } + + return nil + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Get not found", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + if _, ok := db.Users.ByID.Get(tx, &User{ID: 2}); ok { + return errors.New("Found (id)") + } + + if _, ok := db.Users.ByEmail.Get(tx, &User{Email: "x@b.com"}); ok { + return errors.New("Found (email)") + } + + return nil + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Has (true)", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + if ok := db.Users.ByID.Has(tx, &User{ID: 1}); !ok { + return errors.New("Not found (id)") + } + + if ok := db.Users.ByEmail.Has(tx, &User{Email: "a@b.com"}); !ok { + return errors.New("Not found (email)") + } + + return nil + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }, { + Name: "Has (false)", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + if ok := db.Users.ByID.Has(tx, &User{ID: 2}); ok { + return errors.New("Found (id)") + } + + if ok := db.Users.ByEmail.Has(tx, &User{Email: "x@b.com"}); ok { + return errors.New("Found (email)") + } + + return nil + }, + + State: DBState{ + UsersByID: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByEmail: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + UsersByName: []User{{ID: 1, Name: "Alice", Email: "a@b.com"}}, + }, + }}, +}, { + Name: "Mutate while iterating", + + Steps: []DBTestStep{{ + + Name: "Insert", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) error { + for i := 0; i < 4; i++ { + user := &User{ + ID: uint64(i) + 1, + Name: fmt.Sprintf("User%d", i), + Email: fmt.Sprintf("user.%d@x.com", i), + } + if err := db.Users.Insert(tx, user); err != nil { + return err + } + } + return nil + }, + + State: DBState{ + UsersByID: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1", Email: "user.1@x.com"}, + {ID: 3, Name: "User2", Email: "user.2@x.com"}, + {ID: 4, Name: "User3", Email: "user.3@x.com"}, + }, + UsersByEmail: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1", Email: "user.1@x.com"}, + {ID: 3, Name: "User2", Email: "user.2@x.com"}, + {ID: 4, Name: "User3", Email: "user.3@x.com"}, + }, + UsersByName: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1", Email: "user.1@x.com"}, + {ID: 3, Name: "User2", Email: "user.2@x.com"}, + {ID: 4, Name: "User3", Email: "user.3@x.com"}, + }, + }, + }, { + + Name: "Modify while iterating", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) (err error) { + + first := true + pivot := User{Name: "User1"} + db.Users.ByName.AscendAfter(tx, &pivot, func(u *User) bool { + u.Name += "Mod" + if err = db.Users.Update(tx, u); err != nil { + return false + } + if first { + first = false + return true + } + + prev, ok := db.Users.ByID.Get(tx, &User{ID: u.ID - 1}) + if !ok { + err = errors.New("Previous user not found") + return false + } + + if !strings.HasSuffix(prev.Name, "Mod") { + err = errors.New("Incorrect user name: " + prev.Name) + return false + } + + return true + }) + return nil + }, + + State: DBState{ + UsersByID: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1Mod", Email: "user.1@x.com"}, + {ID: 3, Name: "User2Mod", Email: "user.2@x.com"}, + {ID: 4, Name: "User3Mod", Email: "user.3@x.com"}, + }, + UsersByEmail: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1Mod", Email: "user.1@x.com"}, + {ID: 3, Name: "User2Mod", Email: "user.2@x.com"}, + {ID: 4, Name: "User3Mod", Email: "user.3@x.com"}, + }, + UsersByName: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1Mod", Email: "user.1@x.com"}, + {ID: 3, Name: "User2Mod", Email: "user.2@x.com"}, + {ID: 4, Name: "User3Mod", Email: "user.3@x.com"}, + }, + }, + }, { + + Name: "Iterate after modifying", + + Update: func(t *testing.T, db TestDB, tx *Snapshot) (err error) { + + u := &User{ID: 5, Name: "User4Mod", Email: "user.4@x.com"} + if err := db.Users.Insert(tx, u); err != nil { + return err + } + + first := true + db.Users.ByName.DescendAfter(tx, &User{Name: "User5Mod"}, func(u *User) bool { + u.Name = strings.TrimSuffix(u.Name, "Mod") + if err = db.Users.Update(tx, u); err != nil { + return false + } + if first { + first = false + return true + } + + prev, ok := db.Users.ByID.Get(tx, &User{ID: u.ID + 1}) + if !ok { + err = errors.New("Previous user not found") + return false + } + + if strings.HasSuffix(prev.Name, "Mod") { + err = errors.New("Incorrect user name: " + prev.Name) + return false + } + + return true + }) + return nil + }, + + State: DBState{ + UsersByID: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1", Email: "user.1@x.com"}, + {ID: 3, Name: "User2", Email: "user.2@x.com"}, + {ID: 4, Name: "User3", Email: "user.3@x.com"}, + {ID: 5, Name: "User4", Email: "user.4@x.com"}, + }, + UsersByEmail: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1", Email: "user.1@x.com"}, + {ID: 3, Name: "User2", Email: "user.2@x.com"}, + {ID: 4, Name: "User3", Email: "user.3@x.com"}, + {ID: 5, Name: "User4", Email: "user.4@x.com"}, + }, + UsersByName: []User{ + {ID: 1, Name: "User0", Email: "user.0@x.com"}, + {ID: 2, Name: "User1", Email: "user.1@x.com"}, + {ID: 3, Name: "User2", Email: "user.2@x.com"}, + {ID: 4, Name: "User3", Email: "user.3@x.com"}, + {ID: 5, Name: "User4", Email: "user.4@x.com"}, + }, + }, + }}, +}} diff --git a/mdb/db-testlist_test.go b/mdb/db-testlist_test.go new file mode 100644 index 0000000..008093a --- /dev/null +++ b/mdb/db-testlist_test.go @@ -0,0 +1,138 @@ +package mdb + +import ( + "fmt" + "reflect" + "testing" +) + +func TestDBList(t *testing.T) { + db := NewTestDBPrimary(t, t.TempDir()) + + var ( + user1 = User{ + ID: NewID(), + Name: "User1", + Email: "user1@gmail.com", + } + + user2 = User{ + ID: NewID(), + Name: "User2", + Email: "user2@gmail.com", + } + + user3 = User{ + ID: NewID(), + Name: "User3", + Email: "user3@gmail.com", + } + user1Data = make([]UserDataItem, 10) + user2Data = make([]UserDataItem, 4) + user3Data = make([]UserDataItem, 8) + ) + + err := db.Update(func(tx *Snapshot) error { + if err := db.Users.Insert(tx, &user1); err != nil { + return err + } + + if err := db.Users.Insert(tx, &user2); err != nil { + return err + } + + for i := range user1Data { + user1Data[i] = UserDataItem{ + ID: NewID(), + UserID: user1.ID, + Name: fmt.Sprintf("Name1: %d", i), + Data: fmt.Sprintf("Data: %d", i), + } + + if err := db.UserData.Insert(tx, &user1Data[i]); err != nil { + return err + } + } + + for i := range user2Data { + user2Data[i] = UserDataItem{ + ID: NewID(), + UserID: user2.ID, + Name: fmt.Sprintf("Name2: %d", i), + Data: fmt.Sprintf("Data: %d", i), + } + + if err := db.UserData.Insert(tx, &user2Data[i]); err != nil { + return err + } + } + + for i := range user3Data { + user3Data[i] = UserDataItem{ + ID: NewID(), + UserID: user3.ID, + Name: fmt.Sprintf("Name3: %d", i), + Data: fmt.Sprintf("Data: %d", i), + } + + if err := db.UserData.Insert(tx, &user3Data[i]); err != nil { + return err + } + } + + return nil + }) + + if err != nil { + t.Fatal(err) + } + + type TestCase struct { + Name string + Args ListArgs[UserDataItem] + Expected []UserDataItem + } + + cases := []TestCase{ + { + Name: "User1 all", + Args: ListArgs[UserDataItem]{ + After: &UserDataItem{ + UserID: user1.ID, + }, + While: func(item *UserDataItem) bool { + return item.UserID == user1.ID + }, + }, + Expected: user1Data, + }, { + Name: "User1 limited", + Args: ListArgs[UserDataItem]{ + After: &UserDataItem{ + UserID: user1.ID, + }, + While: func(item *UserDataItem) bool { + return item.UserID == user1.ID + }, + Limit: 4, + }, + Expected: user1Data[:4], + }, + } + + for _, tc := range cases { + t.Run(tc.Name, func(t *testing.T) { + tx := db.Snapshot() + l := db.UserData.ByName.List(tx, tc.Args, nil) + if len(l) != len(tc.Expected) { + t.Fatal(tc.Name, l) + } + + for i := range l { + if !reflect.DeepEqual(*l[i], tc.Expected[i]) { + t.Fatal(tc.Name, l) + } + } + }) + } +} diff --git a/mdb/db-testrunner_test.go b/mdb/db-testrunner_test.go new file mode 100644 index 0000000..a812c8e --- /dev/null +++ b/mdb/db-testrunner_test.go @@ -0,0 +1,164 @@ +package mdb + +import ( + "errors" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" +) + +func TestDBRunTests(t *testing.T) { + t.Helper() + for _, testCase := range testDBTestCases { + testCase := testCase + t.Run(testCase.Name, func(t *testing.T) { + t.Parallel() + testRunner_testCase(t, testCase) + }) + } +} + +func testRunner_testCase(t *testing.T, testCase DBTestCase) { + rootDir := t.TempDir() + db := NewTestDBPrimary(t, rootDir) + + mux := http.NewServeMux() + mux.HandleFunc("/rep/", db.Handle) + testServer := httptest.NewServer(mux) + defer testServer.Close() + + rootDir2 := t.TempDir() + db2 := NewTestSecondaryDB(t, rootDir2, testServer.URL+"/rep/") + defer db2.Close() + + snapshots := make([]*Snapshot, 0, len(testCase.Steps)) + + // Run each step and it's associated check function. + for _, step := range testCase.Steps { + t.Run(step.Name, func(t *testing.T) { + err := db.Update(func(tx *Snapshot) error { + return step.Update(t, db, tx) + }) + + if !errors.Is(err, step.ExpectedUpdateError) { + t.Fatal(err, step.ExpectedUpdateError) + } + + snapshot := db.Snapshot() + snapshots = append(snapshots, snapshot) + + testRunner_checkState(t, db, snapshot, step.State) + }) + } + + // Run each step's check function again with stored snapshot. + for i, step := range testCase.Steps { + snapshot := snapshots[i] + t.Run(step.Name+"-checkSnapshot", func(t *testing.T) { + testRunner_checkState(t, db, snapshot, step.State) + }) + } + + pInfo := db.Info() + + for { + info := db2.Info() + if info.SeqNum == pInfo.SeqNum { + break + } + time.Sleep(time.Millisecond) + } + + // TODO: Why is this necessary? + time.Sleep(time.Second) + finalStep := testCase.Steps[len(testCase.Steps)-1] + + secondarySnapshot := db2.Snapshot() + t.Run("Check secondary", func(t *testing.T) { + testRunner_checkState(t, db2, secondarySnapshot, finalStep.State) + }) + + if err := db.Close(); err != nil { + t.Fatal(err) + } + + db = NewTestDBPrimary(t, rootDir) + snapshot := db.Snapshot() + + // Run the final step's check function again with a newly loaded db. + t.Run("Check after reload", func(t *testing.T) { + testRunner_checkState(t, db, snapshot, finalStep.State) + }) + + t.Run("Check that primary and secondary are equal", func(t *testing.T) { + db.AssertEqual(t, db2.Database) + }) + + db.Close() +} + +func testRunner_checkState( + t *testing.T, + db TestDB, + tx *Snapshot, + state DBState, +) { + t.Helper() + checkSlicesEqual(t, "UsersByID", db.Users.ByID.Dump(tx), state.UsersByID) + checkSlicesEqual(t, "UsersByEmail", db.Users.ByEmail.Dump(tx), state.UsersByEmail) + checkSlicesEqual(t, "UsersByName", db.Users.ByName.Dump(tx), state.UsersByName) + checkSlicesEqual(t, "UsersByBlocked", db.Users.ByBlocked.Dump(tx), state.UsersByBlocked) + checkSlicesEqual(t, "DataByID", db.UserData.ByID.Dump(tx), state.DataByID) + checkSlicesEqual(t, "DataByName", db.UserData.ByName.Dump(tx), state.DataByName) + + checkMinMaxEqual(t, "UsersByID", tx, db.Users.ByID, state.UsersByID) + checkMinMaxEqual(t, "UsersByEmail", tx, db.Users.ByEmail, state.UsersByEmail) + checkMinMaxEqual(t, "UsersByName", tx, db.Users.ByName, state.UsersByName) + checkMinMaxEqual(t, "UsersByBlocked", tx, db.Users.ByBlocked, state.UsersByBlocked) + checkMinMaxEqual(t, "DataByID", tx, db.UserData.ByID, state.DataByID) + checkMinMaxEqual(t, "DataByName", tx, db.UserData.ByName, state.DataByName) +} + +func checkSlicesEqual[T any](t *testing.T, name string, actual, expected []T) { + t.Helper() + if len(actual) != len(expected) { + t.Fatal(name, len(actual), len(expected)) + } + + for i := range actual { + if !reflect.DeepEqual(actual[i], expected[i]) { + t.Fatal(name, actual[i], expected[i]) + } + } +} + +func checkMinMaxEqual[T any](t *testing.T, name string, tx *Snapshot, index Index[T], expected []T) { + if len(expected) == 0 { + if min, ok := index.Min(tx); ok { + t.Fatal(min) + } + if max, ok := index.Max(tx); ok { + t.Fatal(max) + } + return + } + + min, ok := index.Min(tx) + if !ok { + t.Fatal("No min") + } + max, ok := index.Max(tx) + if !ok { + t.Fatal("No max") + } + + if !reflect.DeepEqual(*min, expected[0]) { + t.Fatal(min, expected[0]) + } + + if !reflect.DeepEqual(*max, expected[len(expected)-1]) { + t.Fatal(max) + } +} diff --git a/mdb/db-txaggregator.go b/mdb/db-txaggregator.go new file mode 100644 index 0000000..958bd1c --- /dev/null +++ b/mdb/db-txaggregator.go @@ -0,0 +1,105 @@ +package mdb + +import ( + "bytes" + "git.crumpington.com/public/jldb/mdb/change" +) + +type txMod struct { + Update func(tx *Snapshot) error + Resp chan error +} + +func (db *Database) runTXAggreagtor() { + defer db.done.Done() + + var ( + tx *Snapshot + mod txMod + seqNum int64 + timestampMS int64 + err error + buf = &bytes.Buffer{} + toNotify = make([]chan error, 0, db.conf.MaxConcurrentUpdates) + ) + +READ_FIRST: + + toNotify = toNotify[:0] + + select { + case mod = <-db.modChan: + goto BEGIN + case <-db.stop: + goto END + } + +BEGIN: + + tx = db.snapshot.Load().begin() + goto APPLY_MOD + +CLONE: + + tx = tx.clone() + goto APPLY_MOD + +APPLY_MOD: + + if err = mod.Update(tx); err != nil { + mod.Resp <- err + goto ROLLBACK + } + + toNotify = append(toNotify, mod.Resp) + goto NEXT + +ROLLBACK: + + if len(toNotify) == 0 { + goto READ_FIRST + } + + tx = tx.rollback() + goto NEXT + +NEXT: + + select { + case mod = <-db.modChan: + goto CLONE + default: + goto WRITE + } + +WRITE: + + db.idx.StageChanges(tx.changes) + + buf.Reset() + if err = change.Write(tx.changes, buf); err != nil { + db.idx.UnstageChanges(tx.changes) + } + + if err == nil { + seqNum, timestampMS, err = db.rep.Append(int64(buf.Len()), buf) + } + + if err != nil { + db.idx.UnstageChanges(tx.changes) + } else { + db.idx.ApplyChanges(tx.changes) + tx.seqNum = seqNum + tx.timestampMS = timestampMS + tx.setReadOnly() + db.snapshot.Store(tx) + } + + for i := range toNotify { + toNotify[i] <- err + } + + goto READ_FIRST + +END: +} diff --git a/mdb/db-userdata_test.go b/mdb/db-userdata_test.go new file mode 100644 index 0000000..ce946fb --- /dev/null +++ b/mdb/db-userdata_test.go @@ -0,0 +1,36 @@ +package mdb + +import ( + "cmp" + "strings" +) + +type UserDataItem struct { + ID uint64 + UserID uint64 + Name string + Data string +} + +type UserData struct { + *Collection[UserDataItem] + ByName Index[UserDataItem] // Unique index on (Token). +} + +func NewUserDataCollection(db *Database) UserData { + userData := UserData{} + + userData.Collection = NewCollection[UserDataItem](db, "UserData", nil) + + userData.ByName = NewUniqueIndex( + userData.Collection, + "ByName", + func(lhs, rhs *UserDataItem) int { + if x := cmp.Compare(lhs.UserID, rhs.UserID); x != 0 { + return x + } + return strings.Compare(lhs.Name, rhs.Name) + }) + + return userData +} diff --git a/mdb/db-users_test.go b/mdb/db-users_test.go new file mode 100644 index 0000000..63dbcc3 --- /dev/null +++ b/mdb/db-users_test.go @@ -0,0 +1,50 @@ +package mdb + +import "strings" + +type User struct { + ID uint64 + Name string + Email string + Admin bool + Blocked bool +} + +type Users struct { + *Collection[User] + ByEmail Index[User] // Unique index on (Email). + ByName Index[User] // Index on (Name). + ByBlocked Index[User] // Partial index on (Blocked,Email). +} + +func NewUserCollection(db *Database) Users { + users := Users{} + + users.Collection = NewCollection[User](db, "Users", nil) + + users.ByEmail = NewUniqueIndex( + users.Collection, + "ByEmail", + func(lhs, rhs *User) int { + return strings.Compare(lhs.Email, rhs.Email) + }) + + users.ByName = NewIndex( + users.Collection, + "ByName", + func(lhs, rhs *User) int { + return strings.Compare(lhs.Name, rhs.Name) + }) + + users.ByBlocked = NewPartialIndex( + users.Collection, + "ByBlocked", + func(lhs, rhs *User) int { + return strings.Compare(lhs.Email, rhs.Email) + }, + func(item *User) bool { + return item.Blocked + }) + + return users +} diff --git a/mdb/db.go b/mdb/db.go new file mode 100644 index 0000000..17fe15e --- /dev/null +++ b/mdb/db.go @@ -0,0 +1,184 @@ +package mdb + +import ( + "fmt" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/lib/rep" + "git.crumpington.com/public/jldb/mdb/change" + "git.crumpington.com/public/jldb/mdb/pfile" + "net/http" + "os" + "sync" + "sync/atomic" + "time" +) + +type Config struct { + RootDir string + Primary bool + ReplicationPSK string + NetTimeout time.Duration // Default is 1 minute. + + // WAL settings. + WALSegMinCount int64 // Minimum Change sets in a segment. Default is 1024. + WALSegMaxAgeSec int64 // Maximum age of a segment. Default is 1 hour. + WALSegGCAgeSec int64 // Segment age for garbage collection. Default is 7 days. + + // Necessary for secondary. + PrimaryEndpoint string + + // MaxConcurrentUpdates restricts the number of concurently running updates, + // and also limits the maximum number of changes that may be aggregated in + // the WAL. + // + // Default is 32. + MaxConcurrentUpdates int +} + +func (c Config) repConfig() rep.Config { + return rep.Config{ + RootDir: repDirPath(c.RootDir), + Primary: c.Primary, + ReplicationPSK: c.ReplicationPSK, + NetTimeout: c.NetTimeout, + WALSegMinCount: c.WALSegMinCount, + WALSegMaxAgeSec: c.WALSegMaxAgeSec, + WALSegGCAgeSec: c.WALSegGCAgeSec, + PrimaryEndpoint: c.PrimaryEndpoint, + } +} + +type Database struct { + rep *rep.Replicator + rootDir string + conf Config + + pf *pfile.File + idx *pfile.Index + + changes []change.Change + + // The Snapshot stored here is read-only. It will be replaced as needed by + // the txAggregator (primary), or the follower (secondary). + snapshot *atomic.Pointer[Snapshot] + collections map[uint64]collection + + stop chan struct{} + done *sync.WaitGroup + + txModPool chan txMod + + modChan chan txMod +} + +func New(conf Config) *Database { + if conf.MaxConcurrentUpdates <= 0 { + conf.MaxConcurrentUpdates = 32 + } + + db := &Database{ + rootDir: conf.RootDir, + conf: conf, + snapshot: &atomic.Pointer[Snapshot]{}, + collections: map[uint64]collection{}, + stop: make(chan struct{}), + done: &sync.WaitGroup{}, + txModPool: make(chan txMod, conf.MaxConcurrentUpdates), + modChan: make(chan txMod), + } + + db.snapshot.Store(newSnapshot()) + + for i := 0; i < conf.MaxConcurrentUpdates; i++ { + db.txModPool <- txMod{Resp: make(chan error, 1)} + } + + return db +} + +func (db *Database) Open() (err error) { + if err := os.MkdirAll(db.rootDir, 0700); err != nil { + return errs.IO.WithErr(err) + } + + db.rep, err = rep.Open( + rep.App{ + SendState: db.repSendState, + RecvState: db.repRecvState, + InitStorage: db.repInitStorage, + Replay: db.repReplay, + LoadFromStorage: db.repLoadFromStorage, + Apply: db.repApply, + }, + db.conf.repConfig()) + if err != nil { + return err + } + + if db.conf.Primary { + db.done.Add(1) + go db.runTXAggreagtor() + } + return nil +} + +func (db *Database) Close() error { + select { + case <-db.stop: + return nil + default: + } + + close(db.stop) + db.rep.Close() + db.done.Wait() + + db.snapshot = nil + db.collections = nil + + return nil +} + +func (db *Database) Snapshot() *Snapshot { + return db.snapshot.Load() +} + +func (db *Database) Update(update func(tx *Snapshot) error) error { + if !db.conf.Primary { + return errs.ReadOnly.WithMsg("cannot update secondary directly") + } + + mod := <-db.txModPool + mod.Update = update + + db.modChan <- mod + err := <-mod.Resp + db.txModPool <- mod + + return err +} + +func (db *Database) Info() Info { + tx := db.Snapshot() + repInfo := db.rep.Info() + + return Info{ + SeqNum: tx.seqNum, + TimestampMS: tx.timestampMS, + WALFirstSeqNum: repInfo.WALFirstSeqNum, + WALLastSeqNum: repInfo.WALLastSeqNum, + WALLastTimestampMS: repInfo.WALLastTimestampMS, + } +} + +func (db *Database) addCollection(id uint64, c collection, collectionState any) { + if _, ok := db.collections[id]; ok { + panic(fmt.Sprintf("Collection %s uses duplicate ID %d.", c.Name(), id)) + } + db.collections[id] = c + db.snapshot.Load().addCollection(id, collectionState) +} + +func (db *Database) Handle(w http.ResponseWriter, r *http.Request) { + db.rep.Handle(w, r) +} diff --git a/mdb/db_test.go b/mdb/db_test.go new file mode 100644 index 0000000..0b97931 --- /dev/null +++ b/mdb/db_test.go @@ -0,0 +1,54 @@ +package mdb + +import ( + "testing" + "time" +) + +type TestDB struct { + *Database + Users Users + UserData UserData +} + +func NewTestDBPrimary(t *testing.T, rootDir string) TestDB { + db := New(Config{ + RootDir: rootDir, + Primary: true, + NetTimeout: 8 * time.Second, + ReplicationPSK: "123", + }) + + testDB := TestDB{ + Database: db, + Users: NewUserCollection(db), + UserData: NewUserDataCollection(db), + } + + if err := testDB.Open(); err != nil { + t.Fatal(err) + } + + return testDB +} + +func NewTestSecondaryDB(t *testing.T, rootDir, primaryURL string) TestDB { + db := New(Config{ + RootDir: rootDir, + PrimaryEndpoint: primaryURL, + NetTimeout: 8 * time.Second, + ReplicationPSK: "123", + }) + + testDB := TestDB{ + Database: db, + Users: NewUserCollection(db), + UserData: NewUserDataCollection(db), + } + + if err := testDB.Open(); err != nil { + t.Fatal(err) + } + + return testDB +} diff --git a/mdb/equality_test.go b/mdb/equality_test.go new file mode 100644 index 0000000..4e240b9 --- /dev/null +++ b/mdb/equality_test.go @@ -0,0 +1,59 @@ +package mdb + +import ( + "fmt" + "reflect" + "testing" +) + +func (i Index[T]) AssertEqual(t *testing.T, tx1, tx2 *Snapshot) { + t.Helper() + + state1 := i.getState(tx1) + state2 := i.getState(tx2) + + if state1.BTree.Len() != state2.BTree.Len() { + t.Fatalf("(%s) Unequal lengths: %d != %d", + i.name, + state1.BTree.Len(), + state2.BTree.Len()) + } + + errStr := "" + i.Ascend(tx1, func(item1 *T) bool { + item2, ok := i.Get(tx2, item1) + if !ok { + errStr = fmt.Sprintf("Indices don't match. %v not found.", item1) + return false + } + if !reflect.DeepEqual(item1, item2) { + errStr = fmt.Sprintf("%v != %v", item1, item2) + return false + } + return true + }) + + if errStr != "" { + t.Fatal(errStr) + } +} + +func (c *Collection[T]) AssertEqual(t *testing.T, tx1, tx2 *Snapshot) { + t.Helper() + c.ByID.AssertEqual(t, tx1, tx2) + + for _, idx := range c.indices { + idx.AssertEqual(t, tx1, tx2) + } +} + +func (db *Database) AssertEqual(t *testing.T, db2 *Database) { + tx1 := db.Snapshot() + tx2 := db.Snapshot() + for _, c := range db.collections { + cc := c.(interface { + AssertEqual(t *testing.T, tx1, tx2 *Snapshot) + }) + cc.AssertEqual(t, tx1, tx2) + } +} diff --git a/mdb/errors.go b/mdb/errors.go new file mode 100644 index 0000000..343379d --- /dev/null +++ b/mdb/errors.go @@ -0,0 +1,11 @@ +package mdb + +import ( + "git.crumpington.com/public/jldb/lib/errs" +) + +var ( + ErrNotFound = errs.NotFound + ErrReadOnly = errs.ReadOnly + ErrDuplicate = errs.Duplicate +) diff --git a/mdb/filewriter.go b/mdb/filewriter.go new file mode 100644 index 0000000..31a260c --- /dev/null +++ b/mdb/filewriter.go @@ -0,0 +1,100 @@ +package mdb + +/* +// The fileWriter writes changes from the WAL to the data file. It's run by the +// primary, and, for the primary, is the only way the pagefile is modified. +type fileWriter struct { + Stop chan struct{} + Done *sync.WaitGroup + PageFilePath string + WALRootDir string +} + +func (w *fileWriter) Run() { + defer w.Done.Done() + for { + w.runOnce() + select { + case <-w.Stop: + return + default: + time.Sleep(time.Second) + } + } +} + +func (w *fileWriter) runOnce() { + f, err := pagefile.Open(w.PageFilePath) + if err != nil { + w.logf("Failed to open page file: %v", err) + return + } + defer f.Close() + + header, err := w.readHeader(f) + if err != nil { + w.logf("Failed to get header from page file: %v", err) + return + } + + it, err := cswal.NewIterator(w.WALRootDir, header.SeqNum+1) + if err != nil { + w.logf("Failed to get WAL iterator: %v", err) + return + } + defer it.Close() + + for { + hasNext := it.Next(time.Second) + + select { + case <-w.Stop: + return + default: + } + + if !hasNext { + if it.Error() != nil { + w.logf("Iteration error: %v", it.Error()) + return + } + continue + } + + rec := it.Record() + if err := w.applyChanges(f, rec); err != nil { + w.logf("Failed to apply changes: %v", err) + return + } + } +} + +func (w *fileWriter) readHeader(f *pagefile.File) (pagefile.Header, error) { + defer f.RLock()() + return f.ReadHeader() +} + +func (w *fileWriter) applyChanges(f *pagefile.File, rec *cswal.Record) error { + defer f.WLock()() + + if err := f.ApplyChanges(rec.Changes); err != nil { + w.logf("Failed to apply changes to page file: %v", err) + return err + } + + header := pagefile.Header{ + SeqNum: rec.SeqNum, + UpdatedAt: rec.CreatedAt, + } + + if err := f.WriteHeader(header); err != nil { + w.logf("Failed to write page file header: %v", err) + return err + } + return nil +} + +func (w *fileWriter) logf(pattern string, args ...interface{}) { + log.Printf("[FILE-WRITER] "+pattern, args...) +} +*/ diff --git a/mdb/follower.go b/mdb/follower.go new file mode 100644 index 0000000..088562d --- /dev/null +++ b/mdb/follower.go @@ -0,0 +1,68 @@ +package mdb + +/* +type follower struct { + Stop chan struct{} + Done *sync.WaitGroup + WALRootDir string + SeqNum uint64 // Current max applied sequence number. + ApplyChanges func(rec *cswal.Record) error + + seqNum uint64 // Current max applied sequence number. +} + +func (f *follower) Run() { + defer f.Done.Done() + + f.seqNum = f.SeqNum + + for { + f.runOnce() + select { + case <-f.Stop: + return + default: + // Something went wrong. + time.Sleep(time.Second) + } + } +} + +func (f *follower) runOnce() { + it, err := cswal.NewIterator(f.WALRootDir, f.seqNum+1) + if err != nil { + f.logf("Failed to get WAL iterator: %v", errs.FmtDetails(err)) + return + } + defer it.Close() + + for { + hasNext := it.Next(time.Second) + + select { + case <-f.Stop: + return + default: + } + + if !hasNext { + if it.Error() != nil { + f.logf("Iteration error: %v", errs.FmtDetails(it.Error())) + return + } + continue + } + + rec := it.Record() + if err := f.ApplyChanges(rec); err != nil { + f.logf("Failed to apply changes: %s", errs.FmtDetails(err)) + return + } + f.seqNum = rec.SeqNum + } +} + +func (f *follower) logf(pattern string, args ...interface{}) { + log.Printf("[FOLLOWER] "+pattern, args...) +} +*/ diff --git a/mdb/functions.go b/mdb/functions.go new file mode 100644 index 0000000..25233c0 --- /dev/null +++ b/mdb/functions.go @@ -0,0 +1,13 @@ +package mdb + +import ( + "path/filepath" +) + +func pageFilePath(rootDir string) string { + return filepath.Join(rootDir, "pagefile") +} + +func repDirPath(rootDir string) string { + return filepath.Join(rootDir, "rep") +} diff --git a/mdb/functions_test.go b/mdb/functions_test.go new file mode 100644 index 0000000..003538e --- /dev/null +++ b/mdb/functions_test.go @@ -0,0 +1,12 @@ +package mdb + +import ( + "testing" +) + +func TestPageFilePath(t *testing.T) { + pageFilePath := pageFilePath("/tmp") + if pageFilePath != "/tmp/pagefile" { + t.Fatal(pageFilePath) + } +} diff --git a/mdb/id.go b/mdb/id.go new file mode 100644 index 0000000..a0160f6 --- /dev/null +++ b/mdb/id.go @@ -0,0 +1,8 @@ +package mdb + +import "git.crumpington.com/public/jldb/lib/idgen" + +// Safely generate a new ID. +func NewID() uint64 { + return idgen.Next() +} diff --git a/mdb/index-internal.go b/mdb/index-internal.go new file mode 100644 index 0000000..dd6f8e2 --- /dev/null +++ b/mdb/index-internal.go @@ -0,0 +1,11 @@ +package mdb + +import "github.com/google/btree" + +type indexState[T any] struct { + BTree *btree.BTreeG[*T] +} + +func (i indexState[T]) clone() indexState[T] { + return indexState[T]{BTree: i.BTree.Clone()} +} diff --git a/mdb/index.go b/mdb/index.go new file mode 100644 index 0000000..ed769ee --- /dev/null +++ b/mdb/index.go @@ -0,0 +1,236 @@ +package mdb + +import ( + "unsafe" + + "github.com/google/btree" +) + +func NewIndex[T any]( + c *Collection[T], + name string, + compare func(lhs, rhs *T) int, +) Index[T] { + return c.addIndex(indexConfig[T]{ + Name: name, + Unique: false, + Compare: compare, + Include: nil, + }) +} + +func NewPartialIndex[T any]( + c *Collection[T], + name string, + compare func(lhs, rhs *T) int, + include func(*T) bool, +) Index[T] { + return c.addIndex(indexConfig[T]{ + Name: name, + Unique: false, + Compare: compare, + Include: include, + }) +} + +func NewUniqueIndex[T any]( + c *Collection[T], + name string, + compare func(lhs, rhs *T) int, +) Index[T] { + return c.addIndex(indexConfig[T]{ + Name: name, + Unique: true, + Compare: compare, + Include: nil, + }) +} + +func NewUniquePartialIndex[T any]( + c *Collection[T], + name string, + compare func(lhs, rhs *T) int, + include func(*T) bool, +) Index[T] { + return c.addIndex(indexConfig[T]{ + Name: name, + Unique: true, + Compare: compare, + Include: include, + }) +} + +// ---------------------------------------------------------------------------- + +type Index[T any] struct { + name string + collectionID uint64 + indexID uint64 + include func(*T) bool + copy func(*T) *T +} + +func (i Index[T]) Get(tx *Snapshot, in *T) (item *T, ok bool) { + tPtr, ok := i.get(tx, in) + if !ok { + return item, false + } + return i.copy(tPtr), true +} + +func (i Index[T]) get(tx *Snapshot, in *T) (*T, bool) { + return i.btree(tx).Get(in) +} + +func (i Index[T]) Has(tx *Snapshot, in *T) bool { + return i.btree(tx).Has(in) +} + +func (i Index[T]) Min(tx *Snapshot) (item *T, ok bool) { + tPtr, ok := i.btree(tx).Min() + if !ok { + return item, false + } + return i.copy(tPtr), true +} + +func (i Index[T]) Max(tx *Snapshot) (item *T, ok bool) { + tPtr, ok := i.btree(tx).Max() + if !ok { + return item, false + } + return i.copy(tPtr), true +} + +func (i Index[T]) Ascend(tx *Snapshot, each func(*T) bool) { + i.btreeForIter(tx).Ascend(func(t *T) bool { + return each(i.copy(t)) + }) +} + +func (i Index[T]) AscendAfter(tx *Snapshot, after *T, each func(*T) bool) { + i.btreeForIter(tx).AscendGreaterOrEqual(after, func(t *T) bool { + return each(i.copy(t)) + }) +} + +func (i Index[T]) Descend(tx *Snapshot, each func(*T) bool) { + i.btreeForIter(tx).Descend(func(t *T) bool { + return each(i.copy(t)) + }) +} + +func (i Index[T]) DescendAfter(tx *Snapshot, after *T, each func(*T) bool) { + i.btreeForIter(tx).DescendLessOrEqual(after, func(t *T) bool { + return each(i.copy(t)) + }) +} + +type ListArgs[T any] struct { + Desc bool // True for descending order, otherwise ascending. + After *T // If after is given, iterate after (and including) the value. + While func(*T) bool // Continue iterating until While is false. + Limit int // Maximum number of items to return. 0 => All. +} + +func (i Index[T]) List(tx *Snapshot, args ListArgs[T], out []*T) []*T { + if args.Limit < 0 { + return nil + } + + if args.While == nil { + args.While = func(*T) bool { return true } + } + + size := args.Limit + if size == 0 { + size = 32 // Why not? + } + + items := out[:0] + + each := func(item *T) bool { + if !args.While(item) { + return false + } + items = append(items, item) + return args.Limit == 0 || len(items) < args.Limit + } + + if args.Desc { + if args.After != nil { + i.DescendAfter(tx, args.After, each) + } else { + i.Descend(tx, each) + } + } else { + if args.After != nil { + i.AscendAfter(tx, args.After, each) + } else { + i.Ascend(tx, each) + } + } + + return items +} + +// ---------------------------------------------------------------------------- + +func (i Index[T]) insertConflict(tx *Snapshot, item *T) bool { + return i.btree(tx).Has(item) +} + +func (i Index[T]) updateConflict(tx *Snapshot, item *T) bool { + current, ok := i.btree(tx).Get(item) + return ok && i.getID(current) != i.getID(item) +} + +// This should only be called after insertConflict. Additionally, the caller +// should ensure that the index has been properly cloned for write before +// writing. +func (i Index[T]) insert(tx *Snapshot, item *T) { + if i.include != nil && !i.include(item) { + return + } + + i.btree(tx).ReplaceOrInsert(item) +} + +func (i Index[T]) update(tx *Snapshot, old, new *T) { + bt := i.btree(tx) + bt.Delete(old) + + // The insert call will also check the include function if available. + i.insert(tx, new) +} + +func (i Index[T]) delete(tx *Snapshot, item *T) { + i.btree(tx).Delete(item) +} + +// ---------------------------------------------------------------------------- + +func (i Index[T]) getState(tx *Snapshot) indexState[T] { + return tx.collections[i.collectionID].(*collectionState[T]).Indices[i.indexID] +} + +// Get the current btree for get/has/update/delete, etc. +func (i Index[T]) btree(tx *Snapshot) *btree.BTreeG[*T] { + return i.getState(tx).BTree +} + +func (i Index[T]) btreeForIter(tx *Snapshot) *btree.BTreeG[*T] { + cState := tx.collections[i.collectionID].(*collectionState[T]) + bt := cState.Indices[i.indexID].BTree + + // If snapshot and index are writable, return a clone. + if tx.writable() && cState.Version == tx.version { + bt = bt.Clone() + } + + return bt +} + +func (i Index[T]) getID(t *T) uint64 { + return *((*uint64)(unsafe.Pointer(t))) +} diff --git a/mdb/index_test.go b/mdb/index_test.go new file mode 100644 index 0000000..df7576c --- /dev/null +++ b/mdb/index_test.go @@ -0,0 +1,9 @@ +package mdb + +func (i Index[T]) Dump(tx *Snapshot) (l []T) { + i.Ascend(tx, func(t *T) bool { + l = append(l, *t) + return true + }) + return l +} diff --git a/mdb/info.go b/mdb/info.go new file mode 100644 index 0000000..7130984 --- /dev/null +++ b/mdb/info.go @@ -0,0 +1,11 @@ +package mdb + +type Info struct { + SeqNum int64 // In-memory sequence number. + TimestampMS int64 // In-memory timestamp. + FileSeqNum int64 // Page file sequence number. + FileTimestampMS int64 // Page file timestamp. + WALFirstSeqNum int64 // WAL min sequence number. + WALLastSeqNum int64 // WAL max sequence number. + WALLastTimestampMS int64 // WAL timestamp. +} diff --git a/mdb/pfile/alloclist.go b/mdb/pfile/alloclist.go new file mode 100644 index 0000000..7ee4b09 --- /dev/null +++ b/mdb/pfile/alloclist.go @@ -0,0 +1,57 @@ +package pfile + +import "slices" + +type allocList map[[2]uint64][]uint64 + +func newAllocList() *allocList { + al := allocList(map[[2]uint64][]uint64{}) + return &al +} + +func (al allocList) Create(collectionID, itemID, page uint64) { + key := al.key(collectionID, itemID) + al[key] = []uint64{page} +} + +// Push is used to add pages to the storage when loading. It will append +// pages to the appropriate list, or return false if the list isn't found. +func (al allocList) Push(collectionID, itemID, page uint64) bool { + key := al.key(collectionID, itemID) + if _, ok := al[key]; !ok { + return false + } + al[key] = append(al[key], page) + return true +} + +func (al allocList) Store(collectionID, itemID uint64, pages []uint64) { + key := al.key(collectionID, itemID) + al[key] = slices.Clone(pages) +} + +func (al allocList) Remove(collectionID, itemID uint64) []uint64 { + key := al.key(collectionID, itemID) + pages := al[key] + delete(al, key) + return pages +} + +func (al allocList) Iterate( + each func(collectionID, itemID uint64, pages []uint64) error, +) error { + for key, pages := range al { + if err := each(key[0], key[1], pages); err != nil { + return err + } + } + return nil +} + +func (al allocList) Len() int { + return len(al) +} + +func (al allocList) key(collectionID, itemID uint64) [2]uint64 { + return [2]uint64{collectionID, itemID} +} diff --git a/mdb/pfile/alloclist_test.go b/mdb/pfile/alloclist_test.go new file mode 100644 index 0000000..0ad0893 --- /dev/null +++ b/mdb/pfile/alloclist_test.go @@ -0,0 +1,172 @@ +package pfile + +import ( + "errors" + "reflect" + "testing" +) + +func (al allocList) Assert(t *testing.T, state map[[2]uint64][]uint64) { + t.Helper() + + if len(al) != len(state) { + t.Fatalf("Expected %d items, but found %d.", len(state), len(al)) + } + + for key, expected := range state { + val, ok := al[key] + if !ok { + t.Fatalf("Expected to find key %v.", key) + } + if !reflect.DeepEqual(val, expected) { + t.Fatalf("For %v, expected %v but got %v.", key, expected, val) + } + } +} + +func (al *allocList) With(collectionID, itemID uint64, pages ...uint64) *allocList { + al.Store(collectionID, itemID, pages) + return al +} + +func (al *allocList) Equals(rhs *allocList) bool { + + if len(*rhs) != len(*al) { + return false + } + + for key, val := range *rhs { + actual := (*al)[key] + if !reflect.DeepEqual(val, actual) { + return false + } + } + return true +} + +func TestAllocList(t *testing.T) { + const ( + CREATE = "CREATE" + PUSH = "PUSH" + STORE = "STORE" + REMOVE = "REMOVE" + ) + + type TestCase struct { + Name string + Action string + Key [2]uint64 + Page uint64 + Pages []uint64 // For STORE command. + Expected *allocList + ExpectedLen int + } + + testCases := []TestCase{{ + Name: "Create something", + Action: CREATE, + Key: [2]uint64{1, 1}, + Page: 1, + Expected: newAllocList().With(1, 1, 1), + ExpectedLen: 1, + }, { + Name: "Push onto something", + Action: PUSH, + Key: [2]uint64{1, 1}, + Page: 2, + Expected: newAllocList().With(1, 1, 1, 2), + ExpectedLen: 1, + }, { + Name: "Push onto something again", + Action: PUSH, + Key: [2]uint64{1, 1}, + Page: 3, + Expected: newAllocList().With(1, 1, 1, 2, 3), + ExpectedLen: 1, + }, { + Name: "Store something", + Action: STORE, + Key: [2]uint64{2, 2}, + Pages: []uint64{4, 5, 6}, + Expected: newAllocList().With(1, 1, 1, 2, 3).With(2, 2, 4, 5, 6), + ExpectedLen: 2, + }, { + Name: "Remove something", + Action: REMOVE, + Key: [2]uint64{1, 1}, + Expected: newAllocList().With(2, 2, 4, 5, 6), + ExpectedLen: 1, + }} + + al := newAllocList() + + for _, tc := range testCases { + switch tc.Action { + case CREATE: + al.Create(tc.Key[0], tc.Key[1], tc.Page) + + case PUSH: + al.Push(tc.Key[0], tc.Key[1], tc.Page) + + case STORE: + al.Store(tc.Key[0], tc.Key[1], tc.Pages) + + case REMOVE: + al.Remove(tc.Key[0], tc.Key[1]) + + default: + t.Fatalf("Unknown action: %s", tc.Action) + } + + if !al.Equals(tc.Expected) { + t.Fatal(tc.Name, al, tc.Expected) + } + + if al.Len() != tc.ExpectedLen { + t.Fatal(tc.Name, al.Len(), tc.ExpectedLen) + } + } +} + +func TestAllocListIterate_eachError(t *testing.T) { + al := newAllocList().With(1, 1, 2, 3, 4, 5) + myErr := errors.New("xxx") + err := al.Iterate(func(collectionID, itemID uint64, pageIDs []uint64) error { + return myErr + }) + if err != myErr { + t.Fatal(err) + } +} + +func TestAllocListIterate(t *testing.T) { + al := newAllocList().With(1, 1, 2, 3, 4, 5).With(2, 2, 6, 7) + expected := map[uint64][]uint64{ + 1: {2, 3, 4, 5}, + 2: {6, 7}, + } + + err := al.Iterate(func(collectionID, itemID uint64, pageIDs []uint64) error { + e, ok := expected[collectionID] + if !ok { + t.Fatalf("Not found: %d", collectionID) + } + if !reflect.DeepEqual(e, pageIDs) { + t.Fatalf("%v != %v", pageIDs, e) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestAllocListPushNoHead(t *testing.T) { + al := newAllocList().With(1, 1, 2, 3, 4, 5).With(2, 2, 6, 7) + if !al.Push(1, 1, 8) { + t.Fatal("Failed to push onto head page") + } + if al.Push(1, 2, 9) { + t.Fatal("Pushed with no head.") + } +} diff --git a/mdb/pfile/change_test.go b/mdb/pfile/change_test.go new file mode 100644 index 0000000..3cd4502 --- /dev/null +++ b/mdb/pfile/change_test.go @@ -0,0 +1,58 @@ +package pfile + +import ( + crand "crypto/rand" + "git.crumpington.com/public/jldb/mdb/change" + "math/rand" +) + +func randomChangeList() (changes []change.Change) { + count := 1 + rand.Intn(8) + for i := 0; i < count; i++ { + change := change.Change{ + CollectionID: 1 + uint64(rand.Int63n(10)), + ItemID: 1 + uint64(rand.Int63n(10)), + } + + if rand.Float32() < 0.95 { + change.Data = randBytes(1 + rand.Intn(pageDataSize*4)) + change.Store = true + } + + changes = append(changes, change) + } + + return changes +} + +type changeListBuilder []change.Change + +func (b *changeListBuilder) Clear() *changeListBuilder { + *b = (*b)[:0] + return b +} + +func (b *changeListBuilder) Store(cID, iID, dataSize uint64) *changeListBuilder { + data := make([]byte, dataSize) + crand.Read(data) + *b = append(*b, change.Change{ + CollectionID: cID, + ItemID: iID, + Store: true, + Data: data, + }) + return b +} + +func (b *changeListBuilder) Delete(cID, iID uint64) *changeListBuilder { + *b = append(*b, change.Change{ + CollectionID: cID, + ItemID: iID, + Store: false, + }) + return b +} + +func (b *changeListBuilder) Build() []change.Change { + return *b +} diff --git a/mdb/pfile/freelist.go b/mdb/pfile/freelist.go new file mode 100644 index 0000000..3070e2d --- /dev/null +++ b/mdb/pfile/freelist.go @@ -0,0 +1,67 @@ +package pfile + +import "container/heap" + +// ---------------------------------------------------------------------------- +// The intHeap is used to store the free list. +// ---------------------------------------------------------------------------- + +type intHeap []uint64 + +func (h intHeap) Len() int { return len(h) } +func (h intHeap) Less(i, j int) bool { return h[i] < h[j] } +func (h intHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } + +func (h *intHeap) Push(x any) { + // Push and Pop use pointer receivers because they modify the slice's length, + // not just its contents. + *h = append(*h, x.(uint64)) +} + +func (h *intHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} + +// ---------------------------------------------------------------------------- +// Free list +// ---------------------------------------------------------------------------- + +type freeList struct { + h intHeap + nextPage uint64 +} + +// newFreeList creates a new free list that will return available pages from +// smallest to largest. If there are no available pages, it will return new +// pages starting from nextPage. +func newFreeList(pageCount uint64) *freeList { + return &freeList{ + h: []uint64{}, + nextPage: pageCount, + } +} + +func (f *freeList) Push(pages ...uint64) { + for _, page := range pages { + heap.Push(&f.h, page) + } +} + +func (f *freeList) Pop(count int, out []uint64) []uint64 { + out = out[:0] + + for len(out) < count && len(f.h) > 0 { + out = append(out, heap.Pop(&f.h).(uint64)) + } + + for len(out) < count { + out = append(out, f.nextPage) + f.nextPage++ + } + + return out +} diff --git a/mdb/pfile/freelist_test.go b/mdb/pfile/freelist_test.go new file mode 100644 index 0000000..2ff1839 --- /dev/null +++ b/mdb/pfile/freelist_test.go @@ -0,0 +1,90 @@ +package pfile + +import ( + "math/rand" + "reflect" + "testing" +) + +func (fl *freeList) Assert(t *testing.T, pageIDs ...uint64) { + t.Helper() + + if len(fl.h) != len(pageIDs) { + t.Fatalf("FreeList: Expected %d pages but got %d.\n%v != %v", + len(pageIDs), len(fl.h), fl.h, pageIDs) + } + + containsPageID := func(pageID uint64) bool { + for _, v := range fl.h { + if v == pageID { + return true + } + } + return false + } + + for _, pageID := range pageIDs { + if !containsPageID(pageID) { + t.Fatalf("Page not free: %d", pageID) + } + } +} + +func TestFreeList(t *testing.T) { + t.Parallel() + + p0 := uint64(1 + rand.Int63()) + + type TestCase struct { + Name string + Put []uint64 + Alloc int + Expected []uint64 + } + + testCases := []TestCase{ + { + Name: "Alloc first page", + Put: []uint64{}, + Alloc: 1, + Expected: []uint64{p0}, + }, { + Name: "Alloc second page", + Put: []uint64{}, + Alloc: 1, + Expected: []uint64{p0 + 1}, + }, { + Name: "Put second page", + Put: []uint64{p0 + 1}, + Alloc: 0, + Expected: []uint64{}, + }, { + Name: "Alloc 2 pages", + Put: []uint64{}, + Alloc: 2, + Expected: []uint64{p0 + 1, p0 + 2}, + }, { + Name: "Put back and alloc pages", + Put: []uint64{p0}, + Alloc: 3, + Expected: []uint64{p0, p0 + 3, p0 + 4}, + }, { + Name: "Put back large and alloc", + Put: []uint64{p0, p0 + 2, p0 + 4, p0 + 442}, + Alloc: 4, + Expected: []uint64{p0, p0 + 2, p0 + 4, p0 + 442}, + }, + } + + fl := newFreeList(p0) + + var pages []uint64 + + for _, tc := range testCases { + fl.Push(tc.Put...) + pages = fl.Pop(tc.Alloc, pages) + if !reflect.DeepEqual(pages, tc.Expected) { + t.Fatal(tc.Name, pages, tc.Expected) + } + } +} diff --git a/mdb/pfile/header.go b/mdb/pfile/header.go new file mode 100644 index 0000000..48e2e79 --- /dev/null +++ b/mdb/pfile/header.go @@ -0,0 +1 @@ +package pfile diff --git a/mdb/pfile/index.go b/mdb/pfile/index.go new file mode 100644 index 0000000..74d4caa --- /dev/null +++ b/mdb/pfile/index.go @@ -0,0 +1,105 @@ +package pfile + +import ( + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/mdb/change" +) + +type Index struct { + fList *freeList + aList allocList + seen map[[2]uint64]struct{} + mask []bool +} + +func NewIndex(f *File) (*Index, error) { + idx := &Index{ + fList: newFreeList(0), + aList: *newAllocList(), + seen: map[[2]uint64]struct{}{}, + mask: []bool{}, + } + + err := f.iterate(func(pageID uint64, page dataPage) error { + header := page.Header() + switch header.PageType { + case pageTypeHead: + idx.aList.Create(header.CollectionID, header.ItemID, pageID) + case pageTypeData: + if !idx.aList.Push(header.CollectionID, header.ItemID, pageID) { + return errs.Corrupt.WithMsg("encountered data page with no corresponding head page") + } + case pageTypeFree: + idx.fList.Push(pageID) + } + return nil + }) + + return idx, err +} + +func (idx *Index) StageChanges(changes []change.Change) { + clear(idx.seen) + if cap(idx.mask) < len(changes) { + idx.mask = make([]bool, len(changes)) + } + idx.mask = idx.mask[:len(changes)] + + for i := len(changes) - 1; i >= 0; i-- { + key := [2]uint64{changes[i].CollectionID, changes[i].ItemID} + if _, ok := idx.seen[key]; ok { + idx.mask[i] = false + continue + } + + idx.seen[key] = struct{}{} + idx.mask[i] = true + } + + for i, active := range idx.mask { + if !active { + continue + } + + if changes[i].Store { + count := idx.getPageCountForData(len(changes[i].Data)) + changes[i].WritePageIDs = idx.fList.Pop(count, changes[i].WritePageIDs) + } + + if pages := idx.aList.Remove(changes[i].CollectionID, changes[i].ItemID); pages != nil { + changes[i].ClearPageIDs = pages + } + } +} + +func (idx *Index) UnstageChanges(changes []change.Change) { + for i := range changes { + if len(changes[i].WritePageIDs) > 0 { + idx.fList.Push(changes[i].WritePageIDs...) + changes[i].WritePageIDs = changes[i].WritePageIDs[:0] + } + if len(changes[i].ClearPageIDs) > 0 { + idx.aList.Store(changes[i].CollectionID, changes[i].ItemID, changes[i].ClearPageIDs) + changes[i].ClearPageIDs = changes[i].ClearPageIDs[:0] + } + } +} + +func (idx *Index) ApplyChanges(changes []change.Change) { + for i := range changes { + if len(changes[i].WritePageIDs) > 0 { + idx.aList.Store(changes[i].CollectionID, changes[i].ItemID, changes[i].WritePageIDs) + } + if len(changes[i].ClearPageIDs) > 0 { + idx.fList.Push(changes[i].ClearPageIDs...) + } + } +} + +func (idx *Index) getPageCountForData(dataSize int) int { + count := dataSize / pageDataSize + if dataSize%pageDataSize != 0 { + count++ + } + return count +} diff --git a/mdb/pfile/index_test.go b/mdb/pfile/index_test.go new file mode 100644 index 0000000..0c699b7 --- /dev/null +++ b/mdb/pfile/index_test.go @@ -0,0 +1,139 @@ +package pfile + +import ( + "testing" +) + +type IndexState struct { + FreeList []uint64 + AllocList map[[2]uint64][]uint64 +} + +func (idx *Index) Assert(t *testing.T, state IndexState) { + t.Helper() + + idx.fList.Assert(t, state.FreeList...) + idx.aList.Assert(t, state.AllocList) +} + +func TestIndex(t *testing.T) { + pf, idx := newForTesting(t) + defer pf.Close() + + idx.Assert(t, IndexState{ + FreeList: []uint64{}, + AllocList: map[[2]uint64][]uint64{}, + }) + + p0 := uint64(0) + + l := (&changeListBuilder{}). + Store(1, 1, pageDataSize+1). + Build() + + idx.StageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{}, + AllocList: map[[2]uint64][]uint64{}, + }) + + // Unstage a change: free-list gets pages back. + idx.UnstageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{p0, p0 + 1}, + AllocList: map[[2]uint64][]uint64{}, + }) + + // Stage a change: free-list entries are used again. + l = (*changeListBuilder)(&l). + Clear(). + Store(1, 1, pageDataSize+1). + Store(2, 2, pageDataSize-1). + Store(3, 3, pageDataSize). + Build() + + idx.StageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{}, + AllocList: map[[2]uint64][]uint64{}, + }) + + // Apply changes: alloc-list is updated. + idx.ApplyChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{}, + AllocList: map[[2]uint64][]uint64{ + {1, 1}: {p0, p0 + 1}, + {2, 2}: {p0 + 2}, + {3, 3}: {p0 + 3}, + }, + }) + + // Clear some things. + l = (*changeListBuilder)(&l). + Clear(). + Store(1, 1, pageDataSize). + Delete(2, 2). + Build() + + idx.StageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{}, + AllocList: map[[2]uint64][]uint64{ + {3, 3}: {p0 + 3}, + }, + }) + + // Ustaging will push the staged page p0+4 into the free list. + idx.UnstageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{p0 + 4}, + AllocList: map[[2]uint64][]uint64{ + {1, 1}: {p0, p0 + 1}, + {2, 2}: {p0 + 2}, + {3, 3}: {p0 + 3}, + }, + }) + + idx.StageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{}, + AllocList: map[[2]uint64][]uint64{ + {3, 3}: {p0 + 3}, + }, + }) + + idx.ApplyChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{p0, p0 + 1, p0 + 2}, + AllocList: map[[2]uint64][]uint64{ + {1, 1}: {p0 + 4}, + {3, 3}: {p0 + 3}, + }, + }) + + // Duplicate updates. + l = (*changeListBuilder)(&l). + Clear(). + Store(2, 2, pageDataSize). + Store(3, 3, pageDataSize+1). + Store(3, 3, pageDataSize). + Build() + + idx.StageChanges(l) + + idx.Assert(t, IndexState{ + FreeList: []uint64{p0 + 2}, + AllocList: map[[2]uint64][]uint64{ + {1, 1}: {p0 + 4}, + }, + }) +} diff --git a/mdb/pfile/iterate.go b/mdb/pfile/iterate.go new file mode 100644 index 0000000..bcf3b77 --- /dev/null +++ b/mdb/pfile/iterate.go @@ -0,0 +1,18 @@ +package pfile + +import "bytes" + +func IterateAllocated( + pf *File, + idx *Index, + each func(collectionID, itemID uint64, data []byte) error, +) error { + buf := &bytes.Buffer{} + return idx.aList.Iterate(func(collectionID, itemID uint64, pages []uint64) error { + buf.Reset() + if err := pf.readData(pages[0], buf); err != nil { + return err + } + return each(collectionID, itemID, buf.Bytes()) + }) +} diff --git a/mdb/pfile/main_test.go b/mdb/pfile/main_test.go new file mode 100644 index 0000000..a673804 --- /dev/null +++ b/mdb/pfile/main_test.go @@ -0,0 +1,64 @@ +package pfile + +import ( + "bytes" + crand "crypto/rand" + "git.crumpington.com/public/jldb/lib/wal" + "git.crumpington.com/public/jldb/mdb/change" + "path/filepath" + "testing" +) + +func newForTesting(t *testing.T) (*File, *Index) { + t.Helper() + + tmpDir := t.TempDir() + filePath := filepath.Join(tmpDir, "pagefile") + + pf, err := Open(filePath) + if err != nil { + t.Fatal(err) + } + + idx, err := NewIndex(pf) + if err != nil { + t.Fatal(err) + } + + return pf, idx +} + +func randBytes(size int) []byte { + buf := make([]byte, size) + if _, err := crand.Read(buf); err != nil { + panic(err) + } + return buf +} + +func changesToRec(changes []change.Change) wal.Record { + buf := &bytes.Buffer{} + if err := change.Write(changes, buf); err != nil { + panic(err) + } + return wal.Record{ + DataSize: int64(buf.Len()), + Reader: buf, + } +} + +func TestChangesToRec(t *testing.T) { + changes := []change.Change{ + { + CollectionID: 2, + ItemID: 3, + Store: true, + Data: []byte{2, 3, 4}, + WritePageIDs: []uint64{0, 1}, + ClearPageIDs: []uint64{2, 3}, + }, + } + rec := changesToRec(changes) + c2 := []change.Change{} + c2, _ = change.Read(c2, rec.Reader) +} diff --git a/mdb/pfile/page.go b/mdb/pfile/page.go new file mode 100644 index 0000000..bc02d4f --- /dev/null +++ b/mdb/pfile/page.go @@ -0,0 +1,70 @@ +package pfile + +import ( + "hash/crc32" + "git.crumpington.com/public/jldb/lib/errs" + "unsafe" +) + +// ---------------------------------------------------------------------------- + +const ( + pageSize = 512 + pageHeaderSize = 40 + pageDataSize = pageSize - pageHeaderSize + + pageTypeFree = 0 + pageTypeHead = 1 + pageTypeData = 2 +) + +var emptyPage = func() dataPage { + p := newDataPage() + h := p.Header() + h.CRC = p.ComputeCRC() + return p +}() + +// ---------------------------------------------------------------------------- + +type pageHeader struct { + CRC uint32 // IEEE CRC-32 checksum. + PageType uint32 // One of the PageType* constants. + CollectionID uint64 // + ItemID uint64 + DataSize uint64 + NextPage uint64 +} + +// ---------------------------------------------------------------------------- + +type dataPage []byte + +func newDataPage() dataPage { + p := dataPage(make([]byte, pageSize)) + return p +} + +func (p dataPage) Header() *pageHeader { + return (*pageHeader)(unsafe.Pointer(&p[0])) +} + +func (p dataPage) ComputeCRC() uint32 { + return crc32.ChecksumIEEE(p[4:]) +} + +func (p dataPage) Data() []byte { + return p[pageHeaderSize:] +} + +func (p dataPage) Write(data []byte) int { + return copy(p[pageHeaderSize:], data) +} + +func (p dataPage) Validate() error { + header := p.Header() + if header.CRC != p.ComputeCRC() { + return errs.Corrupt.WithMsg("CRC mismatch on data page.") + } + return nil +} diff --git a/mdb/pfile/page_test.go b/mdb/pfile/page_test.go new file mode 100644 index 0000000..0edd52d --- /dev/null +++ b/mdb/pfile/page_test.go @@ -0,0 +1,103 @@ +package pfile + +import ( + "bytes" + crand "crypto/rand" + "git.crumpington.com/public/jldb/lib/errs" + "math/rand" + "testing" +) + +func randomPage(t *testing.T) dataPage { + p := newDataPage() + h := p.Header() + + x := rand.Float32() + if x > 0.66 { + h.PageType = pageTypeFree + h.DataSize = 0 + } else if x < 0.33 { + h.PageType = pageTypeHead + h.DataSize = rand.Uint64() + } else { + h.PageType = pageTypeData + h.DataSize = rand.Uint64() + } + + h.CollectionID = rand.Uint64() + h.ItemID = rand.Uint64() + + dataSize := h.DataSize + if h.DataSize > pageDataSize { + dataSize = pageDataSize + } + + if _, err := crand.Read(p.Data()[:dataSize]); err != nil { + t.Fatal(err) + } + + h.CRC = p.ComputeCRC() + + return p +} + +// ---------------------------------------------------------------------------- + +func TestPageValidate(t *testing.T) { + for i := 0; i < 100; i++ { + p := randomPage(t) + + // Should be valid initially. + if err := p.Validate(); err != nil { + t.Fatal(err) + } + + for i := 0; i < pageSize; i++ { + p[i]++ + if err := p.Validate(); !errs.Corrupt.Is(err) { + t.Fatal(err) + } + p[i]-- + } + + // Should be valid initially. + if err := p.Validate(); err != nil { + t.Fatal(err) + } + } +} + +func TestPageEmptyIsValid(t *testing.T) { + if err := emptyPage.Validate(); err != nil { + t.Fatal(err) + } +} + +func TestPageWrite(t *testing.T) { + for i := 0; i < 100; i++ { + page := newDataPage() + h := page.Header() + h.PageType = pageTypeData + h.CollectionID = rand.Uint64() + h.ItemID = rand.Uint64() + h.DataSize = uint64(1 + rand.Int63n(2*pageDataSize)) + + data := make([]byte, h.DataSize) + crand.Read(data) + + n := page.Write(data) + h.CRC = page.ComputeCRC() + + if n > pageDataSize || n < 1 { + t.Fatal(n) + } + + if !bytes.Equal(data[:n], page.Data()[:n]) { + t.Fatal(data[:n], page.Data()[:n]) + } + + if err := page.Validate(); err != nil { + t.Fatal(err) + } + } +} diff --git a/mdb/pfile/pagefile.go b/mdb/pfile/pagefile.go new file mode 100644 index 0000000..b607da3 --- /dev/null +++ b/mdb/pfile/pagefile.go @@ -0,0 +1,307 @@ +package pfile + +import ( + "bufio" + "bytes" + "compress/gzip" + "encoding/binary" + "io" + "git.crumpington.com/public/jldb/lib/errs" + "git.crumpington.com/public/jldb/mdb/change" + "net" + "os" + "sync" + "time" +) + +type File struct { + lock sync.RWMutex + f *os.File + page dataPage +} + +func Open(path string) (*File, error) { + f, err := os.OpenFile(path, os.O_RDWR|os.O_CREATE, 0600) + if err != nil { + return nil, errs.IO.WithErr(err) + } + + pf := &File{f: f} + pf.page = newDataPage() + + return pf, nil +} + +func (pf *File) Close() error { + pf.lock.Lock() + defer pf.lock.Unlock() + + if err := pf.f.Close(); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +// ---------------------------------------------------------------------------- +// Writing +// ---------------------------------------------------------------------------- + +func (pf *File) ApplyChanges(changes []change.Change) error { + pf.lock.Lock() + defer pf.lock.Unlock() + + return pf.applyChanges(changes) +} + +func (pf *File) applyChanges(changes []change.Change) error { + for _, change := range changes { + if len(change.WritePageIDs) > 0 { + if err := pf.writeChangePages(change); err != nil { + return err + } + } + + for _, id := range change.ClearPageIDs { + if err := pf.writePage(emptyPage, id); err != nil { + return err + } + } + } + + if err := pf.f.Sync(); err != nil { + return errs.IO.WithErr(err) + } + + return nil +} + +func (pf *File) writeChangePages(change change.Change) error { + page := pf.page + + header := page.Header() + + header.PageType = pageTypeHead + header.CollectionID = change.CollectionID + header.ItemID = change.ItemID + header.DataSize = uint64(len(change.Data)) + + pageIDs := change.WritePageIDs + data := change.Data + + for len(change.Data) > 0 && len(pageIDs) > 0 { + pageID := pageIDs[0] + pageIDs = pageIDs[1:] + + if len(pageIDs) > 0 { + header.NextPage = pageIDs[0] + } else { + header.NextPage = 0 + } + + n := page.Write(data) + data = data[n:] + + page.Header().CRC = page.ComputeCRC() + + if err := pf.writePage(page, pageID); err != nil { + return err + } + + // All but first page has pageTypeData. + header.PageType = pageTypeData + } + + if len(pageIDs) > 0 { + return errs.Unexpected.WithMsg("Too many pages provided for given data.") + } + + if len(data) > 0 { + return errs.Unexpected.WithMsg("Not enough pages for given data.") + } + + return nil +} + +func (pf *File) writePage(page dataPage, id uint64) error { + if _, err := pf.f.WriteAt(page, int64(id*pageSize)); err != nil { + return errs.IO.WithErr(err) + } + return nil +} + +// ---------------------------------------------------------------------------- +// Reading +// ---------------------------------------------------------------------------- + +func (pf *File) iterate(each func(pageID uint64, page dataPage) error) error { + pf.lock.RLock() + defer pf.lock.RUnlock() + + page := pf.page + + fi, err := pf.f.Stat() + if err != nil { + return errs.IO.WithErr(err) + } + + fileSize := fi.Size() + if fileSize%pageSize != 0 { + return errs.Corrupt.WithMsg("File size isn't a multiple of page size.") + } + + maxPage := uint64(fileSize / pageSize) + if _, err := pf.f.Seek(0, io.SeekStart); err != nil { + return errs.IO.WithErr(err) + } + + r := bufio.NewReaderSize(pf.f, 1024*1024) + + for pageID := uint64(0); pageID < maxPage; pageID++ { + if _, err := r.Read(page); err != nil { + return errs.IO.WithErr(err) + } + + if err := page.Validate(); err != nil { + return err + } + if err := each(pageID, page); err != nil { + return err + } + } + + return nil +} + +func (pf *File) readData(id uint64, buf *bytes.Buffer) error { + page := pf.page + + // The head page. + if err := pf.readPage(page, id); err != nil { + return err + } + + remaining := int(page.Header().DataSize) + + for { + data := page.Data() + if len(data) > remaining { + data = data[:remaining] + } + + buf.Write(data) + remaining -= len(data) + + if page.Header().NextPage == 0 { + break + } + + if err := pf.readPage(page, page.Header().NextPage); err != nil { + return err + } + } + + if remaining != 0 { + return errs.Corrupt.WithMsg("Incorrect data size. %d remaining.", remaining) + } + + return nil +} + +func (pf *File) readPage(p dataPage, id uint64) error { + if _, err := pf.f.ReadAt(p, int64(id*pageSize)); err != nil { + return errs.IO.WithErr(err) + } + return p.Validate() +} + +// ---------------------------------------------------------------------------- +// Send / Recv +// ---------------------------------------------------------------------------- + +func (pf *File) Send(conn net.Conn, timeout time.Duration) error { + pf.lock.RLock() + defer pf.lock.RUnlock() + + if _, err := pf.f.Seek(0, io.SeekStart); err != nil { + return errs.IO.WithErr(err) + } + + fi, err := pf.f.Stat() + if err != nil { + return errs.IO.WithErr(err) + } + + remaining := fi.Size() + + conn.SetWriteDeadline(time.Now().Add(timeout)) + if err := binary.Write(conn, binary.LittleEndian, remaining); err != nil { + return err + } + + buf := make([]byte, 1024*1024) + w, err := gzip.NewWriterLevel(conn, 3) + if err != nil { + return errs.Unexpected.WithErr(err) + } + defer w.Close() + + for remaining > 0 { + n, err := pf.f.Read(buf) + if err != nil { + return errs.IO.WithErr(err) + } + + conn.SetWriteDeadline(time.Now().Add(timeout)) + if _, err := w.Write(buf[:n]); err != nil { + return errs.IO.WithErr(err) + } + + remaining -= int64(n) + w.Flush() + } + + return nil +} + +func Recv(conn net.Conn, filePath string, timeout time.Duration) error { + defer conn.Close() + + f, err := os.Create(filePath) + if err != nil { + return errs.IO.WithErr(err) + } + defer f.Close() + + remaining := uint64(0) + if err := binary.Read(conn, binary.LittleEndian, &remaining); err != nil { + return err + } + + r, err := gzip.NewReader(conn) + if err != nil { + return errs.Unexpected.WithErr(err) + } + defer r.Close() + + buf := make([]byte, 1024*1024) + for remaining > 0 { + + conn.SetReadDeadline(time.Now().Add(timeout)) + + n, err := io.ReadFull(r, buf) + if err != nil && n == 0 { + return errs.IO.WithErr(err) + } + remaining -= uint64(n) + + if _, err := f.Write(buf[:n]); err != nil { + return errs.IO.WithErr(err) + } + } + + if err := f.Sync(); err != nil { + return errs.IO.WithErr(err) + } + + return nil +} diff --git a/mdb/pfile/pagefile_test.go b/mdb/pfile/pagefile_test.go new file mode 100644 index 0000000..c7fd2ec --- /dev/null +++ b/mdb/pfile/pagefile_test.go @@ -0,0 +1,94 @@ +package pfile + +import ( + "bytes" + "os" + "path/filepath" + "testing" +) + +type FileState struct { + SeqNum uint64 + Data map[[2]uint64][]byte +} + +func (pf *File) Assert(t *testing.T, state pFileState) { + t.Helper() + + pf.lock.RLock() + defer pf.lock.RUnlock() + + idx, err := NewIndex(pf) + if err != nil { + t.Fatal(err) + } + + data := map[[2]uint64][]byte{} + err = IterateAllocated(pf, idx, func(cID, iID uint64, fileData []byte) error { + data[[2]uint64{cID, iID}] = bytes.Clone(fileData) + return nil + }) + + if err != nil { + t.Fatal(err) + } + + if len(data) != len(state.Data) { + t.Fatalf("Expected %d items but got %d.", len(state.Data), len(data)) + } + + for key, expected := range state.Data { + val, ok := data[key] + if !ok { + t.Fatalf("No data found for key %v.", key) + } + if !bytes.Equal(val, expected) { + t.Fatalf("Incorrect data for key %v.", key) + } + } +} + +func TestFileStateUpdateRandom(t *testing.T) { + t.Parallel() + tmpDir := t.TempDir() + walDir := filepath.Join(tmpDir, "wal") + pageFilePath := filepath.Join(tmpDir, "pagefile") + + if err := os.MkdirAll(walDir, 0700); err != nil { + t.Fatal(err) + } + + pf, err := Open(pageFilePath) + if err != nil { + t.Fatal(err) + } + + idx, err := NewIndex(pf) + if err != nil { + t.Fatal(err) + } + + state := pFileState{ + Data: map[[2]uint64][]byte{}, + } + + for i := uint64(1); i < 256; i++ { + changes := randomChangeList() + idx.StageChanges(changes) + + if err := pf.ApplyChanges(changes); err != nil { + t.Fatal(err) + } + idx.ApplyChanges(changes) + + for _, ch := range changes { + if !ch.Store { + delete(state.Data, [2]uint64{ch.CollectionID, ch.ItemID}) + } else { + state.Data[[2]uint64{ch.CollectionID, ch.ItemID}] = ch.Data + } + } + + pf.Assert(t, state) + } +} diff --git a/mdb/pfile/record_test.go b/mdb/pfile/record_test.go new file mode 100644 index 0000000..24d0ec5 --- /dev/null +++ b/mdb/pfile/record_test.go @@ -0,0 +1,57 @@ +package pfile + +import ( + "bytes" + "git.crumpington.com/public/jldb/lib/wal" + "git.crumpington.com/public/jldb/mdb/change" +) + +// ---------------------------------------------------------------------------- + +type pFileState struct { + Data map[[2]uint64][]byte +} + +// ---------------------------------------------------------------------------- + +type recBuilder struct { + changes []change.Change + rec wal.Record +} + +func NewRecBuilder(seqNum, timestamp int64) *recBuilder { + return &recBuilder{ + rec: wal.Record{ + SeqNum: seqNum, + TimestampMS: timestamp, + }, + changes: []change.Change{}, + } +} + +func (b *recBuilder) Store(cID, iID uint64, data string) *recBuilder { + b.changes = append(b.changes, change.Change{ + CollectionID: cID, + ItemID: iID, + Store: true, + Data: []byte(data), + }) + return b +} + +func (b *recBuilder) Delete(cID, iID uint64) *recBuilder { + b.changes = append(b.changes, change.Change{ + CollectionID: cID, + ItemID: iID, + Store: false, + }) + return b +} + +func (b *recBuilder) Record() wal.Record { + buf := &bytes.Buffer{} + change.Write(b.changes, buf) + b.rec.DataSize = int64(buf.Len()) + b.rec.Reader = buf + return b.rec +} diff --git a/mdb/pfile/sendrecv_test.go b/mdb/pfile/sendrecv_test.go new file mode 100644 index 0000000..51d0fec --- /dev/null +++ b/mdb/pfile/sendrecv_test.go @@ -0,0 +1,62 @@ +package pfile + +/* +func TestSendRecv(t *testing.T) { + tmpDir := t.TempDir() + filePath1 := filepath.Join(tmpDir, "1") + filePath2 := filepath.Join(tmpDir, "2") + defer os.RemoveAll(tmpDir) + + f1, err := os.Create(filePath1) + if err != nil { + t.Fatal(err) + } + + size := rand.Int63n(1024 * 1024 * 128) + buf := make([]byte, size) + crand.Read(buf) + if _, err := f1.Write(buf); err != nil { + t.Fatal(err) + } + + if err := f1.Close(); err != nil { + t.Fatal(err) + } + + c1, c2 := net.Pipe() + errChan := make(chan error) + + go func() { + err := Send(filePath1, c1, time.Second) + if err != nil { + log.Printf("Send error: %v", err) + } + errChan <- err + }() + + go func() { + err := Recv(filePath2, c2, time.Second) + if err != nil { + log.Printf("Recv error: %v", err) + } + errChan <- err + }() + + if err := <-errChan; err != nil { + t.Fatal(err) + } + + if err := <-errChan; err != nil { + t.Fatal(err) + } + + buf2, err := os.ReadFile(filePath2) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(buf, buf2) { + t.Fatal("Not equal.") + } +} +*/ diff --git a/mdb/snapshot.go b/mdb/snapshot.go new file mode 100644 index 0000000..8476f84 --- /dev/null +++ b/mdb/snapshot.go @@ -0,0 +1,106 @@ +package mdb + +import ( + "bytes" + "encoding/json" + "git.crumpington.com/public/jldb/mdb/change" + "sync/atomic" +) + +type Snapshot struct { + parent atomic.Pointer[Snapshot] + // The Snapshot's version is incremented each time it's cloned. + version uint64 + + // The snapshot's seqNum is set when it becomes active (read-only).] + seqNum int64 + timestampMS int64 + + collections map[uint64]any // Map from collection ID to *collectionState[T]. + changes []change.Change +} + +func newSnapshot() *Snapshot { + return &Snapshot{ + collections: map[uint64]any{}, + changes: []change.Change{}, + } +} + +func (s *Snapshot) addCollection(id uint64, c any) { + s.collections[id] = c +} + +func (s *Snapshot) writable() bool { + return s.parent.Load() != nil +} + +func (s *Snapshot) setReadOnly() { + s.parent.Store(nil) + s.changes = s.changes[:0] +} + +func (s *Snapshot) store(cID, iID uint64, item any) { + change := s.appendChange(cID, iID) + change.Store = true + + buf := bytes.NewBuffer(change.Data[:0]) + if err := json.NewEncoder(buf).Encode(item); err != nil { + panic(err) + } + change.Data = buf.Bytes() +} + +func (s *Snapshot) delete(cID, iID uint64) { + change := s.appendChange(cID, iID) + change.Store = false +} + +func (s *Snapshot) appendChange(cID, iID uint64) *change.Change { + if len(s.changes) == cap(s.changes) { + s.changes = append(s.changes, change.Change{}) + } else { + s.changes = s.changes[:len(s.changes)+1] + } + + change := &s.changes[len(s.changes)-1] + change.CollectionID = cID + change.ItemID = iID + change.Store = false + change.ClearPageIDs = change.ClearPageIDs[:0] + change.WritePageIDs = change.WritePageIDs[:0] + change.Data = change.Data[:0] + return change +} + +func (s *Snapshot) begin() *Snapshot { + c := s.clone() + c.changes = c.changes[:0] + return c +} + +func (s *Snapshot) clone() *Snapshot { + collections := make(map[uint64]any, len(s.collections)) + for k, v := range s.collections { + collections[k] = v + } + + c := &Snapshot{ + version: s.version + 1, + collections: collections, + changes: s.changes[:], + } + c.parent.Store(s) + return c +} + +func (s *Snapshot) rollback() *Snapshot { + parent := s.parent.Load() + if parent == nil { + return nil + } + + // Don't throw away allocated changes. + parent.changes = s.changes[:len(parent.changes)] + return parent +} diff --git a/mdb/snapshotisolation_test.go b/mdb/snapshotisolation_test.go new file mode 100644 index 0000000..622d1ce --- /dev/null +++ b/mdb/snapshotisolation_test.go @@ -0,0 +1,41 @@ +package mdb + +import ( + "log" + "os" + "sync/atomic" + "testing" + "time" +) + +func TestDBIsolation(t *testing.T) { + if testing.Short() { + t.Skip("Skipping test in short mode.") + } + rootDir := t.TempDir() + defer os.RemoveAll(rootDir) + + db, err := OpenDataDB(rootDir) + if err != nil { + t.Fatal(err) + } + + done := &atomic.Bool{} + go func() { + defer done.Store(true) + db.ModifyFor(8 * time.Second) + }() + + count := 0 + for !done.Load() { + count++ + tx := db.Snapshot() + computed := db.ComputeCRC(tx) + stored := db.ReadCRC(tx) + if computed != stored { + t.Fatal(stored, computed) + } + } + + log.Printf("Read: %d", count) +} diff --git a/mdb/testdb_test.go b/mdb/testdb_test.go new file mode 100644 index 0000000..41156ec --- /dev/null +++ b/mdb/testdb_test.go @@ -0,0 +1,151 @@ +package mdb + +import ( + "crypto/rand" + "errors" + "hash/crc32" + "log" + mrand "math/rand" + "runtime" + "slices" + "sync" + "sync/atomic" + "time" +) + +type DataItem struct { + ID uint64 + Data []byte +} + +type DataCollection struct { + *Collection[DataItem] +} + +func NewDataCollection(db *Database) DataCollection { + return DataCollection{ + Collection: NewCollection(db, "Data", &CollectionConfig[DataItem]{ + Copy: func(in *DataItem) *DataItem { + out := &DataItem{} + *out = *in + out.Data = slices.Clone(in.Data) + return out + }, + }), + } +} + +type CRCItem struct { + ID uint64 // Always 1 + CRC32 uint32 +} + +type CRCCollection struct { + *Collection[CRCItem] +} + +func NewCRCCollection(db *Database) CRCCollection { + return CRCCollection{ + Collection: NewCollection[CRCItem](db, "CRC", nil), + } +} + +type DataDB struct { + *Database + Datas DataCollection + CRCs CRCCollection +} + +func OpenDataDB(rootDir string) (DataDB, error) { + db := New(Config{ + RootDir: rootDir, + Primary: true, + }) + + testdb := DataDB{ + Database: db, + Datas: NewDataCollection(db), + CRCs: NewCRCCollection(db), + } + + return testdb, testdb.Open() +} + +func (db DataDB) ModifyFor(dt time.Duration) { + wg := sync.WaitGroup{} + var count int64 + for i := 0; i < runtime.NumCPU(); i++ { + wg.Add(1) + go func() { + defer wg.Done() + + t0 := time.Now() + for time.Since(t0) < dt { + atomic.AddInt64(&count, 1) + db.modifyOnce() + } + }() + } + + wg.Wait() + log.Printf("Modified: %d", count) +} + +func (db DataDB) modifyOnce() { + isErr := mrand.Float64() < 0.1 + err := db.Update(func(tx *Snapshot) error { + h := crc32.NewIEEE() + for dataID := uint64(1); dataID < 10; dataID++ { + d := DataItem{ + ID: dataID, + Data: make([]byte, 256), + } + + rand.Read(d.Data) + h.Write(d.Data) + if err := db.Datas.Upsert(tx, &d); err != nil { + return err + } + } + + crc := CRCItem{ + ID: 1, + } + + if !isErr { + crc.CRC32 = h.Sum32() + return db.CRCs.Upsert(tx, &crc) + } + + crc.CRC32 = 1 + if err := db.CRCs.Upsert(tx, &crc); err != nil { + return err + } + + return errors.New("ERROR") + }) + + if isErr != (err != nil) { + panic(err) + } +} + +func (db DataDB) ComputeCRC(tx *Snapshot) uint32 { + h := crc32.NewIEEE() + for dataID := uint64(1); dataID < 10; dataID++ { + d, ok := db.Datas.ByID.Get(tx, &DataItem{ID: dataID}) + if !ok { + continue + } + h.Write(d.Data) + } + return h.Sum32() +} + +func (db DataDB) ReadCRC(tx *Snapshot) uint32 { + r, ok := db.CRCs.ByID.Get(tx, &CRCItem{ID: 1}) + if !ok { + return 0 + } + return r.CRC32 +} diff --git a/mdb/testing/crashconsistency/main.go b/mdb/testing/crashconsistency/main.go new file mode 100644 index 0000000..42de62c --- /dev/null +++ b/mdb/testing/crashconsistency/main.go @@ -0,0 +1,162 @@ +package main + +import ( + "crypto/rand" + "errors" + "hash/crc32" + "git.crumpington.com/public/jldb/mdb" + "log" + mrand "math/rand" + "os" + "runtime" + "slices" + "sync" + "sync/atomic" + "time" +) + +type DataItem struct { + ID uint64 + Data []byte +} + +type DataCollection struct { + *mdb.Collection[DataItem] +} + +func NewDataCollection(db *mdb.Database) DataCollection { + return DataCollection{ + Collection: mdb.NewCollection(db, "Data", &mdb.CollectionConfig[DataItem]{ + Copy: func(in *DataItem) *DataItem { + out := new(DataItem) + *out = *in + out.Data = slices.Clone(in.Data) + return out + }, + }), + } +} + +type CRCItem struct { + ID uint64 // Always 1 + CRC32 uint32 +} + +type CRCCollection struct { + *mdb.Collection[CRCItem] +} + +func NewCRCCollection(db *mdb.Database) CRCCollection { + return CRCCollection{ + Collection: mdb.NewCollection[CRCItem](db, "CRC", nil), + } +} + +type DataDB struct { + *mdb.Database + Datas DataCollection + CRCs CRCCollection +} + +func OpenDataDB(rootDir string) (DataDB, error) { + db := mdb.New(mdb.Config{RootDir: rootDir, Primary: true}) + testdb := DataDB{ + Database: db, + Datas: NewDataCollection(db), + CRCs: NewCRCCollection(db), + } + + if err := db.Open(); err != nil { + return testdb, err + } + + return testdb, nil +} + +func (db DataDB) ModifyFor(dt time.Duration) { + wg := sync.WaitGroup{} + var count int64 + for i := 0; i < runtime.NumCPU(); i++ { + wg.Add(1) + go func() { + defer wg.Done() + + t0 := time.Now() + for time.Since(t0) < dt { + atomic.AddInt64(&count, 1) + db.modifyOnce() + } + }() + } + + wg.Wait() + log.Printf("Modified: %d", count) +} + +func (db DataDB) modifyOnce() { + isErr := mrand.Float64() < 0.1 + err := db.Update(func(tx *mdb.Snapshot) error { + h := crc32.NewIEEE() + for dataID := uint64(1); dataID < 10; dataID++ { + d := DataItem{ + ID: dataID, + Data: make([]byte, 256), + } + + rand.Read(d.Data) + h.Write(d.Data) + if err := db.Datas.Upsert(tx, &d); err != nil { + return err + } + } + + crc := CRCItem{ + ID: 1, + } + + if !isErr { + crc.CRC32 = h.Sum32() + return db.CRCs.Upsert(tx, &crc) + } + + crc.CRC32 = 1 + if err := db.CRCs.Upsert(tx, &crc); err != nil { + return err + } + + return errors.New("ERROR") + }) + + if isErr != (err != nil) { + panic(err) + } +} + +func (db DataDB) ComputeCRC(tx *mdb.Snapshot) uint32 { + h := crc32.NewIEEE() + for dataID := uint64(1); dataID < 10; dataID++ { + d, ok := db.Datas.ByID.Get(tx, &DataItem{ID: dataID}) + if !ok { + continue + } + h.Write(d.Data) + } + return h.Sum32() +} + +func (db DataDB) ReadCRC(tx *mdb.Snapshot) uint32 { + r, ok := db.CRCs.ByID.Get(tx, &CRCItem{ID: 1}) + if !ok { + return 0 + } + return r.CRC32 +} + +func main() { + db, err := OpenDataDB(os.Args[1]) + if err != nil { + log.Fatal(err) + } + + db.ModifyFor(time.Minute) +} diff --git a/mdb/txaggregator.go b/mdb/txaggregator.go new file mode 100644 index 0000000..e073e1f --- /dev/null +++ b/mdb/txaggregator.go @@ -0,0 +1,92 @@ +package mdb + +/* +type txAggregator struct { + Stop chan struct{} + Done *sync.WaitGroup + ModChan chan txMod + W *cswal.Writer + Index *pagefile.Index + Snapshot *atomic.Pointer[Snapshot] +} + +func (p txAggregator) Run() { + defer p.Done.Done() + defer p.W.Close() + + var ( + tx *Snapshot + mod txMod + rec cswal.Record + err error + toNotify = make([]chan error, 0, 1024) + ) + +READ_FIRST: + + toNotify = toNotify[:0] + + select { + case mod = <-p.ModChan: + goto BEGIN + case <-p.Stop: + goto END + } + +BEGIN: + + tx = p.Snapshot.Load().begin() + goto APPLY_MOD + +CLONE: + + tx = tx.clone() + goto APPLY_MOD + +APPLY_MOD: + + if err = mod.Update(tx); err != nil { + mod.Resp <- err + goto ROLLBACK + } + + toNotify = append(toNotify, mod.Resp) + goto NEXT + +ROLLBACK: + + if len(toNotify) == 0 { + goto READ_FIRST + } + + tx = tx.rollback() + goto NEXT + +NEXT: + + select { + case mod = <-p.ModChan: + goto CLONE + default: + goto WRITE + } + +WRITE: + + rec, err = writeChangesToWAL(tx.changes, p.Index, p.W) + if err == nil { + tx.seqNum = rec.SeqNum + tx.updatedAt = rec.CreatedAt + tx.setReadOnly() + p.Snapshot.Store(tx) + } + + for i := range toNotify { + toNotify[i] <- err + } + + goto READ_FIRST + +END: +} +*/ diff --git a/mdb/types.go b/mdb/types.go new file mode 100644 index 0000000..452f9ae --- /dev/null +++ b/mdb/types.go @@ -0,0 +1,8 @@ +package mdb + +type collection interface { + Name() string + insertItem(tx *Snapshot, itemID uint64, data []byte) error + upsertItem(tx *Snapshot, itemID uint64, data []byte) error + deleteItem(tx *Snapshot, itemID uint64) error +} diff --git a/mdb/walfollower.go b/mdb/walfollower.go new file mode 100644 index 0000000..f0c642e --- /dev/null +++ b/mdb/walfollower.go @@ -0,0 +1,35 @@ +package mdb + +/* +type walFollower struct { + Stop chan struct{} + Done *sync.WaitGroup + W *cswal.Writer + Client *Client +} + +func (f *walFollower) Run() { + go func() { + <-f.Stop + f.Client.Close() + }() + + defer f.Done.Done() + + for { + f.runOnce() + select { + case <-f.Stop: + return + default: + time.Sleep(time.Second) + } + } +} + +func (f *walFollower) runOnce() { + if err := f.Client.StreamWAL(f.W); err != nil { + log.Printf("[WAL-FOLLOWER] Recv failed: %s", err) + } +} +*/