wip
This commit is contained in:
parent
d0587cc585
commit
c5419d662e
2
flock/README.md
Normal file
2
flock/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# flock
|
||||||
|
|
69
flock/flock.go
Normal file
69
flock/flock.go
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
// The flock package provides a file-system mediated locking mechanism on linux
|
||||||
|
// using the `flock` system call.
|
||||||
|
|
||||||
|
package flock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Lock gets an exclusive lock on the file at the given path. If the file
|
||||||
|
// doesn't exist, it's created.
|
||||||
|
func Lock(path string) (*os.File, error) {
|
||||||
|
return lock(path, syscall.LOCK_EX)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TryLock will return a nil file if the file is already locked.
|
||||||
|
func TryLock(path string) (*os.File, error) {
|
||||||
|
return lock(path, syscall.LOCK_EX|syscall.LOCK_NB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func LockFile(f *os.File) error {
|
||||||
|
_, err := lockFile(f, syscall.LOCK_EX)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns true if the lock was successfully acquired.
|
||||||
|
func TryLockFile(f *os.File) (bool, error) {
|
||||||
|
return lockFile(f, syscall.LOCK_EX|syscall.LOCK_NB)
|
||||||
|
}
|
||||||
|
|
||||||
|
func lockFile(f *os.File, flags int) (bool, error) {
|
||||||
|
|
||||||
|
if err := flock(int(f.Fd()), flags); err != nil {
|
||||||
|
if flags&syscall.LOCK_NB != 0 && errors.Is(err, syscall.EAGAIN) {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func flock(fd int, how int) error {
|
||||||
|
_, _, e1 := syscall.Syscall(syscall.SYS_FLOCK, uintptr(fd), uintptr(how), 0)
|
||||||
|
if e1 != 0 {
|
||||||
|
return syscall.Errno(e1)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func lock(path string, flags int) (*os.File, error) {
|
||||||
|
perm := os.O_CREATE | os.O_RDWR
|
||||||
|
f, err := os.OpenFile(path, perm, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
ok, err := lockFile(f, flags)
|
||||||
|
if err != nil || !ok {
|
||||||
|
f.Close()
|
||||||
|
f = nil
|
||||||
|
}
|
||||||
|
return f, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unlock releases the lock acquired via the Lock function.
|
||||||
|
func Unlock(f *os.File) error {
|
||||||
|
return f.Close()
|
||||||
|
}
|
66
flock/flock_test.go
Normal file
66
flock/flock_test.go
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
package flock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_Lock_basic(t *testing.T) {
|
||||||
|
ch := make(chan int, 1)
|
||||||
|
f, err := Lock("/tmp/fsutil-test-lock")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
ch <- 10
|
||||||
|
Unlock(f)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case x := <-ch:
|
||||||
|
t.Fatal(x)
|
||||||
|
default:
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
f2, _ := Lock("/tmp/fsutil-test-lock")
|
||||||
|
defer Unlock(f2)
|
||||||
|
select {
|
||||||
|
case i := <-ch:
|
||||||
|
if i != 10 {
|
||||||
|
t.Fatal(i)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
t.Fatal("No value available.")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_Lock_badPath(t *testing.T) {
|
||||||
|
_, err := Lock("./dne/file.lock")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTryLock(t *testing.T) {
|
||||||
|
lockPath := "/tmp/fsutil-test-lock"
|
||||||
|
f, err := TryLock(lockPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%#v", err)
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f2, err := TryLock(lockPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if f2 != nil {
|
||||||
|
t.Fatal(f2)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Unlock(f); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
3
flock/go.mod
Normal file
3
flock/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/flock
|
||||||
|
|
||||||
|
go 1.23.0
|
0
flock/go.sum
Normal file
0
flock/go.sum
Normal file
2
httpconn/README.md
Normal file
2
httpconn/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# httpconn
|
||||||
|
|
104
httpconn/client.go
Normal file
104
httpconn/client.go
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
package httpconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrUnknownScheme = errors.New("uknown scheme")
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dialer struct {
|
||||||
|
timeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDialer() *Dialer {
|
||||||
|
return &Dialer{timeout: 10 * time.Second}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) SetTimeout(timeout time.Duration) {
|
||||||
|
d.timeout = timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) Dial(rawURL string) (net.Conn, error) {
|
||||||
|
u, err := url.Parse(rawURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch u.Scheme {
|
||||||
|
case "https":
|
||||||
|
return d.DialHTTPS(u.Host+":443", u.Path)
|
||||||
|
case "http":
|
||||||
|
return d.DialHTTP(u.Host, u.Path)
|
||||||
|
default:
|
||||||
|
return nil, ErrUnknownScheme
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), d.timeout)
|
||||||
|
dd := tls.Dialer{}
|
||||||
|
conn, err := dd.DialContext(ctx, "tcp", host)
|
||||||
|
cancel()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.finishDialing(conn, host, path)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) {
|
||||||
|
conn, err := net.DialTimeout("tcp", host, d.timeout)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return d.finishDialing(conn, host, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) {
|
||||||
|
conn.SetDeadline(time.Now().Add(d.timeout))
|
||||||
|
|
||||||
|
if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Require successful HTTP response before using the conn.
|
||||||
|
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Status != "200 OK" {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.SetDeadline(time.Time{})
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Dial(rawURL string) (net.Conn, error) {
|
||||||
|
return NewDialer().Dial(rawURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialHTTPS(host, path string) (net.Conn, error) {
|
||||||
|
return NewDialer().DialHTTPS(host, path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func DialHTTP(host, path string) (net.Conn, error) {
|
||||||
|
return NewDialer().DialHTTP(host, path)
|
||||||
|
}
|
42
httpconn/conn_test.go
Normal file
42
httpconn/conn_test.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package httpconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"golang.org/x/net/nettest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNetTest_TestConn(t *testing.T) {
|
||||||
|
nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
|
||||||
|
|
||||||
|
connCh := make(chan net.Conn, 1)
|
||||||
|
doneCh := make(chan bool)
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
conn, err := Accept(w, r)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
connCh <- conn
|
||||||
|
<-doneCh
|
||||||
|
})
|
||||||
|
|
||||||
|
srv := httptest.NewServer(mux)
|
||||||
|
|
||||||
|
c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
c2 = <-connCh
|
||||||
|
|
||||||
|
return c1, c2, func() {
|
||||||
|
doneCh <- true
|
||||||
|
srv.Close()
|
||||||
|
}, nil
|
||||||
|
})
|
||||||
|
}
|
5
httpconn/go.mod
Normal file
5
httpconn/go.mod
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module git.crumpington.com/lib/httpconn
|
||||||
|
|
||||||
|
go 1.23.2
|
||||||
|
|
||||||
|
require golang.org/x/net v0.30.0
|
2
httpconn/go.sum
Normal file
2
httpconn/go.sum
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||||
|
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
32
httpconn/server.go
Normal file
32
httpconn/server.go
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
package httpconn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
|
||||||
|
if r.Method != "CONNECT" {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||||
|
io.WriteString(w, "405 must CONNECT\n")
|
||||||
|
return nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
hj, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
return nil, http.ErrNotSupported
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, _, err := hj.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n")
|
||||||
|
conn.SetDeadline(time.Time{})
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
2
idgen/README.md
Normal file
2
idgen/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# idgen
|
||||||
|
|
3
idgen/go.mod
Normal file
3
idgen/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/idgen
|
||||||
|
|
||||||
|
go 1.23.2
|
53
idgen/idgen.go
Normal file
53
idgen/idgen.go
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
package idgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base32"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Creates a new, random token.
|
||||||
|
func NewToken() string {
|
||||||
|
buf := make([]byte, 20)
|
||||||
|
_, err := rand.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return base32.StdEncoding.EncodeToString(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
lock sync.Mutex
|
||||||
|
ts int64 = time.Now().Unix()
|
||||||
|
counter int64 = 1
|
||||||
|
counterMax int64 = 1 << 20
|
||||||
|
)
|
||||||
|
|
||||||
|
// NextID can generate ~1M ints per second for a given nodeID.
|
||||||
|
//
|
||||||
|
// nodeID must be less than 64.
|
||||||
|
func NextID(nodeID int64) int64 {
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
|
||||||
|
tt := time.Now().Unix()
|
||||||
|
if tt > ts {
|
||||||
|
ts = tt
|
||||||
|
counter = 1
|
||||||
|
} else {
|
||||||
|
counter++
|
||||||
|
if counter == counterMax {
|
||||||
|
panic("Too many IDs.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (ts << 26) + (nodeID << 20) + counter
|
||||||
|
}
|
||||||
|
|
||||||
|
func SplitID(id int64) (unixTime, nodeID, counter int64) {
|
||||||
|
counter = id & (0x00000000000FFFFF)
|
||||||
|
nodeID = (id >> 20) & (0x000000000000003F)
|
||||||
|
unixTime = id >> 26
|
||||||
|
return
|
||||||
|
}
|
19
idgen/idgen_test.go
Normal file
19
idgen/idgen_test.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package idgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkNext(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
NextID(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNextID(t *testing.T) {
|
||||||
|
id := NextID(32)
|
||||||
|
|
||||||
|
a, b, c := SplitID(id)
|
||||||
|
log.Print(a, b, c)
|
||||||
|
}
|
2
keyedmutex/README.md
Normal file
2
keyedmutex/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# keyedmutex
|
||||||
|
|
3
keyedmutex/go.mod
Normal file
3
keyedmutex/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/keyedmutex
|
||||||
|
|
||||||
|
go 1.23.2
|
67
keyedmutex/keyedmutex.go
Normal file
67
keyedmutex/keyedmutex.go
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
package keyedmutex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeyedMutex[K comparable] struct {
|
||||||
|
mu *sync.Mutex
|
||||||
|
waitList map[K]*list.List
|
||||||
|
}
|
||||||
|
|
||||||
|
func New[K comparable]() KeyedMutex[K] {
|
||||||
|
return KeyedMutex[K]{
|
||||||
|
mu: new(sync.Mutex),
|
||||||
|
waitList: map[K]*list.List{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m KeyedMutex[K]) Lock(key K) {
|
||||||
|
if ch := m.lock(key); ch != nil {
|
||||||
|
<-ch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m KeyedMutex[K]) lock(key K) chan struct{} {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if waitList, ok := m.waitList[key]; ok {
|
||||||
|
ch := make(chan struct{})
|
||||||
|
waitList.PushBack(ch)
|
||||||
|
return ch
|
||||||
|
}
|
||||||
|
|
||||||
|
m.waitList[key] = list.New()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m KeyedMutex[K]) TryLock(key K) bool {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := m.waitList[key]; ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.waitList[key] = list.New()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m KeyedMutex[K]) Unlock(key K) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
waitList, ok := m.waitList[key]
|
||||||
|
if !ok {
|
||||||
|
panic("unlock of unlocked mutex")
|
||||||
|
}
|
||||||
|
|
||||||
|
if waitList.Len() == 0 {
|
||||||
|
delete(m.waitList, key)
|
||||||
|
} else {
|
||||||
|
ch := waitList.Remove(waitList.Front()).(chan struct{})
|
||||||
|
ch <- struct{}{}
|
||||||
|
}
|
||||||
|
}
|
123
keyedmutex/keyedmutex_test.go
Normal file
123
keyedmutex/keyedmutex_test.go
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
package keyedmutex
|
||||||
|
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestKeyedMutex(t *testing.T) {
|
||||||
|
checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) {
|
||||||
|
if len(m.waitList) != len(keys) {
|
||||||
|
t.Fatal(m.waitList, keys)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, key := range keys {
|
||||||
|
if _, ok := m.waitList[key]; !ok {
|
||||||
|
t.Fatal(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m := New[string]()
|
||||||
|
checkState(t, m)
|
||||||
|
|
||||||
|
m.Lock("a")
|
||||||
|
checkState(t, m, "a")
|
||||||
|
m.Lock("b")
|
||||||
|
checkState(t, m, "a", "b")
|
||||||
|
m.Lock("c")
|
||||||
|
checkState(t, m, "a", "b", "c")
|
||||||
|
|
||||||
|
if m.TryLock("a") {
|
||||||
|
t.Fatal("a")
|
||||||
|
}
|
||||||
|
if m.TryLock("b") {
|
||||||
|
t.Fatal("b")
|
||||||
|
}
|
||||||
|
if m.TryLock("c") {
|
||||||
|
t.Fatal("c")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.TryLock("d") {
|
||||||
|
t.Fatal("d")
|
||||||
|
}
|
||||||
|
|
||||||
|
checkState(t, m, "a", "b", "c", "d")
|
||||||
|
|
||||||
|
if !m.TryLock("e") {
|
||||||
|
t.Fatal("e")
|
||||||
|
}
|
||||||
|
checkState(t, m, "a", "b", "c", "d", "e")
|
||||||
|
|
||||||
|
m.Unlock("c")
|
||||||
|
checkState(t, m, "a", "b", "d", "e")
|
||||||
|
m.Unlock("a")
|
||||||
|
checkState(t, m, "b", "d", "e")
|
||||||
|
m.Unlock("e")
|
||||||
|
checkState(t, m, "b", "d")
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 8; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
m.Lock("b")
|
||||||
|
m.Unlock("b")
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
m.Unlock("b")
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
checkState(t, m, "d")
|
||||||
|
|
||||||
|
m.Unlock("d")
|
||||||
|
checkState(t, m)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeyedMutex_unlockUnlocked(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r == nil {
|
||||||
|
t.Fatal(r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
m := New[string]()
|
||||||
|
m.Unlock("aldkfj")
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUncontendedMutex(b *testing.B) {
|
||||||
|
m := New[string]()
|
||||||
|
key := "xyz"
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
m.Lock(key)
|
||||||
|
m.Unlock(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkContendedMutex(b *testing.B) {
|
||||||
|
m := New[string]()
|
||||||
|
key := "xyz"
|
||||||
|
|
||||||
|
m.Lock(key)
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
m.Lock(key)
|
||||||
|
m.Unlock(key)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
m.Unlock(key)
|
||||||
|
wg.Wait()
|
||||||
|
}
|
2
kvmemcache/README.md
Normal file
2
kvmemcache/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# kvmemcache
|
||||||
|
|
134
kvmemcache/cache.go
Normal file
134
kvmemcache/cache.go
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
package kvmemcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"git.crumpington.com/lib/keyedmutex"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Cache[K comparable, V any] struct {
|
||||||
|
updateLock keyedmutex.KeyedMutex[K]
|
||||||
|
src func(K) (V, error)
|
||||||
|
ttl time.Duration
|
||||||
|
maxSize int
|
||||||
|
|
||||||
|
// Lock protects variables below.
|
||||||
|
lock sync.Mutex
|
||||||
|
cache map[K]*list.Element
|
||||||
|
ll *list.List
|
||||||
|
stats Stats
|
||||||
|
}
|
||||||
|
|
||||||
|
type lruItem[K comparable, V any] struct {
|
||||||
|
key K
|
||||||
|
createdAt time.Time
|
||||||
|
value V
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type Config[K comparable, V any] struct {
|
||||||
|
MaxSize int
|
||||||
|
TTL time.Duration // Zero to ignore.
|
||||||
|
Src func(K) (V, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] {
|
||||||
|
return &Cache[K, V]{
|
||||||
|
updateLock: keyedmutex.New[K](),
|
||||||
|
src: conf.Src,
|
||||||
|
ttl: conf.TTL,
|
||||||
|
maxSize: conf.MaxSize,
|
||||||
|
lock: sync.Mutex{},
|
||||||
|
cache: make(map[K]*list.Element, conf.MaxSize+1),
|
||||||
|
ll: list.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) Get(key K) (V, error) {
|
||||||
|
ok, val, err := c.get(key)
|
||||||
|
if ok {
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c.load(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) Evict(key K) {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
c.evict(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) Stats() Stats {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
return c.stats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) put(key K, value V, err error) {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
c.stats.Misses++
|
||||||
|
|
||||||
|
c.cache[key] = c.ll.PushFront(lruItem[K, V]{
|
||||||
|
key: key,
|
||||||
|
createdAt: time.Now(),
|
||||||
|
value: value,
|
||||||
|
err: err,
|
||||||
|
})
|
||||||
|
|
||||||
|
if c.maxSize != 0 && len(c.cache) > c.maxSize {
|
||||||
|
li := c.ll.Back()
|
||||||
|
c.ll.Remove(li)
|
||||||
|
delete(c.cache, li.Value.(lruItem[K, V]).key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) evict(key K) {
|
||||||
|
elem := c.cache[key]
|
||||||
|
if elem != nil {
|
||||||
|
delete(c.cache, key)
|
||||||
|
c.ll.Remove(elem)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
li := c.cache[key]
|
||||||
|
if li == nil {
|
||||||
|
return false, val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
item := li.Value.(lruItem[K, V])
|
||||||
|
// Maybe evict.
|
||||||
|
if c.ttl != 0 && time.Since(item.createdAt) > c.ttl {
|
||||||
|
c.evict(key)
|
||||||
|
return false, val, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
c.stats.Hits++
|
||||||
|
|
||||||
|
c.ll.MoveToFront(li)
|
||||||
|
return true, item.value, item.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K, V]) load(key K) (V, error) {
|
||||||
|
c.updateLock.Lock(key)
|
||||||
|
defer c.updateLock.Unlock(key)
|
||||||
|
|
||||||
|
// Check again in case we lost the update race.
|
||||||
|
ok, val, err := c.get(key)
|
||||||
|
if ok {
|
||||||
|
return val, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Won the update race.
|
||||||
|
val, err = c.src(key)
|
||||||
|
c.put(key, val, err)
|
||||||
|
return val, err
|
||||||
|
}
|
249
kvmemcache/cache_test.go
Normal file
249
kvmemcache/cache_test.go
Normal file
@ -0,0 +1,249 @@
|
|||||||
|
package kvmemcache
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type State[K comparable] struct {
|
||||||
|
Keys []K
|
||||||
|
Stats Stats
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Cache[K,V]) assert(state State[K]) error {
|
||||||
|
c.lock.Lock()
|
||||||
|
defer c.lock.Unlock()
|
||||||
|
|
||||||
|
if len(c.cache) != len(state.Keys) {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"Expected %d keys but found %d.",
|
||||||
|
len(state.Keys),
|
||||||
|
len(c.cache))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, k := range state.Keys {
|
||||||
|
if _, ok := c.cache[k]; !ok {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"Expected key %v not found.",
|
||||||
|
k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.stats.Hits != state.Stats.Hits {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"Expected %d hits, but found %d.",
|
||||||
|
state.Stats.Hits,
|
||||||
|
c.stats.Hits)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.stats.Misses != state.Stats.Misses {
|
||||||
|
return fmt.Errorf(
|
||||||
|
"Expected %d misses, but found %d.",
|
||||||
|
state.Stats.Misses,
|
||||||
|
c.stats.Misses)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var ErrTest = errors.New("Hello")
|
||||||
|
|
||||||
|
func TestCache_basic(t *testing.T) {
|
||||||
|
c := New(Config[string, string]{
|
||||||
|
MaxSize: 4,
|
||||||
|
TTL: 50 * time.Millisecond,
|
||||||
|
Src: func(key string) (string, error) {
|
||||||
|
if key == "err" {
|
||||||
|
return "", ErrTest
|
||||||
|
}
|
||||||
|
return key, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
type testCase struct {
|
||||||
|
name string
|
||||||
|
sleep time.Duration
|
||||||
|
key string
|
||||||
|
evict bool
|
||||||
|
state State[string]
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []testCase{
|
||||||
|
{
|
||||||
|
name: "get a",
|
||||||
|
key: "a",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a"},
|
||||||
|
Stats: Stats{Hits: 0, Misses: 1},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get a again",
|
||||||
|
key: "a",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a"},
|
||||||
|
Stats: Stats{Hits: 1, Misses: 1},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "sleep, then get a again",
|
||||||
|
sleep: 55 * time.Millisecond,
|
||||||
|
key: "a",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a"},
|
||||||
|
Stats: Stats{Hits: 1, Misses: 2},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get b",
|
||||||
|
key: "b",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "b"},
|
||||||
|
Stats: Stats{Hits: 1, Misses: 3},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get c",
|
||||||
|
key: "c",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "b", "c"},
|
||||||
|
Stats: Stats{Hits: 1, Misses: 4},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get d",
|
||||||
|
key: "d",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "b", "c", "d"},
|
||||||
|
Stats: Stats{Hits: 1, Misses: 5},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get e",
|
||||||
|
key: "e",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"b", "c", "d", "e"},
|
||||||
|
Stats: Stats{Hits: 1, Misses: 6},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get c again",
|
||||||
|
key: "c",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"b", "c", "d", "e"},
|
||||||
|
Stats: Stats{Hits: 2, Misses: 6},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get err",
|
||||||
|
key: "err",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"c", "d", "e", "err"},
|
||||||
|
Stats: Stats{Hits: 2, Misses: 7},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "get err again",
|
||||||
|
key: "err",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"c", "d", "e", "err"},
|
||||||
|
Stats: Stats{Hits: 3, Misses: 7},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "evict c",
|
||||||
|
key: "c",
|
||||||
|
evict: true,
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"d", "e", "err"},
|
||||||
|
Stats: Stats{Hits: 3, Misses: 7},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "reload-all a",
|
||||||
|
key: "a",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "d", "e", "err"},
|
||||||
|
Stats: Stats{Hits: 3, Misses: 8},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "reload-all b",
|
||||||
|
key: "b",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "b", "e", "err"},
|
||||||
|
Stats: Stats{Hits: 3, Misses: 9},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "reload-all c",
|
||||||
|
key: "c",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "b", "c", "err"},
|
||||||
|
Stats: Stats{Hits: 3, Misses: 10},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "reload-all d",
|
||||||
|
key: "d",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"a", "b", "c", "d"},
|
||||||
|
Stats: Stats{Hits: 3, Misses: 11},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "read a again",
|
||||||
|
key: "a",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"b", "c", "d", "a"},
|
||||||
|
Stats: Stats{Hits: 4, Misses: 11},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
name: "read e, evicting b",
|
||||||
|
key: "e",
|
||||||
|
state: State[string]{
|
||||||
|
Keys: []string{"c", "d", "a", "e"},
|
||||||
|
Stats: Stats{Hits: 4, Misses: 12},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
time.Sleep(tc.sleep)
|
||||||
|
if !tc.evict {
|
||||||
|
val, err := c.Get(tc.key)
|
||||||
|
if tc.key == "err" && err != ErrTest {
|
||||||
|
t.Fatal(tc.name, val)
|
||||||
|
}
|
||||||
|
if tc.key != "err" && val != tc.key {
|
||||||
|
t.Fatal(tc.name, tc.key, val)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c.Evict(tc.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := c.assert(tc.state); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCache_thunderingHerd(t *testing.T) {
|
||||||
|
c := New(Config[string,string]{
|
||||||
|
MaxSize: 4,
|
||||||
|
Src: func(key string) (string, error) {
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
return key, nil
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
for i := 0; i < 16384; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
val, err := c.Get("a")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
if val != "a" {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
stats := c.Stats()
|
||||||
|
if stats.Hits != 16383 || stats.Misses != 1 {
|
||||||
|
t.Fatal(stats)
|
||||||
|
}
|
||||||
|
}
|
5
kvmemcache/go.mod
Normal file
5
kvmemcache/go.mod
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module git.crumpington.com/lib/kvmemcache
|
||||||
|
|
||||||
|
go 1.23.2
|
||||||
|
|
||||||
|
require git.crumpington.com/lib/keyedmutex v1.0.1
|
2
kvmemcache/go.sum
Normal file
2
kvmemcache/go.sum
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
git.crumpington.com/lib/keyedmutex v1.0.1 h1:5ylwGXQzL9ojZIhlqkut6dpa4yt6Wz6bOWbf/tQBAMQ=
|
||||||
|
git.crumpington.com/lib/keyedmutex v1.0.1/go.mod h1:VxxJRU/XvvF61IuJZG7kUIv954Q8+Rh8bnVpEzGYrQ4=
|
6
kvmemcache/stats.go
Normal file
6
kvmemcache/stats.go
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
package kvmemcache
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
Hits uint64
|
||||||
|
Misses uint64
|
||||||
|
}
|
2
mmap/README.md
Normal file
2
mmap/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# mmap
|
||||||
|
|
100
mmap/file.go
Normal file
100
mmap/file.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package mmap
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
type File struct {
|
||||||
|
f *os.File
|
||||||
|
Map []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func Create(path string, size int64) (*File, error) {
|
||||||
|
f, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := f.Truncate(size); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m, err := Map(f, PROT_READ|PROT_WRITE)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &File{f, m}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Opens a mapped file in read-only mode.
|
||||||
|
func Open(path string) (*File, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m, err := Map(f, PROT_READ)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &File{f, m}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenFile(
|
||||||
|
path string,
|
||||||
|
fileFlags int,
|
||||||
|
perm os.FileMode,
|
||||||
|
size int64, // -1 for file size.
|
||||||
|
) (*File, error) {
|
||||||
|
f, err := os.OpenFile(path, fileFlags, perm)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
writable := fileFlags|os.O_RDWR != 0 || fileFlags|os.O_WRONLY != 0
|
||||||
|
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if writable && size > 0 && size != fi.Size() {
|
||||||
|
if err := f.Truncate(size); err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mapFlags := PROT_READ
|
||||||
|
if writable {
|
||||||
|
mapFlags |= PROT_WRITE
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Map(f, mapFlags)
|
||||||
|
if err != nil {
|
||||||
|
f.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &File{f, m}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) Sync() error {
|
||||||
|
return Sync(f.Map)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *File) Close() error {
|
||||||
|
if f.Map != nil {
|
||||||
|
if err := Unmap(f.Map); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.Map = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.f != nil {
|
||||||
|
if err := f.f.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.f = nil
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
3
mmap/go.mod
Normal file
3
mmap/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/mmap
|
||||||
|
|
||||||
|
go 1.23.2
|
0
mmap/go.sum
Normal file
0
mmap/go.sum
Normal file
59
mmap/mmap.go
Normal file
59
mmap/mmap.go
Normal file
@ -0,0 +1,59 @@
|
|||||||
|
package mmap
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PROT_READ = syscall.PROT_READ
|
||||||
|
PROT_WRITE = syscall.PROT_WRITE
|
||||||
|
)
|
||||||
|
|
||||||
|
// Mmap creates a memory map of the given file. The flags argument should be a
|
||||||
|
// combination of PROT_READ and PROT_WRITE. The size of the map will be the
|
||||||
|
// file's size.
|
||||||
|
func Map(f *os.File, flags int) ([]byte, error) {
|
||||||
|
fi, err := f.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
size := fi.Size()
|
||||||
|
|
||||||
|
addr, _, errno := syscall.Syscall6(
|
||||||
|
syscall.SYS_MMAP,
|
||||||
|
0, // addr: 0 => allow kernel to choose
|
||||||
|
uintptr(size),
|
||||||
|
uintptr(flags),
|
||||||
|
uintptr(syscall.MAP_SHARED),
|
||||||
|
f.Fd(),
|
||||||
|
0) // offset: 0 => start of file
|
||||||
|
if errno != 0 {
|
||||||
|
return nil, syscall.Errno(errno)
|
||||||
|
}
|
||||||
|
|
||||||
|
return unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Munmap unmaps the data obtained by Map.
|
||||||
|
func Unmap(data []byte) error {
|
||||||
|
_, _, errno := syscall.Syscall(
|
||||||
|
syscall.SYS_MUNMAP,
|
||||||
|
uintptr(unsafe.Pointer(&data[:1][0])),
|
||||||
|
uintptr(cap(data)),
|
||||||
|
0)
|
||||||
|
if errno != 0 {
|
||||||
|
return syscall.Errno(errno)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func Sync(b []byte) (err error) {
|
||||||
|
_p0 := unsafe.Pointer(&b[0])
|
||||||
|
_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(_p0), uintptr(len(b)), uintptr(syscall.MS_SYNC))
|
||||||
|
if errno != 0 {
|
||||||
|
err = syscall.Errno(errno)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
45
pgutil/README.md
Normal file
45
pgutil/README.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# pgutil
|
||||||
|
|
||||||
|
## Transactions
|
||||||
|
|
||||||
|
Simplify postgres transactions using `WithTx` for serializable transactions,
|
||||||
|
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
|
||||||
|
type to get automatic retries of serialization errors.
|
||||||
|
|
||||||
|
## Migrations
|
||||||
|
|
||||||
|
Put your migrations into a directory, for example `migrations`, ordered by name
|
||||||
|
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
|
||||||
|
`Migrate` function:
|
||||||
|
|
||||||
|
```Go
|
||||||
|
//go:embed migrations
|
||||||
|
var migrations embed.FS
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Migrate(db, migrations) // Check the error, of course.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
In order to test this packge, we need to create a test user and database:
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo su postgres
|
||||||
|
psql
|
||||||
|
|
||||||
|
CREATE DATABASE test;
|
||||||
|
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
|
||||||
|
GRANT ALL PRIVILEGES ON DATABASE test TO test;
|
||||||
|
|
||||||
|
use test
|
||||||
|
|
||||||
|
GRANT ALL ON SCHEMA public TO test;
|
||||||
|
```
|
||||||
|
|
||||||
|
Check that you can connect via the command line:
|
||||||
|
|
||||||
|
```
|
||||||
|
psql -h 127.0.0.1 -U test --password test
|
||||||
|
```
|
42
pgutil/dropall.go
Normal file
42
pgutil/dropall.go
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
const dropTablesQueryQuery = `
|
||||||
|
SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;'
|
||||||
|
FROM
|
||||||
|
pg_tables
|
||||||
|
WHERE
|
||||||
|
schemaname='public'`
|
||||||
|
|
||||||
|
// Deletes all tables in the database. Useful for testing.
|
||||||
|
func DropAllTables(db *sql.DB) error {
|
||||||
|
rows, err := db.Query(dropTablesQueryQuery)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
queries := []string{}
|
||||||
|
for rows.Next() {
|
||||||
|
var s string
|
||||||
|
if err := rows.Scan(&s); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
queries = append(queries, s)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(queries) > 0 {
|
||||||
|
log.Printf("DROPPING ALL (%d) TABLES", len(queries))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, query := range queries {
|
||||||
|
if _, err := db.Exec(query); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
31
pgutil/errors.go
Normal file
31
pgutil/errors.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/lib/pq"
|
||||||
|
)
|
||||||
|
|
||||||
|
func ErrIsDuplicateKey(err error) bool {
|
||||||
|
return ErrHasCode(err, "23505")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrIsForeignKey(err error) bool {
|
||||||
|
return ErrHasCode(err, "23503")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrIsSerializationFaiilure(err error) bool {
|
||||||
|
return ErrHasCode(err, "40001")
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrHasCode(err error, code string) bool {
|
||||||
|
if err == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
var pErr *pq.Error
|
||||||
|
if errors.As(err, &pErr) {
|
||||||
|
return pErr.Code == pq.ErrorCode(code)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
36
pgutil/errors_test.go
Normal file
36
pgutil/errors_test.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestErrors(t *testing.T) {
|
||||||
|
db, err := sql.Open(
|
||||||
|
"postgres",
|
||||||
|
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DropAllTables(db); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Migrate(db, testMigrationFS); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`)
|
||||||
|
if !ErrIsDuplicateKey(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`)
|
||||||
|
if !ErrIsDuplicateKey(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`)
|
||||||
|
if !ErrIsForeignKey(err) {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
5
pgutil/go.mod
Normal file
5
pgutil/go.mod
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module git.crumpington.com/git/pgutil
|
||||||
|
|
||||||
|
go 1.23.2
|
||||||
|
|
||||||
|
require github.com/lib/pq v1.10.9
|
2
pgutil/go.sum
Normal file
2
pgutil/go.sum
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
82
pgutil/migrate.go
Normal file
82
pgutil/migrate.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
const initMigrationTableQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
|
||||||
|
|
||||||
|
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
|
||||||
|
|
||||||
|
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
|
||||||
|
|
||||||
|
func Migrate(db *sql.DB, migrationFS embed.FS) error {
|
||||||
|
return WithTx(db, func(tx *sql.Tx) error {
|
||||||
|
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dirs, err := migrationFS.ReadDir(".")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(dirs) != 1 {
|
||||||
|
return errors.New("expected a single migrations directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dirs[0].IsDir() {
|
||||||
|
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
dirName := dirs[0].Name()
|
||||||
|
files, err := migrationFS.ReadDir(dirName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort sql files by name.
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
return files[i].Name() < files[j].Name()
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, dirEnt := range files {
|
||||||
|
if !dirEnt.Type().IsRegular() {
|
||||||
|
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
name = dirEnt.Name()
|
||||||
|
exists bool
|
||||||
|
)
|
||||||
|
|
||||||
|
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(string(migration)); err != nil {
|
||||||
|
return fmt.Errorf("migration %s failed: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
50
pgutil/migrate_test.go
Normal file
50
pgutil/migrate_test.go
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed test-migrations
|
||||||
|
var testMigrationFS embed.FS
|
||||||
|
|
||||||
|
func TestMigrate(t *testing.T) {
|
||||||
|
db, err := sql.Open(
|
||||||
|
"postgres",
|
||||||
|
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DropAllTables(db); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Migrate(db, testMigrationFS); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shouldn't have any effect.
|
||||||
|
if err := Migrate(db, testMigrationFS); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
|
||||||
|
var exists bool
|
||||||
|
|
||||||
|
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
t.Fatal("1 shouldn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("2 should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
9
pgutil/test-migrations/000.sql
Normal file
9
pgutil/test-migrations/000.sql
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
CREATE TABLE users(
|
||||||
|
UserID BIGINT NOT NULL PRIMARY KEY,
|
||||||
|
Email TEXT NOT NULL UNIQUE);
|
||||||
|
|
||||||
|
CREATE TABLE user_notes(
|
||||||
|
UserID BIGINT NOT NULL REFERENCES users(UserID),
|
||||||
|
NoteID BIGINT NOT NULL,
|
||||||
|
Note Text NOT NULL,
|
||||||
|
PRIMARY KEY(UserID,NoteID));
|
1
pgutil/test-migrations/001.sql
Normal file
1
pgutil/test-migrations/001.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');
|
1
pgutil/test-migrations/002.sql
Normal file
1
pgutil/test-migrations/002.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
DELETE FROM users WHERE UserID=1;
|
70
pgutil/tx.go
Normal file
70
pgutil/tx.go
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"math/rand"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Postgres doesn't use serializable transactions by default. This wrapper will
|
||||||
|
// run the enclosed function within a serializable. Note: this may return an
|
||||||
|
// retriable serialization error (see ErrIsSerializationFaiilure).
|
||||||
|
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||||
|
return WithTxDefault(db, func(tx *sql.Tx) error {
|
||||||
|
if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return fn(tx)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// This is a convenience function to provide a transaction wrapper with the
|
||||||
|
// default isolation level.
|
||||||
|
func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||||
|
// Start a transaction.
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fn(tx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
err = tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SerialTxRunner attempts serializable transactions in a loop. If a
|
||||||
|
// transaction fails due to a serialization error, then the runner will retry
|
||||||
|
// with exponential backoff, until the sleep time reaches MaxTimeout.
|
||||||
|
//
|
||||||
|
// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep
|
||||||
|
// for ~100, 200, 400, and 800 ms between retries.
|
||||||
|
//
|
||||||
|
// 10% jitter is added to the sleep time.
|
||||||
|
type SerialTxRunner struct {
|
||||||
|
MinTimeout time.Duration
|
||||||
|
MaxTimeout time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||||
|
timeout := r.MinTimeout
|
||||||
|
for {
|
||||||
|
err := WithTx(db, fn)
|
||||||
|
if err == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10)))
|
||||||
|
time.Sleep(sleepTimeout)
|
||||||
|
timeout *= 2
|
||||||
|
}
|
||||||
|
}
|
120
pgutil/tx_test.go
Normal file
120
pgutil/tx_test.go
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
package pgutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestExecuteTx verifies transaction retry using the classic
|
||||||
|
// example of write skew in bank account balance transfers.
|
||||||
|
func TestWithTx(t *testing.T) {
|
||||||
|
db, err := sql.Open(
|
||||||
|
"postgres",
|
||||||
|
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := DropAllTables(db); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer db.Close()
|
||||||
|
|
||||||
|
initStmt := `
|
||||||
|
CREATE TABLE t (acct INT PRIMARY KEY, balance INT);
|
||||||
|
INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100);
|
||||||
|
`
|
||||||
|
if _, err := db.Exec(initStmt); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryI interface {
|
||||||
|
Query(string, ...interface{}) (*sql.Rows, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
getBalances := func(q queryI) (bal1, bal2 int, err error) {
|
||||||
|
var rows *sql.Rows
|
||||||
|
rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
balances := []*int{&bal1, &bal2}
|
||||||
|
i := 0
|
||||||
|
for ; rows.Next(); i++ {
|
||||||
|
if err = rows.Scan(balances[i]); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if i != 2 {
|
||||||
|
err = fmt.Errorf("expected two balances; got %d", i)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond}
|
||||||
|
|
||||||
|
runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error {
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
*iter = 0
|
||||||
|
errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error {
|
||||||
|
*iter++
|
||||||
|
bal1, bal2, err := getBalances(tx)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// If this is the first iteration, wait for the other tx to
|
||||||
|
// also read.
|
||||||
|
if *iter == 1 {
|
||||||
|
wg.Done()
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
// Now, subtract from one account and give to the other.
|
||||||
|
if bal1 > bal2 {
|
||||||
|
if _, err := tx.Exec(`
|
||||||
|
UPDATE t SET balance=balance-100 WHERE acct=1;
|
||||||
|
UPDATE t SET balance=balance+100 WHERE acct=2;
|
||||||
|
`); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := tx.Exec(`
|
||||||
|
UPDATE t SET balance=balance+100 WHERE acct=1;
|
||||||
|
UPDATE t SET balance=balance-100 WHERE acct=2;
|
||||||
|
`); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}()
|
||||||
|
return errCh
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
var iters1, iters2 int
|
||||||
|
txn1Err := runTxn(&wg, &iters1)
|
||||||
|
txn2Err := runTxn(&wg, &iters2)
|
||||||
|
if err := <-txn1Err; err != nil {
|
||||||
|
t.Errorf("expected success in txn1; got %s", err)
|
||||||
|
}
|
||||||
|
if err := <-txn2Err; err != nil {
|
||||||
|
t.Errorf("expected success in txn2; got %s", err)
|
||||||
|
}
|
||||||
|
if iters1+iters2 <= 2 {
|
||||||
|
t.Errorf("expected retries between the competing transactions; "+
|
||||||
|
"got txn1=%d, txn2=%d", iters1, iters2)
|
||||||
|
}
|
||||||
|
bal1, bal2, err := getBalances(db)
|
||||||
|
if err != nil || bal1 != 100 || bal2 != 100 {
|
||||||
|
t.Errorf("expected balances to be restored without error; "+
|
||||||
|
"got acct1=%d, acct2=%d: %s", bal1, bal2, err)
|
||||||
|
}
|
||||||
|
}
|
3
ratelimiter/go.mod
Normal file
3
ratelimiter/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/ratelimiter
|
||||||
|
|
||||||
|
go 1.23.2
|
86
ratelimiter/ratelimiter.go
Normal file
86
ratelimiter/ratelimiter.go
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
package ratelimiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrBackoff = errors.New("Backoff")
|
||||||
|
|
||||||
|
type Config struct {
|
||||||
|
BurstLimit int64 // Number of requests to allow to burst.
|
||||||
|
FillPeriod time.Duration // Add one per period.
|
||||||
|
MaxWaitCount int64 // Max number of waiting requests. 0 disables.
|
||||||
|
}
|
||||||
|
|
||||||
|
type Limiter struct {
|
||||||
|
lock sync.Mutex
|
||||||
|
|
||||||
|
fillPeriod time.Duration
|
||||||
|
minWaitTime time.Duration
|
||||||
|
maxWaitTime time.Duration
|
||||||
|
|
||||||
|
waitTime time.Duration // If waitTime < 0, no waiting occurs.
|
||||||
|
lastRequest time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(conf Config) *Limiter {
|
||||||
|
if conf.BurstLimit < 0 {
|
||||||
|
panic(conf.BurstLimit)
|
||||||
|
}
|
||||||
|
if conf.FillPeriod <= 0 {
|
||||||
|
panic(conf.FillPeriod)
|
||||||
|
}
|
||||||
|
if conf.MaxWaitCount < 0 {
|
||||||
|
panic(conf.MaxWaitCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
lim := &Limiter{
|
||||||
|
lastRequest: time.Now(),
|
||||||
|
fillPeriod: conf.FillPeriod,
|
||||||
|
waitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit),
|
||||||
|
minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit),
|
||||||
|
maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount),
|
||||||
|
}
|
||||||
|
|
||||||
|
lim.waitTime = lim.minWaitTime
|
||||||
|
|
||||||
|
return lim
|
||||||
|
}
|
||||||
|
|
||||||
|
func (lim *Limiter) limit(count int64) (time.Duration, error) {
|
||||||
|
lim.lock.Lock()
|
||||||
|
defer lim.lock.Unlock()
|
||||||
|
|
||||||
|
dt := time.Since(lim.lastRequest)
|
||||||
|
|
||||||
|
waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod
|
||||||
|
|
||||||
|
if waitTime < lim.minWaitTime {
|
||||||
|
waitTime = lim.minWaitTime
|
||||||
|
} else if waitTime > lim.maxWaitTime {
|
||||||
|
return 0, ErrBackoff
|
||||||
|
}
|
||||||
|
|
||||||
|
lim.waitTime = waitTime
|
||||||
|
lim.lastRequest = lim.lastRequest.Add(dt)
|
||||||
|
|
||||||
|
return lim.waitTime, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the limiter to the calling thread. The function may sleep for up to
|
||||||
|
// maxWaitTime before returning. If the timeout would need to be more than
|
||||||
|
// maxWaitTime to enforce the rate limit, ErrBackoff is returned.
|
||||||
|
func (lim *Limiter) Limit() error {
|
||||||
|
dt, err := lim.limit(1)
|
||||||
|
time.Sleep(dt) // Will return immediately for dt <= 0.
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply the limiter for multiple items at once.
|
||||||
|
func (lim *Limiter) LimitMultiple(count int64) error {
|
||||||
|
dt, err := lim.limit(count)
|
||||||
|
time.Sleep(dt)
|
||||||
|
return err
|
||||||
|
}
|
101
ratelimiter/ratelimiter_test.go
Normal file
101
ratelimiter/ratelimiter_test.go
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
package ratelimiter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRateLimiter_Limit_Errors(t *testing.T) {
|
||||||
|
type TestCase struct {
|
||||||
|
Name string
|
||||||
|
Conf Config
|
||||||
|
N int
|
||||||
|
ErrCount int
|
||||||
|
DT time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []TestCase{
|
||||||
|
{
|
||||||
|
Name: "no burst, no wait",
|
||||||
|
Conf: Config{
|
||||||
|
BurstLimit: 0,
|
||||||
|
FillPeriod: 100 * time.Millisecond,
|
||||||
|
MaxWaitCount: 0,
|
||||||
|
},
|
||||||
|
N: 32,
|
||||||
|
ErrCount: 32,
|
||||||
|
DT: 0,
|
||||||
|
}, {
|
||||||
|
Name: "no wait",
|
||||||
|
Conf: Config{
|
||||||
|
BurstLimit: 10,
|
||||||
|
FillPeriod: 100 * time.Millisecond,
|
||||||
|
MaxWaitCount: 0,
|
||||||
|
},
|
||||||
|
N: 32,
|
||||||
|
ErrCount: 22,
|
||||||
|
DT: 0,
|
||||||
|
}, {
|
||||||
|
Name: "no burst",
|
||||||
|
Conf: Config{
|
||||||
|
BurstLimit: 0,
|
||||||
|
FillPeriod: 10 * time.Millisecond,
|
||||||
|
MaxWaitCount: 10,
|
||||||
|
},
|
||||||
|
N: 32,
|
||||||
|
ErrCount: 22,
|
||||||
|
DT: 100 * time.Millisecond,
|
||||||
|
}, {
|
||||||
|
Name: "burst and wait",
|
||||||
|
Conf: Config{
|
||||||
|
BurstLimit: 10,
|
||||||
|
FillPeriod: 10 * time.Millisecond,
|
||||||
|
MaxWaitCount: 10,
|
||||||
|
},
|
||||||
|
N: 32,
|
||||||
|
ErrCount: 12,
|
||||||
|
DT: 100 * time.Millisecond,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
wg := sync.WaitGroup{}
|
||||||
|
l := New(tc.Conf)
|
||||||
|
errs := make([]error, tc.N)
|
||||||
|
|
||||||
|
t0 := time.Now()
|
||||||
|
|
||||||
|
for i := 0; i < tc.N; i++ {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i int) {
|
||||||
|
errs[i] = l.Limit()
|
||||||
|
wg.Done()
|
||||||
|
}(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
dt := time.Since(t0)
|
||||||
|
|
||||||
|
errCount := 0
|
||||||
|
for _, err := range errs {
|
||||||
|
if err != nil {
|
||||||
|
errCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errCount != tc.ErrCount {
|
||||||
|
t.Fatalf("%s: Expected %d errors but got %d.",
|
||||||
|
tc.Name, tc.ErrCount, errCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt < tc.DT {
|
||||||
|
t.Fatal(tc.Name, dt, tc.DT)
|
||||||
|
}
|
||||||
|
|
||||||
|
if dt > tc.DT+10*time.Millisecond {
|
||||||
|
t.Fatal(tc.Name, dt, tc.DT)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
22
sqlgen/README.md
Normal file
22
sqlgen/README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# sqlgen
|
||||||
|
|
||||||
|
## Installing
|
||||||
|
|
||||||
|
```
|
||||||
|
go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest
|
||||||
|
```
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
```
|
||||||
|
sqlgen [driver] [defs-path] [output-path]
|
||||||
|
```
|
||||||
|
|
||||||
|
## File Format
|
||||||
|
|
||||||
|
```
|
||||||
|
TABLE [sql-name] OF [go-type] <NoInsert> <NoUpdate> <NoDelete> (
|
||||||
|
[sql-column] [go-type] <AS go-name> <PK> <NoInsert> <NoUpdate>,
|
||||||
|
...
|
||||||
|
);
|
||||||
|
```
|
7
sqlgen/cmd/sqlgen/main.go
Normal file
7
sqlgen/cmd/sqlgen/main.go
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import "git.crumpington.com/lib/sqlgen"
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
sqlgen.Main()
|
||||||
|
}
|
3
sqlgen/go.mod
Normal file
3
sqlgen/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/sqlgen
|
||||||
|
|
||||||
|
go 1.23.2
|
43
sqlgen/main.go
Normal file
43
sqlgen/main.go
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
package sqlgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Main() {
|
||||||
|
usage := func() {
|
||||||
|
fmt.Fprintf(os.Stderr, `
|
||||||
|
%s DRIVER DEFS_PATH OUTPUT_PATH
|
||||||
|
|
||||||
|
Drivers are one of: sqlite, postgres
|
||||||
|
`,
|
||||||
|
os.Args[0])
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(os.Args) != 4 {
|
||||||
|
usage()
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
template string
|
||||||
|
driver = os.Args[1]
|
||||||
|
defsPath = os.Args[2]
|
||||||
|
outputPath = os.Args[3]
|
||||||
|
)
|
||||||
|
|
||||||
|
switch driver {
|
||||||
|
case "sqlite":
|
||||||
|
template = sqliteTemplate
|
||||||
|
default:
|
||||||
|
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
|
||||||
|
usage()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := render(template, defsPath, outputPath)
|
||||||
|
if err != nil {
|
||||||
|
fmt.Fprintf(os.Stderr, "Error: %v", err)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
143
sqlgen/parse.go
Normal file
143
sqlgen/parse.go
Normal file
@ -0,0 +1,143 @@
|
|||||||
|
package sqlgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func parsePath(filePath string) (*schema, error) {
|
||||||
|
fileBytes, err := os.ReadFile(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return parseBytes(fileBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseBytes(fileBytes []byte) (*schema, error) {
|
||||||
|
s := string(fileBytes)
|
||||||
|
for _, c := range []string{",", "(", ")", ";"} {
|
||||||
|
s = strings.ReplaceAll(s, c, " "+c+" ")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
tokens = strings.Fields(s)
|
||||||
|
schema = &schema{}
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
for len(tokens) > 0 {
|
||||||
|
switch tokens[0] {
|
||||||
|
case "TABLE":
|
||||||
|
tokens, err = parseTable(schema, tokens)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, errors.New("invalid token: " + tokens[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return schema, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTable(schema *schema, tokens []string) ([]string, error) {
|
||||||
|
tokens = tokens[1:]
|
||||||
|
if len(tokens) < 3 {
|
||||||
|
return tokens, errors.New("incomplete table definition")
|
||||||
|
}
|
||||||
|
if tokens[1] != "OF" {
|
||||||
|
return tokens, errors.New("expected OF in table definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
table := &table{
|
||||||
|
Name: tokens[0],
|
||||||
|
Type: tokens[2],
|
||||||
|
}
|
||||||
|
schema.Tables = append(schema.Tables, table)
|
||||||
|
|
||||||
|
tokens = tokens[3:]
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return tokens, errors.New("missing table definition body")
|
||||||
|
}
|
||||||
|
|
||||||
|
for len(tokens) > 0 {
|
||||||
|
switch tokens[0] {
|
||||||
|
case "NoInsert":
|
||||||
|
table.NoInsert = true
|
||||||
|
tokens = tokens[1:]
|
||||||
|
case "NoUpdate":
|
||||||
|
table.NoUpdate = true
|
||||||
|
tokens = tokens[1:]
|
||||||
|
case "NoDelete":
|
||||||
|
table.NoDelete = true
|
||||||
|
tokens = tokens[1:]
|
||||||
|
case "(":
|
||||||
|
return parseTableBody(table, tokens[1:])
|
||||||
|
default:
|
||||||
|
return tokens, errors.New("unexpected token in table definition: " + tokens[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens, errors.New("incomplete table definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTableBody(table *table, tokens []string) ([]string, error) {
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for len(tokens) > 0 && tokens[0] != ";" {
|
||||||
|
tokens, err = parseTableColumn(table, tokens)
|
||||||
|
if err != nil {
|
||||||
|
return tokens, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) < 1 || tokens[0] != ";" {
|
||||||
|
return tokens, errors.New("incomplete table column definitions")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens[1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseTableColumn(table *table, tokens []string) ([]string, error) {
|
||||||
|
if len(tokens) < 2 {
|
||||||
|
return tokens, errors.New("incomplete column definition")
|
||||||
|
}
|
||||||
|
column := &column{
|
||||||
|
Name: tokens[0],
|
||||||
|
Type: tokens[1],
|
||||||
|
SqlName: tokens[0],
|
||||||
|
}
|
||||||
|
table.Columns = append(table.Columns, column)
|
||||||
|
|
||||||
|
tokens = tokens[2:]
|
||||||
|
for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" {
|
||||||
|
switch tokens[0] {
|
||||||
|
case "AS":
|
||||||
|
if len(tokens) < 2 {
|
||||||
|
return tokens, errors.New("incomplete AS clause in column definition")
|
||||||
|
}
|
||||||
|
column.Name = tokens[1]
|
||||||
|
tokens = tokens[2:]
|
||||||
|
case "PK":
|
||||||
|
column.PK = true
|
||||||
|
tokens = tokens[1:]
|
||||||
|
case "NoInsert":
|
||||||
|
column.NoInsert = true
|
||||||
|
tokens = tokens[1:]
|
||||||
|
case "NoUpdate":
|
||||||
|
column.NoUpdate = true
|
||||||
|
tokens = tokens[1:]
|
||||||
|
default:
|
||||||
|
return tokens, errors.New("unexpected token in column definition: " + tokens[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return tokens, errors.New("incomplete column definition")
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokens[1:], nil
|
||||||
|
}
|
45
sqlgen/parse_test.go
Normal file
45
sqlgen/parse_test.go
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package sqlgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParse(t *testing.T) {
|
||||||
|
toString := func(v any) string {
|
||||||
|
txt, _ := json.MarshalIndent(v, "", " ")
|
||||||
|
return string(txt)
|
||||||
|
}
|
||||||
|
|
||||||
|
paths, err := filepath.Glob("test-files/TestParse/*.def")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, defPath := range paths {
|
||||||
|
t.Run(filepath.Base(defPath), func(t *testing.T) {
|
||||||
|
parsed, err := parsePath(defPath)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expected := &schema{}
|
||||||
|
if err := json.Unmarshal(b, expected); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(parsed, expected) {
|
||||||
|
t.Fatalf("%s != %s", toString(parsed), toString(expected))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
263
sqlgen/schema.go
Normal file
263
sqlgen/schema.go
Normal file
@ -0,0 +1,263 @@
|
|||||||
|
package sqlgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type schema struct {
|
||||||
|
Tables []*table
|
||||||
|
}
|
||||||
|
|
||||||
|
type table struct {
|
||||||
|
Name string // Name in SQL
|
||||||
|
Type string // Go type
|
||||||
|
NoInsert bool
|
||||||
|
NoUpdate bool
|
||||||
|
NoDelete bool
|
||||||
|
Columns []*column
|
||||||
|
}
|
||||||
|
|
||||||
|
type column struct {
|
||||||
|
Name string
|
||||||
|
Type string
|
||||||
|
SqlName string // Defaults to Name
|
||||||
|
PK bool // PK won't be updated
|
||||||
|
NoInsert bool
|
||||||
|
NoUpdate bool // Don't update column in update function
|
||||||
|
}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
func (t *table) colSQLNames() []string {
|
||||||
|
names := make([]string, len(t.Columns))
|
||||||
|
for i := range names {
|
||||||
|
names[i] = t.Columns[i].SqlName
|
||||||
|
}
|
||||||
|
return names
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) SelectQuery() string {
|
||||||
|
return fmt.Sprintf(`SELECT %s FROM %s`,
|
||||||
|
strings.Join(t.colSQLNames(), ","),
|
||||||
|
t.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) insertCols() (cols []*column) {
|
||||||
|
for _, c := range t.Columns {
|
||||||
|
if !c.NoInsert {
|
||||||
|
cols = append(cols, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) InsertQuery() string {
|
||||||
|
cols := t.insertCols()
|
||||||
|
b := &strings.Builder{}
|
||||||
|
b.WriteString(`INSERT INTO `)
|
||||||
|
b.WriteString(t.Name)
|
||||||
|
b.WriteString(`(`)
|
||||||
|
for i, c := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(`,`)
|
||||||
|
}
|
||||||
|
b.WriteString(c.SqlName)
|
||||||
|
}
|
||||||
|
b.WriteString(`) VALUES(`)
|
||||||
|
for i := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(`,`)
|
||||||
|
}
|
||||||
|
b.WriteString(`?`)
|
||||||
|
}
|
||||||
|
b.WriteString(`)`)
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) InsertArgs() string {
|
||||||
|
args := []string{}
|
||||||
|
for i, col := range t.Columns {
|
||||||
|
if !col.NoInsert {
|
||||||
|
args = append(args, "row."+t.Columns[i].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strings.Join(args, ", ")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) UpdateCols() (cols []*column) {
|
||||||
|
for _, c := range t.Columns {
|
||||||
|
if !(c.PK || c.NoUpdate) {
|
||||||
|
cols = append(cols, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) UpdateQuery() string {
|
||||||
|
cols := t.UpdateCols()
|
||||||
|
|
||||||
|
b := &strings.Builder{}
|
||||||
|
b.WriteString(`UPDATE `)
|
||||||
|
b.WriteString(t.Name + ` SET `)
|
||||||
|
for i, col := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteByte(',')
|
||||||
|
}
|
||||||
|
b.WriteString(col.SqlName + `=?`)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(` WHERE`)
|
||||||
|
|
||||||
|
for i, c := range t.pkCols() {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(` AND`)
|
||||||
|
}
|
||||||
|
b.WriteString(` ` + c.SqlName + `=?`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) UpdateArgs() string {
|
||||||
|
cols := t.UpdateCols()
|
||||||
|
|
||||||
|
b := &strings.Builder{}
|
||||||
|
for i, col := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(`, `)
|
||||||
|
}
|
||||||
|
b.WriteString("row." + col.Name)
|
||||||
|
}
|
||||||
|
for _, col := range t.pkCols() {
|
||||||
|
b.WriteString(", row." + col.Name)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) UpdateFullCols() (cols []*column) {
|
||||||
|
for _, c := range t.Columns {
|
||||||
|
if !c.PK {
|
||||||
|
cols = append(cols, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) UpdateFullQuery() string {
|
||||||
|
cols := t.UpdateFullCols()
|
||||||
|
|
||||||
|
b := &strings.Builder{}
|
||||||
|
b.WriteString(`UPDATE `)
|
||||||
|
b.WriteString(t.Name + ` SET `)
|
||||||
|
for i, col := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteByte(',')
|
||||||
|
}
|
||||||
|
b.WriteString(col.SqlName + `=?`)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.WriteString(` WHERE`)
|
||||||
|
|
||||||
|
for i, c := range t.pkCols() {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(` AND`)
|
||||||
|
}
|
||||||
|
b.WriteString(` ` + c.SqlName + `=?`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) UpdateFullArgs() string {
|
||||||
|
cols := t.UpdateFullCols()
|
||||||
|
|
||||||
|
b := &strings.Builder{}
|
||||||
|
for i, col := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(`, `)
|
||||||
|
}
|
||||||
|
b.WriteString("row." + col.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, col := range t.pkCols() {
|
||||||
|
b.WriteString(", row." + col.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) pkCols() (cols []*column) {
|
||||||
|
for _, c := range t.Columns {
|
||||||
|
if c.PK {
|
||||||
|
cols = append(cols, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) PKFunctionArgs() string {
|
||||||
|
b := &strings.Builder{}
|
||||||
|
for _, col := range t.pkCols() {
|
||||||
|
b.WriteString(col.Name)
|
||||||
|
b.WriteString(` `)
|
||||||
|
b.WriteString(col.Type)
|
||||||
|
b.WriteString(",\n")
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) DeleteQuery() string {
|
||||||
|
cols := t.pkCols()
|
||||||
|
|
||||||
|
b := &strings.Builder{}
|
||||||
|
b.WriteString(`DELETE FROM `)
|
||||||
|
b.WriteString(t.Name)
|
||||||
|
b.WriteString(` WHERE `)
|
||||||
|
|
||||||
|
for i, col := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(` AND `)
|
||||||
|
}
|
||||||
|
b.WriteString(col.SqlName)
|
||||||
|
b.WriteString(`=?`)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) DeleteArgs() string {
|
||||||
|
cols := t.pkCols()
|
||||||
|
b := &strings.Builder{}
|
||||||
|
|
||||||
|
for i, col := range cols {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(`,`)
|
||||||
|
}
|
||||||
|
b.WriteString(col.Name)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) GetQuery() string {
|
||||||
|
b := &strings.Builder{}
|
||||||
|
b.WriteString(t.SelectQuery())
|
||||||
|
b.WriteString(` WHERE `)
|
||||||
|
for i, col := range t.pkCols() {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(` AND `)
|
||||||
|
}
|
||||||
|
b.WriteString(col.SqlName + `=?`)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *table) ScanArgs() string {
|
||||||
|
b := &strings.Builder{}
|
||||||
|
for i, col := range t.Columns {
|
||||||
|
if i != 0 {
|
||||||
|
b.WriteString(`, `)
|
||||||
|
}
|
||||||
|
b.WriteString(`&row.` + col.Name)
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
206
sqlgen/sqlite.go.tmpl
Normal file
206
sqlgen/sqlite.go.tmpl
Normal file
@ -0,0 +1,206 @@
|
|||||||
|
package {{.PackageName}}
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"iter"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TX interface {
|
||||||
|
Exec(query string, args ...any) (sql.Result, error)
|
||||||
|
Query(query string, args ...any) (*sql.Rows, error)
|
||||||
|
QueryRow(query string, args ...any) *sql.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
{{range .Schema.Tables}}
|
||||||
|
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
// Table: {{.Name}}
|
||||||
|
// ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
type {{.Type}} struct {
|
||||||
|
{{- range .Columns}}
|
||||||
|
{{.Name}} {{.Type}}{{end}}
|
||||||
|
}
|
||||||
|
|
||||||
|
const {{.Type}}_SelectQuery = "{{.SelectQuery}}"
|
||||||
|
|
||||||
|
{{if not .NoInsert -}}
|
||||||
|
|
||||||
|
func {{.Type}}_Insert(
|
||||||
|
tx TX,
|
||||||
|
row *{{.Type}},
|
||||||
|
) (err error) {
|
||||||
|
{{.Type}}_Sanitize(row)
|
||||||
|
if err = {{.Type}}_Validate(row); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
{{- end}} {{/* if not .NoInsert */}}
|
||||||
|
|
||||||
|
{{if not .NoUpdate -}}
|
||||||
|
|
||||||
|
{{if .UpdateCols -}}
|
||||||
|
|
||||||
|
func {{.Type}}_Update(
|
||||||
|
tx TX,
|
||||||
|
row *{{.Type}},
|
||||||
|
) (found bool, err error) {
|
||||||
|
{{.Type}}_Sanitize(row)
|
||||||
|
if err = {{.Type}}_Validate(row); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 1 {
|
||||||
|
panic("multiple rows updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return n != 0, nil
|
||||||
|
}
|
||||||
|
{{- end}}
|
||||||
|
|
||||||
|
{{if .UpdateFullCols -}}
|
||||||
|
|
||||||
|
func {{.Type}}_UpdateFull(
|
||||||
|
tx TX,
|
||||||
|
row *{{.Type}},
|
||||||
|
) (found bool, err error) {
|
||||||
|
{{.Type}}_Sanitize(row)
|
||||||
|
if err = {{.Type}}_Validate(row); err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 1 {
|
||||||
|
panic("multiple rows updated")
|
||||||
|
}
|
||||||
|
|
||||||
|
return n != 0, nil
|
||||||
|
}
|
||||||
|
{{- end}}
|
||||||
|
|
||||||
|
|
||||||
|
{{- end}} {{/* if not .NoUpdate */}}
|
||||||
|
|
||||||
|
{{if not .NoDelete -}}
|
||||||
|
|
||||||
|
func {{.Type}}_Delete(
|
||||||
|
tx TX,
|
||||||
|
{{.PKFunctionArgs -}}
|
||||||
|
) (found bool, err error) {
|
||||||
|
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := result.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if n > 1 {
|
||||||
|
panic("multiple rows deleted")
|
||||||
|
}
|
||||||
|
|
||||||
|
return n != 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
{{- end}}
|
||||||
|
|
||||||
|
func {{.Type}}_Get(
|
||||||
|
tx TX,
|
||||||
|
{{.PKFunctionArgs -}}
|
||||||
|
) (
|
||||||
|
row *{{.Type}},
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
row = &{{.Type}}{}
|
||||||
|
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
|
||||||
|
err = r.Scan({{.ScanArgs}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
func {{.Type}}_GetWhere(
|
||||||
|
tx TX,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (
|
||||||
|
row *{{.Type}},
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
row = &{{.Type}}{}
|
||||||
|
r := tx.QueryRow(query, args...)
|
||||||
|
err = r.Scan({{.ScanArgs}})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
func {{.Type}}_Iterate(
|
||||||
|
tx TX,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (
|
||||||
|
iter.Seq2[*{{.Type}}, error],
|
||||||
|
) {
|
||||||
|
rows, err := tx.Query(query, args...)
|
||||||
|
if err != nil {
|
||||||
|
return func(yield func(*{{.Type}}, error) bool) {
|
||||||
|
yield(nil, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(yield func(*{{.Type}}, error) bool) {
|
||||||
|
defer rows.Close()
|
||||||
|
for rows.Next() {
|
||||||
|
row := &{{.Type}}{}
|
||||||
|
err := rows.Scan({{.ScanArgs}})
|
||||||
|
if !yield(row, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func {{.Type}}_List(
|
||||||
|
tx TX,
|
||||||
|
query string,
|
||||||
|
args ...any,
|
||||||
|
) (
|
||||||
|
l []*{{.Type}},
|
||||||
|
err error,
|
||||||
|
) {
|
||||||
|
for row, err := range {{.Type}}_Iterate(tx, query, args...) {
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l = append(l, row)
|
||||||
|
}
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
{{end}} {{/* range .Schema.Tables */}}
|
36
sqlgen/template.go
Normal file
36
sqlgen/template.go
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
package sqlgen
|
||||||
|
|
||||||
|
import (
|
||||||
|
_ "embed"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"path/filepath"
|
||||||
|
"text/template"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed sqlite.go.tmpl
|
||||||
|
var sqliteTemplate string
|
||||||
|
|
||||||
|
func render(templateStr, schemaPath, outputPath string) error {
|
||||||
|
sch, err := parsePath(schemaPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := template.Must(template.New("").Parse(templateStr))
|
||||||
|
fOut, err := os.Create(outputPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer fOut.Close()
|
||||||
|
|
||||||
|
err = tmpl.Execute(fOut, struct {
|
||||||
|
PackageName string
|
||||||
|
Schema *schema
|
||||||
|
}{filepath.Base(filepath.Dir(outputPath)), sch})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return exec.Command("gofmt", "-w", outputPath).Run()
|
||||||
|
}
|
3
sqlgen/test-files/TestParse/000.def
Normal file
3
sqlgen/test-files/TestParse/000.def
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
TABLE users OF User NoDelete (
|
||||||
|
user_id string AS UserID PK
|
||||||
|
);
|
17
sqlgen/test-files/TestParse/000.json
Normal file
17
sqlgen/test-files/TestParse/000.json
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
{
|
||||||
|
"Tables": [
|
||||||
|
{
|
||||||
|
"Name": "users",
|
||||||
|
"Type": "User",
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
4
sqlgen/test-files/TestParse/001.def
Normal file
4
sqlgen/test-files/TestParse/001.def
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
TABLE users OF User NoDelete (
|
||||||
|
user_id string AS UserID PK,
|
||||||
|
email string AS Email NoUpdate
|
||||||
|
);
|
22
sqlgen/test-files/TestParse/001.json
Normal file
22
sqlgen/test-files/TestParse/001.json
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
{
|
||||||
|
"Tables": [
|
||||||
|
{
|
||||||
|
"Name": "users",
|
||||||
|
"Type": "User",
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
}, {
|
||||||
|
"Name": "Email",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "email",
|
||||||
|
"NoUpdate": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
6
sqlgen/test-files/TestParse/002.def
Normal file
6
sqlgen/test-files/TestParse/002.def
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
TABLE users OF User NoDelete (
|
||||||
|
user_id string AS UserID PK,
|
||||||
|
email string AS Email NoUpdate,
|
||||||
|
name string AS Name NoInsert,
|
||||||
|
admin bool AS Admin NoInsert NoUpdate
|
||||||
|
);
|
33
sqlgen/test-files/TestParse/002.json
Normal file
33
sqlgen/test-files/TestParse/002.json
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
{
|
||||||
|
"Tables": [
|
||||||
|
{
|
||||||
|
"Name": "users",
|
||||||
|
"Type": "User",
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
}, {
|
||||||
|
"Name": "Email",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "email",
|
||||||
|
"NoUpdate": true
|
||||||
|
}, {
|
||||||
|
"Name": "Name",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "name",
|
||||||
|
"NoInsert": true
|
||||||
|
}, {
|
||||||
|
"Name": "Admin",
|
||||||
|
"Type": "bool",
|
||||||
|
"SqlName": "admin",
|
||||||
|
"NoInsert": true,
|
||||||
|
"NoUpdate": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
12
sqlgen/test-files/TestParse/003.def
Normal file
12
sqlgen/test-files/TestParse/003.def
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
TABLE users OF User NoDelete (
|
||||||
|
user_id string AS UserID PK,
|
||||||
|
email string AS Email NoUpdate,
|
||||||
|
name string AS Name NoInsert,
|
||||||
|
admin bool AS Admin NoInsert NoUpdate
|
||||||
|
);
|
||||||
|
|
||||||
|
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
|
||||||
|
user_id string AS UserID PK,
|
||||||
|
email string AS Email,
|
||||||
|
name string AS Name
|
||||||
|
);
|
61
sqlgen/test-files/TestParse/003.json
Normal file
61
sqlgen/test-files/TestParse/003.json
Normal file
@ -0,0 +1,61 @@
|
|||||||
|
{
|
||||||
|
"Tables": [
|
||||||
|
{
|
||||||
|
"Name": "users",
|
||||||
|
"Type": "User",
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Email",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "email",
|
||||||
|
"NoUpdate": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Name",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "name",
|
||||||
|
"NoInsert": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Admin",
|
||||||
|
"Type": "bool",
|
||||||
|
"SqlName": "admin",
|
||||||
|
"NoInsert": true,
|
||||||
|
"NoUpdate": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "users_view",
|
||||||
|
"Type": "UserView",
|
||||||
|
"NoInsert": true,
|
||||||
|
"NoUpdate": true,
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Email",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Name",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "name"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
13
sqlgen/test-files/TestParse/004.def
Normal file
13
sqlgen/test-files/TestParse/004.def
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
TABLE users OF User NoDelete (
|
||||||
|
user_id string AS UserID PK,
|
||||||
|
email string AS Email NoUpdate,
|
||||||
|
name string AS Name NoInsert,
|
||||||
|
admin bool AS Admin NoInsert NoUpdate,
|
||||||
|
SSN string NoUpdate
|
||||||
|
);
|
||||||
|
|
||||||
|
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
|
||||||
|
user_id string AS UserID PK,
|
||||||
|
email string AS Email,
|
||||||
|
name string AS Name
|
||||||
|
);
|
66
sqlgen/test-files/TestParse/004.json
Normal file
66
sqlgen/test-files/TestParse/004.json
Normal file
@ -0,0 +1,66 @@
|
|||||||
|
{
|
||||||
|
"Tables": [
|
||||||
|
{
|
||||||
|
"Name": "users",
|
||||||
|
"Type": "User",
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Email",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "email",
|
||||||
|
"NoUpdate": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Name",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "name",
|
||||||
|
"NoInsert": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Admin",
|
||||||
|
"Type": "bool",
|
||||||
|
"SqlName": "admin",
|
||||||
|
"NoInsert": true,
|
||||||
|
"NoUpdate": true
|
||||||
|
}, {
|
||||||
|
"Name": "SSN",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "SSN",
|
||||||
|
"NoUpdate": true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "users_view",
|
||||||
|
"Type": "UserView",
|
||||||
|
"NoInsert": true,
|
||||||
|
"NoUpdate": true,
|
||||||
|
"NoDelete": true,
|
||||||
|
"Columns": [
|
||||||
|
{
|
||||||
|
"Name": "UserID",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "user_id",
|
||||||
|
"PK": true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Email",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "email"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"Name": "Name",
|
||||||
|
"Type": "string",
|
||||||
|
"SqlName": "name"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
45
sqliteutil/README.md
Normal file
45
sqliteutil/README.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# sqliteutil
|
||||||
|
|
||||||
|
## Transactions
|
||||||
|
|
||||||
|
Simplify postgres transactions using `WithTx` for serializable transactions,
|
||||||
|
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
|
||||||
|
type to get automatic retries of serialization errors.
|
||||||
|
|
||||||
|
## Migrations
|
||||||
|
|
||||||
|
Put your migrations into a directory, for example `migrations`, ordered by name
|
||||||
|
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
|
||||||
|
`Migrate` function:
|
||||||
|
|
||||||
|
```Go
|
||||||
|
//go:embed migrations
|
||||||
|
var migrations embed.FS
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Migrate(db, migrations) // Check the error, of course.
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
In order to test this packge, we need to create a test user and database:
|
||||||
|
|
||||||
|
```
|
||||||
|
sudo su postgres
|
||||||
|
psql
|
||||||
|
|
||||||
|
CREATE DATABASE test;
|
||||||
|
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
|
||||||
|
GRANT ALL PRIVILEGES ON DATABASE test TO test;
|
||||||
|
|
||||||
|
use test
|
||||||
|
|
||||||
|
GRANT ALL ON SCHEMA public TO test;
|
||||||
|
```
|
||||||
|
|
||||||
|
Check that you can connect via the command line:
|
||||||
|
|
||||||
|
```
|
||||||
|
psql -h 127.0.0.1 -U test --password test
|
||||||
|
```
|
5
sqliteutil/go.mod
Normal file
5
sqliteutil/go.mod
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
module git.crumpington.com/lib/sqliteutil
|
||||||
|
|
||||||
|
go 1.23.2
|
||||||
|
|
||||||
|
require github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
2
sqliteutil/go.sum
Normal file
2
sqliteutil/go.sum
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
82
sqliteutil/migrate.go
Normal file
82
sqliteutil/migrate.go
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
package sqliteutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
const initMigrationTableQuery = `
|
||||||
|
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
|
||||||
|
|
||||||
|
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
|
||||||
|
|
||||||
|
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
|
||||||
|
|
||||||
|
func Migrate(db *sql.DB, migrationFS embed.FS) error {
|
||||||
|
return WithTx(db, func(tx *sql.Tx) error {
|
||||||
|
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
dirs, err := migrationFS.ReadDir(".")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(dirs) != 1 {
|
||||||
|
return errors.New("expected a single migrations directory")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !dirs[0].IsDir() {
|
||||||
|
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
dirName := dirs[0].Name()
|
||||||
|
files, err := migrationFS.ReadDir(dirName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort sql files by name.
|
||||||
|
sort.Slice(files, func(i, j int) bool {
|
||||||
|
return files[i].Name() < files[j].Name()
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, dirEnt := range files {
|
||||||
|
if !dirEnt.Type().IsRegular() {
|
||||||
|
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
name = dirEnt.Name()
|
||||||
|
exists bool
|
||||||
|
)
|
||||||
|
|
||||||
|
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(string(migration)); err != nil {
|
||||||
|
return fmt.Errorf("migration %s failed: %v", name, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}
|
44
sqliteutil/migrate_test.go
Normal file
44
sqliteutil/migrate_test.go
Normal file
@ -0,0 +1,44 @@
|
|||||||
|
package sqliteutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"embed"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed test-migrations
|
||||||
|
var testMigrationFS embed.FS
|
||||||
|
|
||||||
|
func TestMigrate(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := Migrate(db, testMigrationFS); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shouldn't have any effect.
|
||||||
|
if err := Migrate(db, testMigrationFS); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
|
||||||
|
var exists bool
|
||||||
|
|
||||||
|
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if exists {
|
||||||
|
t.Fatal("1 shouldn't exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if !exists {
|
||||||
|
t.Fatal("2 should exist")
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
9
sqliteutil/test-migrations/000.sql
Normal file
9
sqliteutil/test-migrations/000.sql
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
CREATE TABLE users(
|
||||||
|
UserID BIGINT NOT NULL PRIMARY KEY,
|
||||||
|
Email TEXT NOT NULL UNIQUE);
|
||||||
|
|
||||||
|
CREATE TABLE user_notes(
|
||||||
|
UserID BIGINT NOT NULL REFERENCES users(UserID),
|
||||||
|
NoteID BIGINT NOT NULL,
|
||||||
|
Note Text NOT NULL,
|
||||||
|
PRIMARY KEY(UserID,NoteID));
|
1
sqliteutil/test-migrations/001.sql
Normal file
1
sqliteutil/test-migrations/001.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');
|
1
sqliteutil/test-migrations/002.sql
Normal file
1
sqliteutil/test-migrations/002.sql
Normal file
@ -0,0 +1 @@
|
|||||||
|
DELETE FROM users WHERE UserID=1;
|
28
sqliteutil/tx.go
Normal file
28
sqliteutil/tx.go
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
package sqliteutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
// This is a convenience function to run a function within a transaction.
|
||||||
|
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||||
|
// Start a transaction.
|
||||||
|
tx, err := db.Begin()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fn(tx)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
err = tx.Commit()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
_ = tx.Rollback()
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
2
tagengine/README.md
Normal file
2
tagengine/README.md
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
# tagengine
|
||||||
|
|
3
tagengine/go.mod
Normal file
3
tagengine/go.mod
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
module git.crumpington.com/lib/tagengine
|
||||||
|
|
||||||
|
go 1.23.2
|
0
tagengine/go.sum
Normal file
0
tagengine/go.sum
Normal file
30
tagengine/ngram.go
Normal file
30
tagengine/ngram.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import "unicode"
|
||||||
|
|
||||||
|
func ngramLength(s string) int {
|
||||||
|
N := len(s)
|
||||||
|
i := 0
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Eat spaces.
|
||||||
|
for i < N && unicode.IsSpace(rune(s[i])) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Done?
|
||||||
|
if i == N {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// Non-space!
|
||||||
|
count++
|
||||||
|
|
||||||
|
// Eat non-spaces.
|
||||||
|
for i < N && !unicode.IsSpace(rune(s[i])) {
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return count
|
||||||
|
}
|
31
tagengine/ngram_test.go
Normal file
31
tagengine/ngram_test.go
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNGramLength(t *testing.T) {
|
||||||
|
type Case struct {
|
||||||
|
Input string
|
||||||
|
Length int
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []Case{
|
||||||
|
{"a b c", 3},
|
||||||
|
{" xyz\nlkj dflaj a", 4},
|
||||||
|
{"a", 1},
|
||||||
|
{" a", 1},
|
||||||
|
{"a", 1},
|
||||||
|
{" a\n", 1},
|
||||||
|
{" a ", 1},
|
||||||
|
{"\tx\ny\nz q ", 4},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
length := ngramLength(tc.Input)
|
||||||
|
if length != tc.Length {
|
||||||
|
log.Fatalf("%s: %d != %d", tc.Input, length, tc.Length)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
79
tagengine/node.go
Normal file
79
tagengine/node.go
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type node struct {
|
||||||
|
Token string
|
||||||
|
Matches []*Rule // If a list of tokens reaches this node, it matches these.
|
||||||
|
Children map[string]*node
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *node) AddRule(r *Rule) {
|
||||||
|
n.addRule(r, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *node) addRule(r *Rule, idx int) {
|
||||||
|
if len(r.Includes) == idx {
|
||||||
|
n.Matches = append(n.Matches, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
token := r.Includes[idx]
|
||||||
|
|
||||||
|
child, ok := n.Children[token]
|
||||||
|
if !ok {
|
||||||
|
child = &node{
|
||||||
|
Token: token,
|
||||||
|
Children: map[string]*node{},
|
||||||
|
}
|
||||||
|
n.Children[token] = child
|
||||||
|
}
|
||||||
|
|
||||||
|
child.addRule(r, idx+1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note that tokens must be sorted. This is the case for tokens created from
|
||||||
|
// the tokenize function.
|
||||||
|
func (n *node) Match(tokens []string) (rules []*Rule) {
|
||||||
|
return n.match(tokens, rules)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *node) match(tokens []string, rules []*Rule) []*Rule {
|
||||||
|
// Check for a match.
|
||||||
|
if n.Matches != nil {
|
||||||
|
rules = append(rules, n.Matches...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(tokens) == 0 {
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to match children.
|
||||||
|
for i := 0; i < len(tokens); i++ {
|
||||||
|
token := tokens[i]
|
||||||
|
if child, ok := n.Children[token]; ok {
|
||||||
|
rules = child.match(tokens[i+1:], rules)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *node) Dump() {
|
||||||
|
n.dump(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *node) dump(depth int) {
|
||||||
|
indent := strings.Repeat(" ", 2*depth)
|
||||||
|
tag := ""
|
||||||
|
for _, m := range n.Matches {
|
||||||
|
tag += " " + m.Tag
|
||||||
|
}
|
||||||
|
fmt.Printf("%s%s%s\n", indent, n.Token, tag)
|
||||||
|
for _, child := range n.Children {
|
||||||
|
child.dump(depth + 1)
|
||||||
|
}
|
||||||
|
}
|
159
tagengine/rule.go
Normal file
159
tagengine/rule.go
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
type Rule struct {
|
||||||
|
// The purpose of a Rule is to attach it's Tag to matching text.
|
||||||
|
Tag string
|
||||||
|
|
||||||
|
// Includes is a list of strings that must be found in the input in order to
|
||||||
|
// match.
|
||||||
|
Includes []string
|
||||||
|
|
||||||
|
// Excludes is a list of strings that can exclude a match for this rule.
|
||||||
|
Excludes []string
|
||||||
|
|
||||||
|
// Blocks: If this rule is matched, then it will block matches of any tags
|
||||||
|
// listed here.
|
||||||
|
Blocks []string
|
||||||
|
|
||||||
|
// The Score encodes the complexity of the Rule. A higher score indicates a
|
||||||
|
// more specific match. A Rule more includes, or includes with multiple words
|
||||||
|
// should havee a higher Score than a Rule with fewer includes or less
|
||||||
|
// complex includes.
|
||||||
|
Score int
|
||||||
|
|
||||||
|
excludes map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRule(tag string) Rule {
|
||||||
|
return Rule{Tag: tag}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) Inc(l ...string) Rule {
|
||||||
|
return Rule{
|
||||||
|
Tag: r.Tag,
|
||||||
|
Includes: append(r.Includes, l...),
|
||||||
|
Excludes: r.Excludes,
|
||||||
|
Blocks: r.Blocks,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) Exc(l ...string) Rule {
|
||||||
|
return Rule{
|
||||||
|
Tag: r.Tag,
|
||||||
|
Includes: r.Includes,
|
||||||
|
Excludes: append(r.Excludes, l...),
|
||||||
|
Blocks: r.Blocks,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) Block(l ...string) Rule {
|
||||||
|
return Rule{
|
||||||
|
Tag: r.Tag,
|
||||||
|
Includes: r.Includes,
|
||||||
|
Excludes: r.Excludes,
|
||||||
|
Blocks: append(r.Blocks, l...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rule *Rule) normalize(sanitize func(string) string) {
|
||||||
|
for i, token := range rule.Includes {
|
||||||
|
rule.Includes[i] = sanitize(token)
|
||||||
|
}
|
||||||
|
for i, token := range rule.Excludes {
|
||||||
|
rule.Excludes[i] = sanitize(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
sortTokens(rule.Includes)
|
||||||
|
sortTokens(rule.Excludes)
|
||||||
|
|
||||||
|
rule.excludes = map[string]struct{}{}
|
||||||
|
for _, s := range rule.Excludes {
|
||||||
|
rule.excludes[s] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
rule.Score = rule.computeScore()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) maxNGram() int {
|
||||||
|
max := 0
|
||||||
|
for _, s := range r.Includes {
|
||||||
|
n := ngramLength(s)
|
||||||
|
if n > max {
|
||||||
|
max = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, s := range r.Excludes {
|
||||||
|
n := ngramLength(s)
|
||||||
|
if n > max {
|
||||||
|
max = n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return max
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) isExcluded(tokens []string) bool {
|
||||||
|
// This is most often the case.
|
||||||
|
if len(r.excludes) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, s := range tokens {
|
||||||
|
if _, ok := r.excludes[s]; ok {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r Rule) computeScore() (score int) {
|
||||||
|
for _, token := range r.Includes {
|
||||||
|
n := ngramLength(token)
|
||||||
|
score += n * (n + 1) / 2
|
||||||
|
}
|
||||||
|
return score
|
||||||
|
}
|
||||||
|
|
||||||
|
func ruleLess(lhs, rhs *Rule) bool {
|
||||||
|
// If scores differ, sort by score.
|
||||||
|
if lhs.Score != rhs.Score {
|
||||||
|
return lhs.Score < rhs.Score
|
||||||
|
}
|
||||||
|
|
||||||
|
// If include depth differs, sort by depth.
|
||||||
|
lDepth := len(lhs.Includes)
|
||||||
|
rDepth := len(rhs.Includes)
|
||||||
|
|
||||||
|
if lDepth != rDepth {
|
||||||
|
return lDepth < rDepth
|
||||||
|
}
|
||||||
|
|
||||||
|
// If exclude depth differs, sort by depth.
|
||||||
|
lDepth = len(lhs.Excludes)
|
||||||
|
rDepth = len(rhs.Excludes)
|
||||||
|
|
||||||
|
if lDepth != rDepth {
|
||||||
|
return lDepth < rDepth
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort alphabetically by includes.
|
||||||
|
for i := range lhs.Includes {
|
||||||
|
if lhs.Includes[i] != rhs.Includes[i] {
|
||||||
|
return lhs.Includes[i] < rhs.Includes[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by alphabetically by excludes.
|
||||||
|
for i := range lhs.Excludes {
|
||||||
|
if lhs.Excludes[i] != rhs.Excludes[i] {
|
||||||
|
return lhs.Excludes[i] < rhs.Excludes[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by tag.
|
||||||
|
if lhs.Tag != rhs.Tag {
|
||||||
|
return lhs.Tag < rhs.Tag
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
58
tagengine/rulegroup.go
Normal file
58
tagengine/rulegroup.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
// A RuleGroup can be converted into a list of rules. Each rule will point to
|
||||||
|
// the same tag, and have the same exclude set and blocks.
|
||||||
|
type RuleGroup struct {
|
||||||
|
Tag string
|
||||||
|
Includes [][]string
|
||||||
|
Excludes []string
|
||||||
|
Blocks []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRuleGroup(tag string) RuleGroup {
|
||||||
|
return RuleGroup{
|
||||||
|
Tag: tag,
|
||||||
|
Includes: [][]string{},
|
||||||
|
Excludes: []string{},
|
||||||
|
Blocks: []string{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g RuleGroup) Inc(l ...string) RuleGroup {
|
||||||
|
return RuleGroup{
|
||||||
|
Tag: g.Tag,
|
||||||
|
Includes: append(g.Includes, l),
|
||||||
|
Excludes: g.Excludes,
|
||||||
|
Blocks: g.Blocks,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g RuleGroup) Exc(l ...string) RuleGroup {
|
||||||
|
return RuleGroup{
|
||||||
|
Tag: g.Tag,
|
||||||
|
Includes: g.Includes,
|
||||||
|
Excludes: append(g.Excludes, l...),
|
||||||
|
Blocks: g.Blocks,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g RuleGroup) Block(l ...string) RuleGroup {
|
||||||
|
return RuleGroup{
|
||||||
|
Tag: g.Tag,
|
||||||
|
Includes: g.Includes,
|
||||||
|
Excludes: g.Excludes,
|
||||||
|
Blocks: append(g.Blocks, l...),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g RuleGroup) ToList() (l []Rule) {
|
||||||
|
for _, includes := range g.Includes {
|
||||||
|
l = append(l, Rule{
|
||||||
|
Tag: g.Tag,
|
||||||
|
Excludes: g.Excludes,
|
||||||
|
Includes: includes,
|
||||||
|
Blocks: g.Blocks,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
162
tagengine/ruleset.go
Normal file
162
tagengine/ruleset.go
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
)
|
||||||
|
|
||||||
|
type RuleSet struct {
|
||||||
|
root *node
|
||||||
|
maxNgram int
|
||||||
|
sanitize func(string) string
|
||||||
|
rules []*Rule
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRuleSet() *RuleSet {
|
||||||
|
return &RuleSet{
|
||||||
|
root: &node{
|
||||||
|
Token: "/",
|
||||||
|
Children: map[string]*node{},
|
||||||
|
},
|
||||||
|
sanitize: BasicSanitizer,
|
||||||
|
rules: []*Rule{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRuleSetFromList(rules []Rule) *RuleSet {
|
||||||
|
rs := NewRuleSet()
|
||||||
|
rs.AddRule(rules...)
|
||||||
|
return rs
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *RuleSet) Add(ruleOrGroup ...interface{}) {
|
||||||
|
for _, ix := range ruleOrGroup {
|
||||||
|
switch x := ix.(type) {
|
||||||
|
case Rule:
|
||||||
|
t.AddRule(x)
|
||||||
|
case RuleGroup:
|
||||||
|
t.AddRuleGroup(x)
|
||||||
|
default:
|
||||||
|
panic("Add expects either Rule or RuleGroup objects.")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *RuleSet) AddRule(rules ...Rule) {
|
||||||
|
for _, rule := range rules {
|
||||||
|
rule := rule
|
||||||
|
|
||||||
|
// Make sure rule is well-formed.
|
||||||
|
rule.normalize(t.sanitize)
|
||||||
|
|
||||||
|
// Update maxNgram.
|
||||||
|
N := rule.maxNGram()
|
||||||
|
if N > t.maxNgram {
|
||||||
|
t.maxNgram = N
|
||||||
|
}
|
||||||
|
|
||||||
|
t.rules = append(t.rules, &rule)
|
||||||
|
t.root.AddRule(&rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) {
|
||||||
|
for _, rg := range ruleGroups {
|
||||||
|
t.AddRule(rg.ToList()...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MatchRules will return a list of all matching rules. The rules are sorted by
|
||||||
|
// the match's Score. The best match will be first.
|
||||||
|
func (t *RuleSet) MatchRules(input string) (rules []*Rule) {
|
||||||
|
input = t.sanitize(input)
|
||||||
|
tokens := Tokenize(input, t.maxNgram)
|
||||||
|
|
||||||
|
rules = t.root.Match(tokens)
|
||||||
|
if len(rules) == 0 {
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check excludes.
|
||||||
|
l := rules[:0]
|
||||||
|
for _, r := range rules {
|
||||||
|
if !r.isExcluded(tokens) {
|
||||||
|
l = append(l, r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
rules = l
|
||||||
|
|
||||||
|
// Sort rules descending.
|
||||||
|
sort.Slice(rules, func(i, j int) bool {
|
||||||
|
return ruleLess(rules[j], rules[i])
|
||||||
|
})
|
||||||
|
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
type Match struct {
|
||||||
|
Tag string
|
||||||
|
|
||||||
|
// Confidence is used to sort all matches, and is normalized so the sum of
|
||||||
|
// Confidence values for all matches is 1. Confidence is relative to the
|
||||||
|
// number of matches and the size of matches in terms of number of tokens.
|
||||||
|
Confidence float64 // In the range (0,1].
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a list of matches with confidence. This is useful if you'd like to
|
||||||
|
// find the best matching rule out of all the matched rules.
|
||||||
|
//
|
||||||
|
// If you just want to find all matching rules, then use MatchRules.
|
||||||
|
func (t *RuleSet) Match(input string) []Match {
|
||||||
|
rules := t.MatchRules(input)
|
||||||
|
if len(rules) == 0 {
|
||||||
|
return []Match{}
|
||||||
|
}
|
||||||
|
if len(rules) == 1 {
|
||||||
|
return []Match{{
|
||||||
|
Tag: rules[0].Tag,
|
||||||
|
Confidence: 1,
|
||||||
|
}}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create list of blocked tags.
|
||||||
|
blocks := map[string]struct{}{}
|
||||||
|
for _, rule := range rules {
|
||||||
|
for _, tag := range rule.Blocks {
|
||||||
|
blocks[tag] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove rules for blocked tags.
|
||||||
|
iOut := 0
|
||||||
|
for _, rule := range rules {
|
||||||
|
if _, ok := blocks[rule.Tag]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rules[iOut] = rule
|
||||||
|
iOut++
|
||||||
|
}
|
||||||
|
rules = rules[:iOut]
|
||||||
|
|
||||||
|
// Matches by index.
|
||||||
|
matches := map[string]int{}
|
||||||
|
out := []Match{}
|
||||||
|
sum := float64(0)
|
||||||
|
|
||||||
|
for _, rule := range rules {
|
||||||
|
idx, ok := matches[rule.Tag]
|
||||||
|
if !ok {
|
||||||
|
idx = len(matches)
|
||||||
|
matches[rule.Tag] = idx
|
||||||
|
out = append(out, Match{Tag: rule.Tag})
|
||||||
|
}
|
||||||
|
out[idx].Confidence += float64(rule.Score)
|
||||||
|
sum += float64(rule.Score)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range out {
|
||||||
|
out[i].Confidence /= sum
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
84
tagengine/ruleset_test.go
Normal file
84
tagengine/ruleset_test.go
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRulesSet(t *testing.T) {
|
||||||
|
rs := NewRuleSet()
|
||||||
|
rs.AddRule(Rule{
|
||||||
|
Tag: "cc/2",
|
||||||
|
Includes: []string{"cola", "coca"},
|
||||||
|
})
|
||||||
|
rs.AddRule(Rule{
|
||||||
|
Tag: "cc/0",
|
||||||
|
Includes: []string{"coca cola"},
|
||||||
|
})
|
||||||
|
rs.AddRule(Rule{
|
||||||
|
Tag: "cz/2",
|
||||||
|
Includes: []string{"coca", "zero"},
|
||||||
|
})
|
||||||
|
rs.AddRule(Rule{
|
||||||
|
Tag: "cc0/3",
|
||||||
|
Includes: []string{"zero", "coca", "cola"},
|
||||||
|
})
|
||||||
|
rs.AddRule(Rule{
|
||||||
|
Tag: "cc0/3.1",
|
||||||
|
Includes: []string{"coca", "cola", "zero"},
|
||||||
|
Excludes: []string{"pepsi"},
|
||||||
|
})
|
||||||
|
rs.AddRule(Rule{
|
||||||
|
Tag: "spa",
|
||||||
|
Includes: []string{"spa"},
|
||||||
|
Blocks: []string{"cc/0", "cc0/3", "cc0/3.1"},
|
||||||
|
})
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
Input string
|
||||||
|
Matches []Match
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []TestCase{
|
||||||
|
{
|
||||||
|
Input: "coca-cola zero",
|
||||||
|
Matches: []Match{
|
||||||
|
{"cc0/3.1", 0.3},
|
||||||
|
{"cc0/3", 0.3},
|
||||||
|
{"cz/2", 0.2},
|
||||||
|
{"cc/2", 0.2},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Input: "coca cola",
|
||||||
|
Matches: []Match{
|
||||||
|
{"cc/0", 0.6},
|
||||||
|
{"cc/2", 0.4},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Input: "coca cola zero pepsi",
|
||||||
|
Matches: []Match{
|
||||||
|
{"cc0/3", 0.3},
|
||||||
|
{"cc/0", 0.3},
|
||||||
|
{"cz/2", 0.2},
|
||||||
|
{"cc/2", 0.2},
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Input: "fanta orange",
|
||||||
|
Matches: []Match{},
|
||||||
|
}, {
|
||||||
|
Input: "coca-cola zero / fanta / spa",
|
||||||
|
Matches: []Match{
|
||||||
|
{"cz/2", 0.4},
|
||||||
|
{"cc/2", 0.4},
|
||||||
|
{"spa", 0.2},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
matches := rs.Match(tc.Input)
|
||||||
|
if !reflect.DeepEqual(matches, tc.Matches) {
|
||||||
|
t.Fatalf("%v != %v", matches, tc.Matches)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
20
tagengine/sanitize.go
Normal file
20
tagengine/sanitize.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.crumpington.com/lib/tagengine/sanitize"
|
||||||
|
)
|
||||||
|
|
||||||
|
// The basic sanitizer:
|
||||||
|
// * lower-case
|
||||||
|
// * put spaces around numbers
|
||||||
|
// * put slaces around punctuation
|
||||||
|
// * collapse multiple spaces
|
||||||
|
func BasicSanitizer(s string) string {
|
||||||
|
s = strings.ToLower(s)
|
||||||
|
s = sanitize.SpaceNumbers(s)
|
||||||
|
s = sanitize.SpacePunctuation(s)
|
||||||
|
s = sanitize.CollapseSpaces(s)
|
||||||
|
return s
|
||||||
|
}
|
91
tagengine/sanitize/sanitize.go
Normal file
91
tagengine/sanitize/sanitize.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
package sanitize
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
"unicode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SpaceNumbers(s string) string {
|
||||||
|
if len(s) == 0 {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
isDigit := func(b rune) bool {
|
||||||
|
switch b {
|
||||||
|
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
b := strings.Builder{}
|
||||||
|
|
||||||
|
var first rune
|
||||||
|
for _, c := range s {
|
||||||
|
first = c
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
digit := isDigit(first)
|
||||||
|
|
||||||
|
// Range over runes.
|
||||||
|
for _, c := range s {
|
||||||
|
thisDigit := isDigit(c)
|
||||||
|
if thisDigit != digit {
|
||||||
|
b.WriteByte(' ')
|
||||||
|
digit = thisDigit
|
||||||
|
}
|
||||||
|
b.WriteRune(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func SpacePunctuation(s string) string {
|
||||||
|
needsSpace := func(r rune) bool {
|
||||||
|
switch r {
|
||||||
|
case '`', '~', '!', '@', '#', '%', '^', '&', '*', '(', ')',
|
||||||
|
'-', '_', '+', '=', '[', '{', ']', '}', '\\', '|',
|
||||||
|
':', ';', '"', '\'', ',', '<', '.', '>', '?', '/':
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
b := strings.Builder{}
|
||||||
|
|
||||||
|
// Range over runes.
|
||||||
|
for _, r := range s {
|
||||||
|
if needsSpace(r) {
|
||||||
|
b.WriteRune(' ')
|
||||||
|
b.WriteRune(r)
|
||||||
|
b.WriteRune(' ')
|
||||||
|
} else {
|
||||||
|
b.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func CollapseSpaces(s string) string {
|
||||||
|
// Trim leading and trailing spaces.
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
b := strings.Builder{}
|
||||||
|
wasSpace := false
|
||||||
|
|
||||||
|
// Range over runes.
|
||||||
|
for _, c := range s {
|
||||||
|
if unicode.IsSpace(c) {
|
||||||
|
wasSpace = true
|
||||||
|
continue
|
||||||
|
} else if wasSpace {
|
||||||
|
wasSpace = false
|
||||||
|
b.WriteRune(' ')
|
||||||
|
}
|
||||||
|
b.WriteRune(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
return b.String()
|
||||||
|
}
|
30
tagengine/sanitize_test.go
Normal file
30
tagengine/sanitize_test.go
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestSanitize(t *testing.T) {
|
||||||
|
sanitize := BasicSanitizer
|
||||||
|
|
||||||
|
type Case struct {
|
||||||
|
In string
|
||||||
|
Out string
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []Case{
|
||||||
|
{"", ""},
|
||||||
|
{"123abc", "123 abc"},
|
||||||
|
{"abc123", "abc 123"},
|
||||||
|
{"abc123xyz", "abc 123 xyz"},
|
||||||
|
{"1f2", "1 f 2"},
|
||||||
|
{" abc", "abc"},
|
||||||
|
{" ; KitKat/m&m's (bottle) @ ", "; kitkat / m & m ' s ( bottle ) @"},
|
||||||
|
{"€", "€"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
out := sanitize(tc.In)
|
||||||
|
if out != tc.Out {
|
||||||
|
t.Fatalf("%v != %v", out, tc.Out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
63
tagengine/tokenize.go
Normal file
63
tagengine/tokenize.go
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ignoreTokens = map[string]struct{}{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// These on their own are ignored.
|
||||||
|
tokens := []string{
|
||||||
|
"`", `~`, `!`, `@`, `#`, `%`, `^`, `&`, `*`, `(`, `)`,
|
||||||
|
`-`, `_`, `+`, `=`, `[`, `{`, `]`, `}`, `\`, `|`,
|
||||||
|
`:`, `;`, `"`, `'`, `,`, `<`, `.`, `>`, `?`, `/`,
|
||||||
|
}
|
||||||
|
for _, s := range tokens {
|
||||||
|
ignoreTokens[s] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Tokenize(
|
||||||
|
input string,
|
||||||
|
maxNgram int,
|
||||||
|
) (
|
||||||
|
tokens []string,
|
||||||
|
) {
|
||||||
|
// Avoid duplicate ngrams.
|
||||||
|
ignored := map[string]bool{}
|
||||||
|
|
||||||
|
fields := strings.Fields(input)
|
||||||
|
|
||||||
|
if len(fields) < maxNgram {
|
||||||
|
maxNgram = len(fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < maxNgram+1; i++ {
|
||||||
|
jMax := len(fields) - i + 1
|
||||||
|
|
||||||
|
for j := 0; j < jMax; j++ {
|
||||||
|
ngram := strings.Join(fields[j:i+j], " ")
|
||||||
|
if _, ok := ignoreTokens[ngram]; !ok {
|
||||||
|
if _, ok := ignored[ngram]; !ok {
|
||||||
|
tokens = append(tokens, ngram)
|
||||||
|
ignored[ngram] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sortTokens(tokens)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortTokens(tokens []string) {
|
||||||
|
sort.Slice(tokens, func(i, j int) bool {
|
||||||
|
if len(tokens[i]) != len(tokens[j]) {
|
||||||
|
return len(tokens[i]) < len(tokens[j])
|
||||||
|
}
|
||||||
|
return tokens[i] < tokens[j]
|
||||||
|
})
|
||||||
|
}
|
55
tagengine/tokenize_test.go
Normal file
55
tagengine/tokenize_test.go
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
package tagengine
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTokenize(t *testing.T) {
|
||||||
|
type Case struct {
|
||||||
|
Input string
|
||||||
|
MaxNgram int
|
||||||
|
Output []string
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []Case{
|
||||||
|
{
|
||||||
|
Input: "a bb c d",
|
||||||
|
MaxNgram: 3,
|
||||||
|
Output: []string{
|
||||||
|
"a", "c", "d", "bb",
|
||||||
|
"c d", "a bb", "bb c",
|
||||||
|
"a bb c", "bb c d",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Input: "a b",
|
||||||
|
MaxNgram: 3,
|
||||||
|
Output: []string{
|
||||||
|
"a", "b", "a b",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Input: "- b c d",
|
||||||
|
MaxNgram: 3,
|
||||||
|
Output: []string{
|
||||||
|
"b", "c", "d",
|
||||||
|
"- b", "b c", "c d",
|
||||||
|
"- b c", "b c d",
|
||||||
|
},
|
||||||
|
}, {
|
||||||
|
Input: "a a b c d c d",
|
||||||
|
MaxNgram: 3,
|
||||||
|
Output: []string{
|
||||||
|
"a", "b", "c", "d",
|
||||||
|
"a a", "a b", "b c", "c d", "d c",
|
||||||
|
"a a b", "a b c", "b c d", "c d c", "d c d",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
output := Tokenize(tc.Input, tc.MaxNgram)
|
||||||
|
if !reflect.DeepEqual(output, tc.Output) {
|
||||||
|
t.Fatalf("%s: %#v", tc.Input, output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
5
webutil/README.md
Normal file
5
webutil/README.md
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
# webutil
|
||||||
|
|
||||||
|
## Roadmap
|
||||||
|
|
||||||
|
* logging middleware
|
10
webutil/go.mod
Normal file
10
webutil/go.mod
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
module git.crumpington.com/lib/webutil
|
||||||
|
|
||||||
|
go 1.23.2
|
||||||
|
|
||||||
|
require golang.org/x/crypto v0.28.0
|
||||||
|
|
||||||
|
require (
|
||||||
|
golang.org/x/net v0.21.0 // indirect
|
||||||
|
golang.org/x/text v0.19.0 // indirect
|
||||||
|
)
|
6
webutil/go.sum
Normal file
6
webutil/go.sum
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||||
|
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||||
|
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||||
|
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||||
|
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||||
|
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
24
webutil/listenandserve.go
Normal file
24
webutil/listenandserve.go
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
package webutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/acme/autocert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Serve requests using the given http.Server. If srv.Addr has the format
|
||||||
|
// `hostname:https`, then use autocert to manage certificates for the domain.
|
||||||
|
//
|
||||||
|
// For http on port 80, you can use :http.
|
||||||
|
func ListenAndServe(srv *http.Server) error {
|
||||||
|
if strings.HasSuffix(srv.Addr, ":https") {
|
||||||
|
hostname := strings.TrimSuffix(srv.Addr, ":https")
|
||||||
|
if len(hostname) == 0 {
|
||||||
|
return errors.New("https requires a hostname")
|
||||||
|
}
|
||||||
|
return srv.Serve(autocert.NewListener(hostname))
|
||||||
|
}
|
||||||
|
return srv.ListenAndServe()
|
||||||
|
}
|
47
webutil/middleware-logging.go
Normal file
47
webutil/middleware-logging.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package webutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _log = log.New(os.Stderr, "", 0)
|
||||||
|
|
||||||
|
type responseWriterWrapper struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
httpStatus int
|
||||||
|
responseSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseWriterWrapper) WriteHeader(status int) {
|
||||||
|
w.httpStatus = status
|
||||||
|
w.ResponseWriter.WriteHeader(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||||
|
if w.httpStatus == 0 {
|
||||||
|
w.httpStatus = 200
|
||||||
|
}
|
||||||
|
w.responseSize += len(b)
|
||||||
|
return w.ResponseWriter.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithLogging(inner http.HandlerFunc) http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
t := time.Now()
|
||||||
|
wrapper := responseWriterWrapper{w, 0, 0}
|
||||||
|
|
||||||
|
inner(&wrapper, r)
|
||||||
|
_log.Printf("%s \"%s %s %s\" %d %d %v\n",
|
||||||
|
r.RemoteAddr,
|
||||||
|
r.Method,
|
||||||
|
r.URL.Path,
|
||||||
|
r.Proto,
|
||||||
|
wrapper.httpStatus,
|
||||||
|
wrapper.responseSize,
|
||||||
|
time.Since(t),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
100
webutil/template.go
Normal file
100
webutil/template.go
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
package webutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"html/template"
|
||||||
|
"io/fs"
|
||||||
|
"log"
|
||||||
|
"path"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseTemplateSet parses sets of templates from an embed.FS.
|
||||||
|
//
|
||||||
|
// Each directory constitutes a set of templates that are parsed together.
|
||||||
|
//
|
||||||
|
// Structure (within a directory):
|
||||||
|
// - share/* are always parsed.
|
||||||
|
// - base.html will be parsed with each other file in same dir
|
||||||
|
//
|
||||||
|
// Call a template with m[path].Execute(w, data) (root dir name is excluded).
|
||||||
|
//
|
||||||
|
// For example, if you have
|
||||||
|
// - /user/share/*
|
||||||
|
// - /user/base.html
|
||||||
|
// - /user/home.html
|
||||||
|
//
|
||||||
|
// Then you call m["/user/home.html"].Execute(w, data).
|
||||||
|
func ParseTemplateSet(funcs template.FuncMap, fs embed.FS) map[string]*template.Template {
|
||||||
|
m := map[string]*template.Template{}
|
||||||
|
rootDir := readDir(fs, ".")[0].Name()
|
||||||
|
loadTemplateDir(fs, funcs, m, rootDir, rootDir)
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func loadTemplateDir(
|
||||||
|
fs embed.FS,
|
||||||
|
funcs template.FuncMap,
|
||||||
|
m map[string]*template.Template,
|
||||||
|
dirPath string,
|
||||||
|
rootDir string,
|
||||||
|
) map[string]*template.Template {
|
||||||
|
t := template.New("")
|
||||||
|
if funcs != nil {
|
||||||
|
t = t.Funcs(funcs)
|
||||||
|
}
|
||||||
|
|
||||||
|
shareDir := path.Join(dirPath, "share")
|
||||||
|
if _, err := fs.ReadDir(shareDir); err == nil {
|
||||||
|
log.Printf("Parsing %s...", path.Join(shareDir, "*"))
|
||||||
|
t = template.Must(t.ParseFS(fs, path.Join(shareDir, "*")))
|
||||||
|
}
|
||||||
|
|
||||||
|
if data, _ := fs.ReadFile(path.Join(dirPath, "base.html")); data != nil {
|
||||||
|
log.Printf("Parsing %s...", path.Join(dirPath, "base.html"))
|
||||||
|
t = template.Must(t.Parse(string(data)))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, ent := range readDir(fs, dirPath) {
|
||||||
|
if ent.Type().IsDir() {
|
||||||
|
if ent.Name() != "share" {
|
||||||
|
m = loadTemplateDir(fs, funcs, m, path.Join(dirPath, ent.Name()), rootDir)
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ent.Type().IsRegular() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ent.Name() == "base.html" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
filePath := path.Join(dirPath, ent.Name())
|
||||||
|
log.Printf("Parsing %s...", filePath)
|
||||||
|
|
||||||
|
key := strings.TrimPrefix(path.Join(dirPath, ent.Name()), rootDir)
|
||||||
|
tt := template.Must(t.Clone())
|
||||||
|
tt = template.Must(tt.Parse(readFile(fs, filePath)))
|
||||||
|
m[key] = tt
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func readDir(fs embed.FS, dirPath string) []fs.DirEntry {
|
||||||
|
ents, err := fs.ReadDir(dirPath)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return ents
|
||||||
|
}
|
||||||
|
|
||||||
|
func readFile(fs embed.FS, path string) string {
|
||||||
|
data, err := fs.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
49
webutil/template_test.go
Normal file
49
webutil/template_test.go
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
package webutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"embed"
|
||||||
|
"html/template"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed all:test-templates
|
||||||
|
var testFS embed.FS
|
||||||
|
|
||||||
|
func TestParseTemplateSet(t *testing.T) {
|
||||||
|
funcs := template.FuncMap{"join": strings.Join}
|
||||||
|
m := ParseTemplateSet(funcs, testFS)
|
||||||
|
|
||||||
|
type TestCase struct {
|
||||||
|
Key string
|
||||||
|
Data any
|
||||||
|
Out string
|
||||||
|
}
|
||||||
|
|
||||||
|
cases := []TestCase{
|
||||||
|
{
|
||||||
|
Key: "/home.html",
|
||||||
|
Data: "DATA",
|
||||||
|
Out: "<p>HOME!</p>",
|
||||||
|
}, {
|
||||||
|
Key: "/about.html",
|
||||||
|
Data: "DATA",
|
||||||
|
Out: "<p><b>DATA</b></p>",
|
||||||
|
}, {
|
||||||
|
Key: "/contact.html",
|
||||||
|
Data: []string{"a", "b", "c"},
|
||||||
|
Out: "<p>a,b,c</p>",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
b := &bytes.Buffer{}
|
||||||
|
m[tc.Key].Execute(b, tc.Data)
|
||||||
|
out := strings.TrimSpace(b.String())
|
||||||
|
if out != tc.Out {
|
||||||
|
t.Fatalf("%s != %s", out, tc.Out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
1
webutil/test-templates/about.html
Normal file
1
webutil/test-templates/about.html
Normal file
@ -0,0 +1 @@
|
|||||||
|
{{define "body"}}{{template "bold" .}}{{end}}
|
1
webutil/test-templates/base.html
Normal file
1
webutil/test-templates/base.html
Normal file
@ -0,0 +1 @@
|
|||||||
|
<p>{{block "body" .}}default{{end}}</p>
|
1
webutil/test-templates/contact.html
Normal file
1
webutil/test-templates/contact.html
Normal file
@ -0,0 +1 @@
|
|||||||
|
{{define "body"}}{{join . ","}}{{end}}
|
1
webutil/test-templates/home.html
Normal file
1
webutil/test-templates/home.html
Normal file
@ -0,0 +1 @@
|
|||||||
|
{{define "body"}}HOME!{{end}}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user