This commit is contained in:
jdl 2024-11-11 06:36:55 +01:00
parent d0587cc585
commit c5419d662e
102 changed files with 4181 additions and 0 deletions

2
flock/README.md Normal file
View File

@ -0,0 +1,2 @@
# flock

69
flock/flock.go Normal file
View File

@ -0,0 +1,69 @@
// The flock package provides a file-system mediated locking mechanism on linux
// using the `flock` system call.
package flock
import (
"errors"
"os"
"syscall"
)
// Lock gets an exclusive lock on the file at the given path. If the file
// doesn't exist, it's created.
func Lock(path string) (*os.File, error) {
return lock(path, syscall.LOCK_EX)
}
// TryLock will return a nil file if the file is already locked.
func TryLock(path string) (*os.File, error) {
return lock(path, syscall.LOCK_EX|syscall.LOCK_NB)
}
func LockFile(f *os.File) error {
_, err := lockFile(f, syscall.LOCK_EX)
return err
}
// Returns true if the lock was successfully acquired.
func TryLockFile(f *os.File) (bool, error) {
return lockFile(f, syscall.LOCK_EX|syscall.LOCK_NB)
}
func lockFile(f *os.File, flags int) (bool, error) {
if err := flock(int(f.Fd()), flags); err != nil {
if flags&syscall.LOCK_NB != 0 && errors.Is(err, syscall.EAGAIN) {
return false, nil
}
return false, err
}
return true, nil
}
func flock(fd int, how int) error {
_, _, e1 := syscall.Syscall(syscall.SYS_FLOCK, uintptr(fd), uintptr(how), 0)
if e1 != 0 {
return syscall.Errno(e1)
}
return nil
}
func lock(path string, flags int) (*os.File, error) {
perm := os.O_CREATE | os.O_RDWR
f, err := os.OpenFile(path, perm, 0600)
if err != nil {
return nil, err
}
ok, err := lockFile(f, flags)
if err != nil || !ok {
f.Close()
f = nil
}
return f, err
}
// Unlock releases the lock acquired via the Lock function.
func Unlock(f *os.File) error {
return f.Close()
}

66
flock/flock_test.go Normal file
View File

@ -0,0 +1,66 @@
package flock
import (
"testing"
"time"
)
func Test_Lock_basic(t *testing.T) {
ch := make(chan int, 1)
f, err := Lock("/tmp/fsutil-test-lock")
if err != nil {
t.Fatal(err)
}
go func() {
time.Sleep(time.Second)
ch <- 10
Unlock(f)
}()
select {
case x := <-ch:
t.Fatal(x)
default:
}
f2, _ := Lock("/tmp/fsutil-test-lock")
defer Unlock(f2)
select {
case i := <-ch:
if i != 10 {
t.Fatal(i)
}
default:
t.Fatal("No value available.")
}
}
func Test_Lock_badPath(t *testing.T) {
_, err := Lock("./dne/file.lock")
if err == nil {
t.Fatal(err)
}
}
func TestTryLock(t *testing.T) {
lockPath := "/tmp/fsutil-test-lock"
f, err := TryLock(lockPath)
if err != nil {
t.Fatalf("%#v", err)
t.Fatal(err)
}
f2, err := TryLock(lockPath)
if err != nil {
t.Fatal(err)
}
if f2 != nil {
t.Fatal(f2)
}
if err := Unlock(f); err != nil {
t.Fatal(err)
}
}

3
flock/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/flock
go 1.23.0

0
flock/go.sum Normal file
View File

2
httpconn/README.md Normal file
View File

@ -0,0 +1,2 @@
# httpconn

104
httpconn/client.go Normal file
View File

@ -0,0 +1,104 @@
package httpconn
import (
"bufio"
"context"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"net/url"
"time"
)
var (
ErrUnknownScheme = errors.New("uknown scheme")
)
type Dialer struct {
timeout time.Duration
}
func NewDialer() *Dialer {
return &Dialer{timeout: 10 * time.Second}
}
func (d *Dialer) SetTimeout(timeout time.Duration) {
d.timeout = timeout
}
func (d *Dialer) Dial(rawURL string) (net.Conn, error) {
u, err := url.Parse(rawURL)
if err != nil {
return nil, err
}
switch u.Scheme {
case "https":
return d.DialHTTPS(u.Host+":443", u.Path)
case "http":
return d.DialHTTP(u.Host, u.Path)
default:
return nil, ErrUnknownScheme
}
}
func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), d.timeout)
dd := tls.Dialer{}
conn, err := dd.DialContext(ctx, "tcp", host)
cancel()
if err != nil {
return nil, err
}
return d.finishDialing(conn, host, path)
}
func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) {
conn, err := net.DialTimeout("tcp", host, d.timeout)
if err != nil {
return nil, err
}
return d.finishDialing(conn, host, path)
}
func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) {
conn.SetDeadline(time.Now().Add(d.timeout))
if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil {
return nil, err
}
if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil {
return nil, err
}
// Require successful HTTP response before using the conn.
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
if err != nil {
conn.Close()
return nil, err
}
if resp.Status != "200 OK" {
conn.Close()
return nil, err
}
conn.SetDeadline(time.Time{})
return conn, nil
}
func Dial(rawURL string) (net.Conn, error) {
return NewDialer().Dial(rawURL)
}
func DialHTTPS(host, path string) (net.Conn, error) {
return NewDialer().DialHTTPS(host, path)
}
func DialHTTP(host, path string) (net.Conn, error) {
return NewDialer().DialHTTP(host, path)
}

42
httpconn/conn_test.go Normal file
View File

@ -0,0 +1,42 @@
package httpconn
import (
"net"
"net/http"
"net/http/httptest"
"strings"
"testing"
"golang.org/x/net/nettest"
)
func TestNetTest_TestConn(t *testing.T) {
nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
connCh := make(chan net.Conn, 1)
doneCh := make(chan bool)
mux := http.NewServeMux()
mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) {
conn, err := Accept(w, r)
if err != nil {
panic(err)
}
connCh <- conn
<-doneCh
})
srv := httptest.NewServer(mux)
c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect")
if err != nil {
panic(err)
}
c2 = <-connCh
return c1, c2, func() {
doneCh <- true
srv.Close()
}, nil
})
}

5
httpconn/go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.crumpington.com/lib/httpconn
go 1.23.2
require golang.org/x/net v0.30.0

2
httpconn/go.sum Normal file
View File

@ -0,0 +1,2 @@
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=

32
httpconn/server.go Normal file
View File

@ -0,0 +1,32 @@
package httpconn
import (
"io"
"net"
"net/http"
"time"
)
func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
if r.Method != "CONNECT" {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusMethodNotAllowed)
io.WriteString(w, "405 must CONNECT\n")
return nil, http.ErrNotSupported
}
hj, ok := w.(http.Hijacker)
if !ok {
return nil, http.ErrNotSupported
}
conn, _, err := hj.Hijack()
if err != nil {
return nil, err
}
_, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n")
conn.SetDeadline(time.Time{})
return conn, nil
}

2
idgen/README.md Normal file
View File

@ -0,0 +1,2 @@
# idgen

3
idgen/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/idgen
go 1.23.2

53
idgen/idgen.go Normal file
View File

@ -0,0 +1,53 @@
package idgen
import (
"crypto/rand"
"encoding/base32"
"sync"
"time"
)
// Creates a new, random token.
func NewToken() string {
buf := make([]byte, 20)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
return base32.StdEncoding.EncodeToString(buf)
}
var (
lock sync.Mutex
ts int64 = time.Now().Unix()
counter int64 = 1
counterMax int64 = 1 << 20
)
// NextID can generate ~1M ints per second for a given nodeID.
//
// nodeID must be less than 64.
func NextID(nodeID int64) int64 {
lock.Lock()
defer lock.Unlock()
tt := time.Now().Unix()
if tt > ts {
ts = tt
counter = 1
} else {
counter++
if counter == counterMax {
panic("Too many IDs.")
}
}
return (ts << 26) + (nodeID << 20) + counter
}
func SplitID(id int64) (unixTime, nodeID, counter int64) {
counter = id & (0x00000000000FFFFF)
nodeID = (id >> 20) & (0x000000000000003F)
unixTime = id >> 26
return
}

19
idgen/idgen_test.go Normal file
View File

@ -0,0 +1,19 @@
package idgen
import (
"log"
"testing"
)
func BenchmarkNext(b *testing.B) {
for i := 0; i < b.N; i++ {
NextID(0)
}
}
func TestNextID(t *testing.T) {
id := NextID(32)
a, b, c := SplitID(id)
log.Print(a, b, c)
}

2
keyedmutex/README.md Normal file
View File

@ -0,0 +1,2 @@
# keyedmutex

3
keyedmutex/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/keyedmutex
go 1.23.2

67
keyedmutex/keyedmutex.go Normal file
View File

@ -0,0 +1,67 @@
package keyedmutex
import (
"container/list"
"sync"
)
type KeyedMutex[K comparable] struct {
mu *sync.Mutex
waitList map[K]*list.List
}
func New[K comparable]() KeyedMutex[K] {
return KeyedMutex[K]{
mu: new(sync.Mutex),
waitList: map[K]*list.List{},
}
}
func (m KeyedMutex[K]) Lock(key K) {
if ch := m.lock(key); ch != nil {
<-ch
}
}
func (m KeyedMutex[K]) lock(key K) chan struct{} {
m.mu.Lock()
defer m.mu.Unlock()
if waitList, ok := m.waitList[key]; ok {
ch := make(chan struct{})
waitList.PushBack(ch)
return ch
}
m.waitList[key] = list.New()
return nil
}
func (m KeyedMutex[K]) TryLock(key K) bool {
m.mu.Lock()
defer m.mu.Unlock()
if _, ok := m.waitList[key]; ok {
return false
}
m.waitList[key] = list.New()
return true
}
func (m KeyedMutex[K]) Unlock(key K) {
m.mu.Lock()
defer m.mu.Unlock()
waitList, ok := m.waitList[key]
if !ok {
panic("unlock of unlocked mutex")
}
if waitList.Len() == 0 {
delete(m.waitList, key)
} else {
ch := waitList.Remove(waitList.Front()).(chan struct{})
ch <- struct{}{}
}
}

View File

@ -0,0 +1,123 @@
package keyedmutex
import (
"sync"
"testing"
"time"
)
func TestKeyedMutex(t *testing.T) {
checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) {
if len(m.waitList) != len(keys) {
t.Fatal(m.waitList, keys)
}
for _, key := range keys {
if _, ok := m.waitList[key]; !ok {
t.Fatal(key)
}
}
}
m := New[string]()
checkState(t, m)
m.Lock("a")
checkState(t, m, "a")
m.Lock("b")
checkState(t, m, "a", "b")
m.Lock("c")
checkState(t, m, "a", "b", "c")
if m.TryLock("a") {
t.Fatal("a")
}
if m.TryLock("b") {
t.Fatal("b")
}
if m.TryLock("c") {
t.Fatal("c")
}
if !m.TryLock("d") {
t.Fatal("d")
}
checkState(t, m, "a", "b", "c", "d")
if !m.TryLock("e") {
t.Fatal("e")
}
checkState(t, m, "a", "b", "c", "d", "e")
m.Unlock("c")
checkState(t, m, "a", "b", "d", "e")
m.Unlock("a")
checkState(t, m, "b", "d", "e")
m.Unlock("e")
checkState(t, m, "b", "d")
wg := sync.WaitGroup{}
for i := 0; i < 8; i++ {
wg.Add(1)
go func() {
defer wg.Done()
m.Lock("b")
m.Unlock("b")
}()
}
time.Sleep(100 * time.Millisecond)
m.Unlock("b")
wg.Wait()
checkState(t, m, "d")
m.Unlock("d")
checkState(t, m)
}
func TestKeyedMutex_unlockUnlocked(t *testing.T) {
defer func() {
if r := recover(); r == nil {
t.Fatal(r)
}
}()
m := New[string]()
m.Unlock("aldkfj")
}
func BenchmarkUncontendedMutex(b *testing.B) {
m := New[string]()
key := "xyz"
for i := 0; i < b.N; i++ {
m.Lock(key)
m.Unlock(key)
}
}
func BenchmarkContendedMutex(b *testing.B) {
m := New[string]()
key := "xyz"
m.Lock(key)
wg := sync.WaitGroup{}
for i := 0; i < b.N; i++ {
wg.Add(1)
go func() {
defer wg.Done()
m.Lock(key)
m.Unlock(key)
}()
}
time.Sleep(time.Second)
b.ResetTimer()
m.Unlock(key)
wg.Wait()
}

2
kvmemcache/README.md Normal file
View File

@ -0,0 +1,2 @@
# kvmemcache

134
kvmemcache/cache.go Normal file
View File

@ -0,0 +1,134 @@
package kvmemcache
import (
"container/list"
"sync"
"time"
"git.crumpington.com/lib/keyedmutex"
)
type Cache[K comparable, V any] struct {
updateLock keyedmutex.KeyedMutex[K]
src func(K) (V, error)
ttl time.Duration
maxSize int
// Lock protects variables below.
lock sync.Mutex
cache map[K]*list.Element
ll *list.List
stats Stats
}
type lruItem[K comparable, V any] struct {
key K
createdAt time.Time
value V
err error
}
type Config[K comparable, V any] struct {
MaxSize int
TTL time.Duration // Zero to ignore.
Src func(K) (V, error)
}
func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] {
return &Cache[K, V]{
updateLock: keyedmutex.New[K](),
src: conf.Src,
ttl: conf.TTL,
maxSize: conf.MaxSize,
lock: sync.Mutex{},
cache: make(map[K]*list.Element, conf.MaxSize+1),
ll: list.New(),
}
}
func (c *Cache[K, V]) Get(key K) (V, error) {
ok, val, err := c.get(key)
if ok {
return val, err
}
return c.load(key)
}
func (c *Cache[K, V]) Evict(key K) {
c.lock.Lock()
defer c.lock.Unlock()
c.evict(key)
}
func (c *Cache[K, V]) Stats() Stats {
c.lock.Lock()
defer c.lock.Unlock()
return c.stats
}
func (c *Cache[K, V]) put(key K, value V, err error) {
c.lock.Lock()
defer c.lock.Unlock()
c.stats.Misses++
c.cache[key] = c.ll.PushFront(lruItem[K, V]{
key: key,
createdAt: time.Now(),
value: value,
err: err,
})
if c.maxSize != 0 && len(c.cache) > c.maxSize {
li := c.ll.Back()
c.ll.Remove(li)
delete(c.cache, li.Value.(lruItem[K, V]).key)
}
}
func (c *Cache[K, V]) evict(key K) {
elem := c.cache[key]
if elem != nil {
delete(c.cache, key)
c.ll.Remove(elem)
}
}
func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) {
c.lock.Lock()
defer c.lock.Unlock()
li := c.cache[key]
if li == nil {
return false, val, nil
}
item := li.Value.(lruItem[K, V])
// Maybe evict.
if c.ttl != 0 && time.Since(item.createdAt) > c.ttl {
c.evict(key)
return false, val, nil
}
c.stats.Hits++
c.ll.MoveToFront(li)
return true, item.value, item.err
}
func (c *Cache[K, V]) load(key K) (V, error) {
c.updateLock.Lock(key)
defer c.updateLock.Unlock(key)
// Check again in case we lost the update race.
ok, val, err := c.get(key)
if ok {
return val, err
}
// Won the update race.
val, err = c.src(key)
c.put(key, val, err)
return val, err
}

249
kvmemcache/cache_test.go Normal file
View File

@ -0,0 +1,249 @@
package kvmemcache
import (
"errors"
"fmt"
"sync"
"testing"
"time"
)
type State[K comparable] struct {
Keys []K
Stats Stats
}
func (c *Cache[K,V]) assert(state State[K]) error {
c.lock.Lock()
defer c.lock.Unlock()
if len(c.cache) != len(state.Keys) {
return fmt.Errorf(
"Expected %d keys but found %d.",
len(state.Keys),
len(c.cache))
}
for _, k := range state.Keys {
if _, ok := c.cache[k]; !ok {
return fmt.Errorf(
"Expected key %v not found.",
k)
}
}
if c.stats.Hits != state.Stats.Hits {
return fmt.Errorf(
"Expected %d hits, but found %d.",
state.Stats.Hits,
c.stats.Hits)
}
if c.stats.Misses != state.Stats.Misses {
return fmt.Errorf(
"Expected %d misses, but found %d.",
state.Stats.Misses,
c.stats.Misses)
}
return nil
}
var ErrTest = errors.New("Hello")
func TestCache_basic(t *testing.T) {
c := New(Config[string, string]{
MaxSize: 4,
TTL: 50 * time.Millisecond,
Src: func(key string) (string, error) {
if key == "err" {
return "", ErrTest
}
return key, nil
},
})
type testCase struct {
name string
sleep time.Duration
key string
evict bool
state State[string]
}
cases := []testCase{
{
name: "get a",
key: "a",
state: State[string]{
Keys: []string{"a"},
Stats: Stats{Hits: 0, Misses: 1},
},
}, {
name: "get a again",
key: "a",
state: State[string]{
Keys: []string{"a"},
Stats: Stats{Hits: 1, Misses: 1},
},
}, {
name: "sleep, then get a again",
sleep: 55 * time.Millisecond,
key: "a",
state: State[string]{
Keys: []string{"a"},
Stats: Stats{Hits: 1, Misses: 2},
},
}, {
name: "get b",
key: "b",
state: State[string]{
Keys: []string{"a", "b"},
Stats: Stats{Hits: 1, Misses: 3},
},
}, {
name: "get c",
key: "c",
state: State[string]{
Keys: []string{"a", "b", "c"},
Stats: Stats{Hits: 1, Misses: 4},
},
}, {
name: "get d",
key: "d",
state: State[string]{
Keys: []string{"a", "b", "c", "d"},
Stats: Stats{Hits: 1, Misses: 5},
},
}, {
name: "get e",
key: "e",
state: State[string]{
Keys: []string{"b", "c", "d", "e"},
Stats: Stats{Hits: 1, Misses: 6},
},
}, {
name: "get c again",
key: "c",
state: State[string]{
Keys: []string{"b", "c", "d", "e"},
Stats: Stats{Hits: 2, Misses: 6},
},
}, {
name: "get err",
key: "err",
state: State[string]{
Keys: []string{"c", "d", "e", "err"},
Stats: Stats{Hits: 2, Misses: 7},
},
}, {
name: "get err again",
key: "err",
state: State[string]{
Keys: []string{"c", "d", "e", "err"},
Stats: Stats{Hits: 3, Misses: 7},
},
}, {
name: "evict c",
key: "c",
evict: true,
state: State[string]{
Keys: []string{"d", "e", "err"},
Stats: Stats{Hits: 3, Misses: 7},
},
}, {
name: "reload-all a",
key: "a",
state: State[string]{
Keys: []string{"a", "d", "e", "err"},
Stats: Stats{Hits: 3, Misses: 8},
},
}, {
name: "reload-all b",
key: "b",
state: State[string]{
Keys: []string{"a", "b", "e", "err"},
Stats: Stats{Hits: 3, Misses: 9},
},
}, {
name: "reload-all c",
key: "c",
state: State[string]{
Keys: []string{"a", "b", "c", "err"},
Stats: Stats{Hits: 3, Misses: 10},
},
}, {
name: "reload-all d",
key: "d",
state: State[string]{
Keys: []string{"a", "b", "c", "d"},
Stats: Stats{Hits: 3, Misses: 11},
},
}, {
name: "read a again",
key: "a",
state: State[string]{
Keys: []string{"b", "c", "d", "a"},
Stats: Stats{Hits: 4, Misses: 11},
},
}, {
name: "read e, evicting b",
key: "e",
state: State[string]{
Keys: []string{"c", "d", "a", "e"},
Stats: Stats{Hits: 4, Misses: 12},
},
},
}
for _, tc := range cases {
time.Sleep(tc.sleep)
if !tc.evict {
val, err := c.Get(tc.key)
if tc.key == "err" && err != ErrTest {
t.Fatal(tc.name, val)
}
if tc.key != "err" && val != tc.key {
t.Fatal(tc.name, tc.key, val)
}
} else {
c.Evict(tc.key)
}
if err := c.assert(tc.state); err != nil {
t.Fatal(err)
}
}
}
func TestCache_thunderingHerd(t *testing.T) {
c := New(Config[string,string]{
MaxSize: 4,
Src: func(key string) (string, error) {
time.Sleep(time.Second)
return key, nil
},
})
wg := sync.WaitGroup{}
for i := 0; i < 16384; i++ {
wg.Add(1)
go func() {
defer wg.Done()
val, err := c.Get("a")
if err != nil {
panic(err)
}
if val != "a" {
panic(err)
}
}()
}
wg.Wait()
stats := c.Stats()
if stats.Hits != 16383 || stats.Misses != 1 {
t.Fatal(stats)
}
}

5
kvmemcache/go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.crumpington.com/lib/kvmemcache
go 1.23.2
require git.crumpington.com/lib/keyedmutex v1.0.1

2
kvmemcache/go.sum Normal file
View File

@ -0,0 +1,2 @@
git.crumpington.com/lib/keyedmutex v1.0.1 h1:5ylwGXQzL9ojZIhlqkut6dpa4yt6Wz6bOWbf/tQBAMQ=
git.crumpington.com/lib/keyedmutex v1.0.1/go.mod h1:VxxJRU/XvvF61IuJZG7kUIv954Q8+Rh8bnVpEzGYrQ4=

6
kvmemcache/stats.go Normal file
View File

@ -0,0 +1,6 @@
package kvmemcache
type Stats struct {
Hits uint64
Misses uint64
}

2
mmap/README.md Normal file
View File

@ -0,0 +1,2 @@
# mmap

100
mmap/file.go Normal file
View File

@ -0,0 +1,100 @@
package mmap
import "os"
type File struct {
f *os.File
Map []byte
}
func Create(path string, size int64) (*File, error) {
f, err := os.Create(path)
if err != nil {
return nil, err
}
if err := f.Truncate(size); err != nil {
f.Close()
return nil, err
}
m, err := Map(f, PROT_READ|PROT_WRITE)
if err != nil {
f.Close()
return nil, err
}
return &File{f, m}, nil
}
// Opens a mapped file in read-only mode.
func Open(path string) (*File, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
m, err := Map(f, PROT_READ)
if err != nil {
f.Close()
return nil, err
}
return &File{f, m}, nil
}
func OpenFile(
path string,
fileFlags int,
perm os.FileMode,
size int64, // -1 for file size.
) (*File, error) {
f, err := os.OpenFile(path, fileFlags, perm)
if err != nil {
return nil, err
}
writable := fileFlags|os.O_RDWR != 0 || fileFlags|os.O_WRONLY != 0
fi, err := f.Stat()
if err != nil {
f.Close()
return nil, err
}
if writable && size > 0 && size != fi.Size() {
if err := f.Truncate(size); err != nil {
f.Close()
return nil, err
}
}
mapFlags := PROT_READ
if writable {
mapFlags |= PROT_WRITE
}
m, err := Map(f, mapFlags)
if err != nil {
f.Close()
return nil, err
}
return &File{f, m}, nil
}
func (f *File) Sync() error {
return Sync(f.Map)
}
func (f *File) Close() error {
if f.Map != nil {
if err := Unmap(f.Map); err != nil {
return err
}
f.Map = nil
}
if f.f != nil {
if err := f.f.Close(); err != nil {
return err
}
f.f = nil
}
return nil
}

3
mmap/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/mmap
go 1.23.2

0
mmap/go.sum Normal file
View File

59
mmap/mmap.go Normal file
View File

@ -0,0 +1,59 @@
package mmap
import (
"os"
"syscall"
"unsafe"
)
const (
PROT_READ = syscall.PROT_READ
PROT_WRITE = syscall.PROT_WRITE
)
// Mmap creates a memory map of the given file. The flags argument should be a
// combination of PROT_READ and PROT_WRITE. The size of the map will be the
// file's size.
func Map(f *os.File, flags int) ([]byte, error) {
fi, err := f.Stat()
if err != nil {
return nil, err
}
size := fi.Size()
addr, _, errno := syscall.Syscall6(
syscall.SYS_MMAP,
0, // addr: 0 => allow kernel to choose
uintptr(size),
uintptr(flags),
uintptr(syscall.MAP_SHARED),
f.Fd(),
0) // offset: 0 => start of file
if errno != 0 {
return nil, syscall.Errno(errno)
}
return unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), nil
}
// Munmap unmaps the data obtained by Map.
func Unmap(data []byte) error {
_, _, errno := syscall.Syscall(
syscall.SYS_MUNMAP,
uintptr(unsafe.Pointer(&data[:1][0])),
uintptr(cap(data)),
0)
if errno != 0 {
return syscall.Errno(errno)
}
return nil
}
func Sync(b []byte) (err error) {
_p0 := unsafe.Pointer(&b[0])
_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(_p0), uintptr(len(b)), uintptr(syscall.MS_SYNC))
if errno != 0 {
err = syscall.Errno(errno)
}
return
}

45
pgutil/README.md Normal file
View File

@ -0,0 +1,45 @@
# pgutil
## Transactions
Simplify postgres transactions using `WithTx` for serializable transactions,
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
type to get automatic retries of serialization errors.
## Migrations
Put your migrations into a directory, for example `migrations`, ordered by name
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
`Migrate` function:
```Go
//go:embed migrations
var migrations embed.FS
func init() {
Migrate(db, migrations) // Check the error, of course.
}
```
## Testing
In order to test this packge, we need to create a test user and database:
```
sudo su postgres
psql
CREATE DATABASE test;
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
GRANT ALL PRIVILEGES ON DATABASE test TO test;
use test
GRANT ALL ON SCHEMA public TO test;
```
Check that you can connect via the command line:
```
psql -h 127.0.0.1 -U test --password test
```

42
pgutil/dropall.go Normal file
View File

@ -0,0 +1,42 @@
package pgutil
import (
"database/sql"
"log"
)
const dropTablesQueryQuery = `
SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;'
FROM
pg_tables
WHERE
schemaname='public'`
// Deletes all tables in the database. Useful for testing.
func DropAllTables(db *sql.DB) error {
rows, err := db.Query(dropTablesQueryQuery)
if err != nil {
return err
}
queries := []string{}
for rows.Next() {
var s string
if err := rows.Scan(&s); err != nil {
return err
}
queries = append(queries, s)
}
if len(queries) > 0 {
log.Printf("DROPPING ALL (%d) TABLES", len(queries))
}
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
return err
}
}
return nil
}

31
pgutil/errors.go Normal file
View File

@ -0,0 +1,31 @@
package pgutil
import (
"errors"
"github.com/lib/pq"
)
func ErrIsDuplicateKey(err error) bool {
return ErrHasCode(err, "23505")
}
func ErrIsForeignKey(err error) bool {
return ErrHasCode(err, "23503")
}
func ErrIsSerializationFaiilure(err error) bool {
return ErrHasCode(err, "40001")
}
func ErrHasCode(err error, code string) bool {
if err == nil {
return false
}
var pErr *pq.Error
if errors.As(err, &pErr) {
return pErr.Code == pq.ErrorCode(code)
}
return false
}

36
pgutil/errors_test.go Normal file
View File

@ -0,0 +1,36 @@
package pgutil
import (
"database/sql"
"testing"
)
func TestErrors(t *testing.T) {
db, err := sql.Open(
"postgres",
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
if err != nil {
t.Fatal(err)
}
if err := DropAllTables(db); err != nil {
t.Fatal(err)
}
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`)
if !ErrIsDuplicateKey(err) {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`)
if !ErrIsDuplicateKey(err) {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`)
if !ErrIsForeignKey(err) {
t.Fatal(err)
}
}

5
pgutil/go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.crumpington.com/git/pgutil
go 1.23.2
require github.com/lib/pq v1.10.9

2
pgutil/go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=

82
pgutil/migrate.go Normal file
View File

@ -0,0 +1,82 @@
package pgutil
import (
"database/sql"
"embed"
"errors"
"fmt"
"path/filepath"
"sort"
)
const initMigrationTableQuery = `
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
func Migrate(db *sql.DB, migrationFS embed.FS) error {
return WithTx(db, func(tx *sql.Tx) error {
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
return err
}
dirs, err := migrationFS.ReadDir(".")
if err != nil {
return err
}
if len(dirs) != 1 {
return errors.New("expected a single migrations directory")
}
if !dirs[0].IsDir() {
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
}
dirName := dirs[0].Name()
files, err := migrationFS.ReadDir(dirName)
if err != nil {
return err
}
// Sort sql files by name.
sort.Slice(files, func(i, j int) bool {
return files[i].Name() < files[j].Name()
})
for _, dirEnt := range files {
if !dirEnt.Type().IsRegular() {
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
}
var (
name = dirEnt.Name()
exists bool
)
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
if err != nil {
return err
}
if exists {
continue
}
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
if err != nil {
return err
}
if _, err := tx.Exec(string(migration)); err != nil {
return fmt.Errorf("migration %s failed: %v", name, err)
}
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
return err
}
}
return nil
})
}

50
pgutil/migrate_test.go Normal file
View File

@ -0,0 +1,50 @@
package pgutil
import (
"database/sql"
"embed"
"testing"
)
//go:embed test-migrations
var testMigrationFS embed.FS
func TestMigrate(t *testing.T) {
db, err := sql.Open(
"postgres",
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
if err != nil {
t.Fatal(err)
}
if err := DropAllTables(db); err != nil {
t.Fatal(err)
}
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
// Shouldn't have any effect.
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
var exists bool
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
t.Fatal(err)
}
if exists {
t.Fatal("1 shouldn't exist")
}
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatal("2 should exist")
}
}

View File

@ -0,0 +1,9 @@
CREATE TABLE users(
UserID BIGINT NOT NULL PRIMARY KEY,
Email TEXT NOT NULL UNIQUE);
CREATE TABLE user_notes(
UserID BIGINT NOT NULL REFERENCES users(UserID),
NoteID BIGINT NOT NULL,
Note Text NOT NULL,
PRIMARY KEY(UserID,NoteID));

View File

@ -0,0 +1 @@
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');

View File

@ -0,0 +1 @@
DELETE FROM users WHERE UserID=1;

70
pgutil/tx.go Normal file
View File

@ -0,0 +1,70 @@
package pgutil
import (
"database/sql"
"math/rand"
"time"
)
// Postgres doesn't use serializable transactions by default. This wrapper will
// run the enclosed function within a serializable. Note: this may return an
// retriable serialization error (see ErrIsSerializationFaiilure).
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
return WithTxDefault(db, func(tx *sql.Tx) error {
if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil {
return err
}
return fn(tx)
})
}
// This is a convenience function to provide a transaction wrapper with the
// default isolation level.
func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error {
// Start a transaction.
tx, err := db.Begin()
if err != nil {
return err
}
err = fn(tx)
if err == nil {
err = tx.Commit()
}
if err != nil {
_ = tx.Rollback()
}
return err
}
// SerialTxRunner attempts serializable transactions in a loop. If a
// transaction fails due to a serialization error, then the runner will retry
// with exponential backoff, until the sleep time reaches MaxTimeout.
//
// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep
// for ~100, 200, 400, and 800 ms between retries.
//
// 10% jitter is added to the sleep time.
type SerialTxRunner struct {
MinTimeout time.Duration
MaxTimeout time.Duration
}
func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
timeout := r.MinTimeout
for {
err := WithTx(db, fn)
if err == nil {
return nil
}
if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) {
return err
}
sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10)))
time.Sleep(sleepTimeout)
timeout *= 2
}
}

120
pgutil/tx_test.go Normal file
View File

@ -0,0 +1,120 @@
package pgutil
import (
"database/sql"
"fmt"
"sync"
"testing"
"time"
)
// TestExecuteTx verifies transaction retry using the classic
// example of write skew in bank account balance transfers.
func TestWithTx(t *testing.T) {
db, err := sql.Open(
"postgres",
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
if err != nil {
t.Fatal(err)
}
if err := DropAllTables(db); err != nil {
t.Fatal(err)
}
defer db.Close()
initStmt := `
CREATE TABLE t (acct INT PRIMARY KEY, balance INT);
INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100);
`
if _, err := db.Exec(initStmt); err != nil {
t.Fatal(err)
}
type queryI interface {
Query(string, ...interface{}) (*sql.Rows, error)
}
getBalances := func(q queryI) (bal1, bal2 int, err error) {
var rows *sql.Rows
rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`)
if err != nil {
return
}
defer rows.Close()
balances := []*int{&bal1, &bal2}
i := 0
for ; rows.Next(); i++ {
if err = rows.Scan(balances[i]); err != nil {
return
}
}
if i != 2 {
err = fmt.Errorf("expected two balances; got %d", i)
return
}
return
}
txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond}
runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error {
errCh := make(chan error, 1)
go func() {
*iter = 0
errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error {
*iter++
bal1, bal2, err := getBalances(tx)
if err != nil {
return err
}
// If this is the first iteration, wait for the other tx to
// also read.
if *iter == 1 {
wg.Done()
wg.Wait()
}
// Now, subtract from one account and give to the other.
if bal1 > bal2 {
if _, err := tx.Exec(`
UPDATE t SET balance=balance-100 WHERE acct=1;
UPDATE t SET balance=balance+100 WHERE acct=2;
`); err != nil {
return err
}
} else {
if _, err := tx.Exec(`
UPDATE t SET balance=balance+100 WHERE acct=1;
UPDATE t SET balance=balance-100 WHERE acct=2;
`); err != nil {
return err
}
}
return nil
})
}()
return errCh
}
var wg sync.WaitGroup
wg.Add(2)
var iters1, iters2 int
txn1Err := runTxn(&wg, &iters1)
txn2Err := runTxn(&wg, &iters2)
if err := <-txn1Err; err != nil {
t.Errorf("expected success in txn1; got %s", err)
}
if err := <-txn2Err; err != nil {
t.Errorf("expected success in txn2; got %s", err)
}
if iters1+iters2 <= 2 {
t.Errorf("expected retries between the competing transactions; "+
"got txn1=%d, txn2=%d", iters1, iters2)
}
bal1, bal2, err := getBalances(db)
if err != nil || bal1 != 100 || bal2 != 100 {
t.Errorf("expected balances to be restored without error; "+
"got acct1=%d, acct2=%d: %s", bal1, bal2, err)
}
}

3
ratelimiter/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/ratelimiter
go 1.23.2

View File

@ -0,0 +1,86 @@
package ratelimiter
import (
"errors"
"sync"
"time"
)
var ErrBackoff = errors.New("Backoff")
type Config struct {
BurstLimit int64 // Number of requests to allow to burst.
FillPeriod time.Duration // Add one per period.
MaxWaitCount int64 // Max number of waiting requests. 0 disables.
}
type Limiter struct {
lock sync.Mutex
fillPeriod time.Duration
minWaitTime time.Duration
maxWaitTime time.Duration
waitTime time.Duration // If waitTime < 0, no waiting occurs.
lastRequest time.Time
}
func New(conf Config) *Limiter {
if conf.BurstLimit < 0 {
panic(conf.BurstLimit)
}
if conf.FillPeriod <= 0 {
panic(conf.FillPeriod)
}
if conf.MaxWaitCount < 0 {
panic(conf.MaxWaitCount)
}
lim := &Limiter{
lastRequest: time.Now(),
fillPeriod: conf.FillPeriod,
waitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit),
minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit),
maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount),
}
lim.waitTime = lim.minWaitTime
return lim
}
func (lim *Limiter) limit(count int64) (time.Duration, error) {
lim.lock.Lock()
defer lim.lock.Unlock()
dt := time.Since(lim.lastRequest)
waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod
if waitTime < lim.minWaitTime {
waitTime = lim.minWaitTime
} else if waitTime > lim.maxWaitTime {
return 0, ErrBackoff
}
lim.waitTime = waitTime
lim.lastRequest = lim.lastRequest.Add(dt)
return lim.waitTime, nil
}
// Apply the limiter to the calling thread. The function may sleep for up to
// maxWaitTime before returning. If the timeout would need to be more than
// maxWaitTime to enforce the rate limit, ErrBackoff is returned.
func (lim *Limiter) Limit() error {
dt, err := lim.limit(1)
time.Sleep(dt) // Will return immediately for dt <= 0.
return err
}
// Apply the limiter for multiple items at once.
func (lim *Limiter) LimitMultiple(count int64) error {
dt, err := lim.limit(count)
time.Sleep(dt)
return err
}

View File

@ -0,0 +1,101 @@
package ratelimiter
import (
"sync"
"testing"
"time"
)
func TestRateLimiter_Limit_Errors(t *testing.T) {
type TestCase struct {
Name string
Conf Config
N int
ErrCount int
DT time.Duration
}
cases := []TestCase{
{
Name: "no burst, no wait",
Conf: Config{
BurstLimit: 0,
FillPeriod: 100 * time.Millisecond,
MaxWaitCount: 0,
},
N: 32,
ErrCount: 32,
DT: 0,
}, {
Name: "no wait",
Conf: Config{
BurstLimit: 10,
FillPeriod: 100 * time.Millisecond,
MaxWaitCount: 0,
},
N: 32,
ErrCount: 22,
DT: 0,
}, {
Name: "no burst",
Conf: Config{
BurstLimit: 0,
FillPeriod: 10 * time.Millisecond,
MaxWaitCount: 10,
},
N: 32,
ErrCount: 22,
DT: 100 * time.Millisecond,
}, {
Name: "burst and wait",
Conf: Config{
BurstLimit: 10,
FillPeriod: 10 * time.Millisecond,
MaxWaitCount: 10,
},
N: 32,
ErrCount: 12,
DT: 100 * time.Millisecond,
},
}
for _, tc := range cases {
wg := sync.WaitGroup{}
l := New(tc.Conf)
errs := make([]error, tc.N)
t0 := time.Now()
for i := 0; i < tc.N; i++ {
wg.Add(1)
go func(i int) {
errs[i] = l.Limit()
wg.Done()
}(i)
}
wg.Wait()
dt := time.Since(t0)
errCount := 0
for _, err := range errs {
if err != nil {
errCount++
}
}
if errCount != tc.ErrCount {
t.Fatalf("%s: Expected %d errors but got %d.",
tc.Name, tc.ErrCount, errCount)
}
if dt < tc.DT {
t.Fatal(tc.Name, dt, tc.DT)
}
if dt > tc.DT+10*time.Millisecond {
t.Fatal(tc.Name, dt, tc.DT)
}
}
}

22
sqlgen/README.md Normal file
View File

@ -0,0 +1,22 @@
# sqlgen
## Installing
```
go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest
```
## Usage
```
sqlgen [driver] [defs-path] [output-path]
```
## File Format
```
TABLE [sql-name] OF [go-type] <NoInsert> <NoUpdate> <NoDelete> (
[sql-column] [go-type] <AS go-name> <PK> <NoInsert> <NoUpdate>,
...
);
```

View File

@ -0,0 +1,7 @@
package main
import "git.crumpington.com/lib/sqlgen"
func main() {
sqlgen.Main()
}

3
sqlgen/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/sqlgen
go 1.23.2

43
sqlgen/main.go Normal file
View File

@ -0,0 +1,43 @@
package sqlgen
import (
"fmt"
"os"
)
func Main() {
usage := func() {
fmt.Fprintf(os.Stderr, `
%s DRIVER DEFS_PATH OUTPUT_PATH
Drivers are one of: sqlite, postgres
`,
os.Args[0])
os.Exit(1)
}
if len(os.Args) != 4 {
usage()
}
var (
template string
driver = os.Args[1]
defsPath = os.Args[2]
outputPath = os.Args[3]
)
switch driver {
case "sqlite":
template = sqliteTemplate
default:
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
usage()
}
err := render(template, defsPath, outputPath)
if err != nil {
fmt.Fprintf(os.Stderr, "Error: %v", err)
os.Exit(1)
}
}

143
sqlgen/parse.go Normal file
View File

@ -0,0 +1,143 @@
package sqlgen
import (
"errors"
"os"
"strings"
)
func parsePath(filePath string) (*schema, error) {
fileBytes, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
return parseBytes(fileBytes)
}
func parseBytes(fileBytes []byte) (*schema, error) {
s := string(fileBytes)
for _, c := range []string{",", "(", ")", ";"} {
s = strings.ReplaceAll(s, c, " "+c+" ")
}
var (
tokens = strings.Fields(s)
schema = &schema{}
err error
)
for len(tokens) > 0 {
switch tokens[0] {
case "TABLE":
tokens, err = parseTable(schema, tokens)
if err != nil {
return nil, err
}
default:
return nil, errors.New("invalid token: " + tokens[0])
}
}
return schema, nil
}
func parseTable(schema *schema, tokens []string) ([]string, error) {
tokens = tokens[1:]
if len(tokens) < 3 {
return tokens, errors.New("incomplete table definition")
}
if tokens[1] != "OF" {
return tokens, errors.New("expected OF in table definition")
}
table := &table{
Name: tokens[0],
Type: tokens[2],
}
schema.Tables = append(schema.Tables, table)
tokens = tokens[3:]
if len(tokens) == 0 {
return tokens, errors.New("missing table definition body")
}
for len(tokens) > 0 {
switch tokens[0] {
case "NoInsert":
table.NoInsert = true
tokens = tokens[1:]
case "NoUpdate":
table.NoUpdate = true
tokens = tokens[1:]
case "NoDelete":
table.NoDelete = true
tokens = tokens[1:]
case "(":
return parseTableBody(table, tokens[1:])
default:
return tokens, errors.New("unexpected token in table definition: " + tokens[0])
}
}
return tokens, errors.New("incomplete table definition")
}
func parseTableBody(table *table, tokens []string) ([]string, error) {
var err error
for len(tokens) > 0 && tokens[0] != ";" {
tokens, err = parseTableColumn(table, tokens)
if err != nil {
return tokens, err
}
}
if len(tokens) < 1 || tokens[0] != ";" {
return tokens, errors.New("incomplete table column definitions")
}
return tokens[1:], nil
}
func parseTableColumn(table *table, tokens []string) ([]string, error) {
if len(tokens) < 2 {
return tokens, errors.New("incomplete column definition")
}
column := &column{
Name: tokens[0],
Type: tokens[1],
SqlName: tokens[0],
}
table.Columns = append(table.Columns, column)
tokens = tokens[2:]
for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" {
switch tokens[0] {
case "AS":
if len(tokens) < 2 {
return tokens, errors.New("incomplete AS clause in column definition")
}
column.Name = tokens[1]
tokens = tokens[2:]
case "PK":
column.PK = true
tokens = tokens[1:]
case "NoInsert":
column.NoInsert = true
tokens = tokens[1:]
case "NoUpdate":
column.NoUpdate = true
tokens = tokens[1:]
default:
return tokens, errors.New("unexpected token in column definition: " + tokens[0])
}
}
if len(tokens) == 0 {
return tokens, errors.New("incomplete column definition")
}
return tokens[1:], nil
}

45
sqlgen/parse_test.go Normal file
View File

@ -0,0 +1,45 @@
package sqlgen
import (
"encoding/json"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
)
func TestParse(t *testing.T) {
toString := func(v any) string {
txt, _ := json.MarshalIndent(v, "", " ")
return string(txt)
}
paths, err := filepath.Glob("test-files/TestParse/*.def")
if err != nil {
t.Fatal(err)
}
for _, defPath := range paths {
t.Run(filepath.Base(defPath), func(t *testing.T) {
parsed, err := parsePath(defPath)
if err != nil {
t.Fatal(err)
}
b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json")
if err != nil {
t.Fatal(err)
}
expected := &schema{}
if err := json.Unmarshal(b, expected); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(parsed, expected) {
t.Fatalf("%s != %s", toString(parsed), toString(expected))
}
})
}
}

263
sqlgen/schema.go Normal file
View File

@ -0,0 +1,263 @@
package sqlgen
import (
"fmt"
"strings"
)
type schema struct {
Tables []*table
}
type table struct {
Name string // Name in SQL
Type string // Go type
NoInsert bool
NoUpdate bool
NoDelete bool
Columns []*column
}
type column struct {
Name string
Type string
SqlName string // Defaults to Name
PK bool // PK won't be updated
NoInsert bool
NoUpdate bool // Don't update column in update function
}
// ----------------------------------------------------------------------------
func (t *table) colSQLNames() []string {
names := make([]string, len(t.Columns))
for i := range names {
names[i] = t.Columns[i].SqlName
}
return names
}
func (t *table) SelectQuery() string {
return fmt.Sprintf(`SELECT %s FROM %s`,
strings.Join(t.colSQLNames(), ","),
t.Name)
}
func (t *table) insertCols() (cols []*column) {
for _, c := range t.Columns {
if !c.NoInsert {
cols = append(cols, c)
}
}
return cols
}
func (t *table) InsertQuery() string {
cols := t.insertCols()
b := &strings.Builder{}
b.WriteString(`INSERT INTO `)
b.WriteString(t.Name)
b.WriteString(`(`)
for i, c := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(c.SqlName)
}
b.WriteString(`) VALUES(`)
for i := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(`?`)
}
b.WriteString(`)`)
return b.String()
}
func (t *table) InsertArgs() string {
args := []string{}
for i, col := range t.Columns {
if !col.NoInsert {
args = append(args, "row."+t.Columns[i].Name)
}
}
return strings.Join(args, ", ")
}
func (t *table) UpdateCols() (cols []*column) {
for _, c := range t.Columns {
if !(c.PK || c.NoUpdate) {
cols = append(cols, c)
}
}
return cols
}
func (t *table) UpdateQuery() string {
cols := t.UpdateCols()
b := &strings.Builder{}
b.WriteString(`UPDATE `)
b.WriteString(t.Name + ` SET `)
for i, col := range cols {
if i != 0 {
b.WriteByte(',')
}
b.WriteString(col.SqlName + `=?`)
}
b.WriteString(` WHERE`)
for i, c := range t.pkCols() {
if i != 0 {
b.WriteString(` AND`)
}
b.WriteString(` ` + c.SqlName + `=?`)
}
return b.String()
}
func (t *table) UpdateArgs() string {
cols := t.UpdateCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString("row." + col.Name)
}
for _, col := range t.pkCols() {
b.WriteString(", row." + col.Name)
}
return b.String()
}
func (t *table) UpdateFullCols() (cols []*column) {
for _, c := range t.Columns {
if !c.PK {
cols = append(cols, c)
}
}
return cols
}
func (t *table) UpdateFullQuery() string {
cols := t.UpdateFullCols()
b := &strings.Builder{}
b.WriteString(`UPDATE `)
b.WriteString(t.Name + ` SET `)
for i, col := range cols {
if i != 0 {
b.WriteByte(',')
}
b.WriteString(col.SqlName + `=?`)
}
b.WriteString(` WHERE`)
for i, c := range t.pkCols() {
if i != 0 {
b.WriteString(` AND`)
}
b.WriteString(` ` + c.SqlName + `=?`)
}
return b.String()
}
func (t *table) UpdateFullArgs() string {
cols := t.UpdateFullCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString("row." + col.Name)
}
for _, col := range t.pkCols() {
b.WriteString(", row." + col.Name)
}
return b.String()
}
func (t *table) pkCols() (cols []*column) {
for _, c := range t.Columns {
if c.PK {
cols = append(cols, c)
}
}
return cols
}
func (t *table) PKFunctionArgs() string {
b := &strings.Builder{}
for _, col := range t.pkCols() {
b.WriteString(col.Name)
b.WriteString(` `)
b.WriteString(col.Type)
b.WriteString(",\n")
}
return b.String()
}
func (t *table) DeleteQuery() string {
cols := t.pkCols()
b := &strings.Builder{}
b.WriteString(`DELETE FROM `)
b.WriteString(t.Name)
b.WriteString(` WHERE `)
for i, col := range cols {
if i != 0 {
b.WriteString(` AND `)
}
b.WriteString(col.SqlName)
b.WriteString(`=?`)
}
return b.String()
}
func (t *table) DeleteArgs() string {
cols := t.pkCols()
b := &strings.Builder{}
for i, col := range cols {
if i != 0 {
b.WriteString(`,`)
}
b.WriteString(col.Name)
}
return b.String()
}
func (t *table) GetQuery() string {
b := &strings.Builder{}
b.WriteString(t.SelectQuery())
b.WriteString(` WHERE `)
for i, col := range t.pkCols() {
if i != 0 {
b.WriteString(` AND `)
}
b.WriteString(col.SqlName + `=?`)
}
return b.String()
}
func (t *table) ScanArgs() string {
b := &strings.Builder{}
for i, col := range t.Columns {
if i != 0 {
b.WriteString(`, `)
}
b.WriteString(`&row.` + col.Name)
}
return b.String()
}

206
sqlgen/sqlite.go.tmpl Normal file
View File

@ -0,0 +1,206 @@
package {{.PackageName}}
import (
"database/sql"
"iter"
)
type TX interface {
Exec(query string, args ...any) (sql.Result, error)
Query(query string, args ...any) (*sql.Rows, error)
QueryRow(query string, args ...any) *sql.Row
}
{{range .Schema.Tables}}
// ----------------------------------------------------------------------------
// Table: {{.Name}}
// ----------------------------------------------------------------------------
type {{.Type}} struct {
{{- range .Columns}}
{{.Name}} {{.Type}}{{end}}
}
const {{.Type}}_SelectQuery = "{{.SelectQuery}}"
{{if not .NoInsert -}}
func {{.Type}}_Insert(
tx TX,
row *{{.Type}},
) (err error) {
{{.Type}}_Sanitize(row)
if err = {{.Type}}_Validate(row); err != nil {
return err
}
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
return err
}
{{- end}} {{/* if not .NoInsert */}}
{{if not .NoUpdate -}}
{{if .UpdateCols -}}
func {{.Type}}_Update(
tx TX,
row *{{.Type}},
) (found bool, err error) {
{{.Type}}_Sanitize(row)
if err = {{.Type}}_Validate(row); err != nil {
return false, err
}
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
if err != nil {
return false, err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
if n > 1 {
panic("multiple rows updated")
}
return n != 0, nil
}
{{- end}}
{{if .UpdateFullCols -}}
func {{.Type}}_UpdateFull(
tx TX,
row *{{.Type}},
) (found bool, err error) {
{{.Type}}_Sanitize(row)
if err = {{.Type}}_Validate(row); err != nil {
return false, err
}
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
if err != nil {
return false, err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
if n > 1 {
panic("multiple rows updated")
}
return n != 0, nil
}
{{- end}}
{{- end}} {{/* if not .NoUpdate */}}
{{if not .NoDelete -}}
func {{.Type}}_Delete(
tx TX,
{{.PKFunctionArgs -}}
) (found bool, err error) {
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
if err != nil {
return false, err
}
n, err := result.RowsAffected()
if err != nil {
panic(err)
}
if n > 1 {
panic("multiple rows deleted")
}
return n != 0, nil
}
{{- end}}
func {{.Type}}_Get(
tx TX,
{{.PKFunctionArgs -}}
) (
row *{{.Type}},
err error,
) {
row = &{{.Type}}{}
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
err = r.Scan({{.ScanArgs}})
return
}
func {{.Type}}_GetWhere(
tx TX,
query string,
args ...any,
) (
row *{{.Type}},
err error,
) {
row = &{{.Type}}{}
r := tx.QueryRow(query, args...)
err = r.Scan({{.ScanArgs}})
return
}
func {{.Type}}_Iterate(
tx TX,
query string,
args ...any,
) (
iter.Seq2[*{{.Type}}, error],
) {
rows, err := tx.Query(query, args...)
if err != nil {
return func(yield func(*{{.Type}}, error) bool) {
yield(nil, err)
}
}
return func(yield func(*{{.Type}}, error) bool) {
defer rows.Close()
for rows.Next() {
row := &{{.Type}}{}
err := rows.Scan({{.ScanArgs}})
if !yield(row, err) {
return
}
}
}
}
func {{.Type}}_List(
tx TX,
query string,
args ...any,
) (
l []*{{.Type}},
err error,
) {
for row, err := range {{.Type}}_Iterate(tx, query, args...) {
if err != nil {
return nil, err
}
l = append(l, row)
}
return l, nil
}
{{end}} {{/* range .Schema.Tables */}}

36
sqlgen/template.go Normal file
View File

@ -0,0 +1,36 @@
package sqlgen
import (
_ "embed"
"os"
"os/exec"
"path/filepath"
"text/template"
)
//go:embed sqlite.go.tmpl
var sqliteTemplate string
func render(templateStr, schemaPath, outputPath string) error {
sch, err := parsePath(schemaPath)
if err != nil {
return err
}
tmpl := template.Must(template.New("").Parse(templateStr))
fOut, err := os.Create(outputPath)
if err != nil {
return err
}
defer fOut.Close()
err = tmpl.Execute(fOut, struct {
PackageName string
Schema *schema
}{filepath.Base(filepath.Dir(outputPath)), sch})
if err != nil {
return err
}
return exec.Command("gofmt", "-w", outputPath).Run()
}

View File

@ -0,0 +1,3 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK
);

View File

@ -0,0 +1,17 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
}
]
}
]
}

View File

@ -0,0 +1,4 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate
);

View File

@ -0,0 +1,22 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
}, {
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
}
]
}
]
}

View File

@ -0,0 +1,6 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate,
name string AS Name NoInsert,
admin bool AS Admin NoInsert NoUpdate
);

View File

@ -0,0 +1,33 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
}, {
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
}, {
"Name": "Name",
"Type": "string",
"SqlName": "name",
"NoInsert": true
}, {
"Name": "Admin",
"Type": "bool",
"SqlName": "admin",
"NoInsert": true,
"NoUpdate": true
}
]
}
]
}

View File

@ -0,0 +1,12 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate,
name string AS Name NoInsert,
admin bool AS Admin NoInsert NoUpdate
);
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
user_id string AS UserID PK,
email string AS Email,
name string AS Name
);

View File

@ -0,0 +1,61 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name",
"NoInsert": true
},
{
"Name": "Admin",
"Type": "bool",
"SqlName": "admin",
"NoInsert": true,
"NoUpdate": true
}
]
},
{
"Name": "users_view",
"Type": "UserView",
"NoInsert": true,
"NoUpdate": true,
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email"
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name"
}
]
}
]
}

View File

@ -0,0 +1,13 @@
TABLE users OF User NoDelete (
user_id string AS UserID PK,
email string AS Email NoUpdate,
name string AS Name NoInsert,
admin bool AS Admin NoInsert NoUpdate,
SSN string NoUpdate
);
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
user_id string AS UserID PK,
email string AS Email,
name string AS Name
);

View File

@ -0,0 +1,66 @@
{
"Tables": [
{
"Name": "users",
"Type": "User",
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email",
"NoUpdate": true
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name",
"NoInsert": true
},
{
"Name": "Admin",
"Type": "bool",
"SqlName": "admin",
"NoInsert": true,
"NoUpdate": true
}, {
"Name": "SSN",
"Type": "string",
"SqlName": "SSN",
"NoUpdate": true
}
]
},
{
"Name": "users_view",
"Type": "UserView",
"NoInsert": true,
"NoUpdate": true,
"NoDelete": true,
"Columns": [
{
"Name": "UserID",
"Type": "string",
"SqlName": "user_id",
"PK": true
},
{
"Name": "Email",
"Type": "string",
"SqlName": "email"
},
{
"Name": "Name",
"Type": "string",
"SqlName": "name"
}
]
}
]
}

45
sqliteutil/README.md Normal file
View File

@ -0,0 +1,45 @@
# sqliteutil
## Transactions
Simplify postgres transactions using `WithTx` for serializable transactions,
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
type to get automatic retries of serialization errors.
## Migrations
Put your migrations into a directory, for example `migrations`, ordered by name
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
`Migrate` function:
```Go
//go:embed migrations
var migrations embed.FS
func init() {
Migrate(db, migrations) // Check the error, of course.
}
```
## Testing
In order to test this packge, we need to create a test user and database:
```
sudo su postgres
psql
CREATE DATABASE test;
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
GRANT ALL PRIVILEGES ON DATABASE test TO test;
use test
GRANT ALL ON SCHEMA public TO test;
```
Check that you can connect via the command line:
```
psql -h 127.0.0.1 -U test --password test
```

5
sqliteutil/go.mod Normal file
View File

@ -0,0 +1,5 @@
module git.crumpington.com/lib/sqliteutil
go 1.23.2
require github.com/mattn/go-sqlite3 v1.14.24 // indirect

2
sqliteutil/go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=

82
sqliteutil/migrate.go Normal file
View File

@ -0,0 +1,82 @@
package sqliteutil
import (
"database/sql"
"embed"
"errors"
"fmt"
"path/filepath"
"sort"
)
const initMigrationTableQuery = `
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
func Migrate(db *sql.DB, migrationFS embed.FS) error {
return WithTx(db, func(tx *sql.Tx) error {
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
return err
}
dirs, err := migrationFS.ReadDir(".")
if err != nil {
return err
}
if len(dirs) != 1 {
return errors.New("expected a single migrations directory")
}
if !dirs[0].IsDir() {
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
}
dirName := dirs[0].Name()
files, err := migrationFS.ReadDir(dirName)
if err != nil {
return err
}
// Sort sql files by name.
sort.Slice(files, func(i, j int) bool {
return files[i].Name() < files[j].Name()
})
for _, dirEnt := range files {
if !dirEnt.Type().IsRegular() {
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
}
var (
name = dirEnt.Name()
exists bool
)
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
if err != nil {
return err
}
if exists {
continue
}
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
if err != nil {
return err
}
if _, err := tx.Exec(string(migration)); err != nil {
return fmt.Errorf("migration %s failed: %v", name, err)
}
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
return err
}
}
return nil
})
}

View File

@ -0,0 +1,44 @@
package sqliteutil
import (
"database/sql"
"embed"
"testing"
)
//go:embed test-migrations
var testMigrationFS embed.FS
func TestMigrate(t *testing.T) {
db, err := sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatal(err)
}
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
// Shouldn't have any effect.
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
var exists bool
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
t.Fatal(err)
}
if exists {
t.Fatal("1 shouldn't exist")
}
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatal("2 should exist")
}
}

View File

@ -0,0 +1,9 @@
CREATE TABLE users(
UserID BIGINT NOT NULL PRIMARY KEY,
Email TEXT NOT NULL UNIQUE);
CREATE TABLE user_notes(
UserID BIGINT NOT NULL REFERENCES users(UserID),
NoteID BIGINT NOT NULL,
Note Text NOT NULL,
PRIMARY KEY(UserID,NoteID));

View File

@ -0,0 +1 @@
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');

View File

@ -0,0 +1 @@
DELETE FROM users WHERE UserID=1;

28
sqliteutil/tx.go Normal file
View File

@ -0,0 +1,28 @@
package sqliteutil
import (
"database/sql"
_ "github.com/mattn/go-sqlite3"
)
// This is a convenience function to run a function within a transaction.
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
// Start a transaction.
tx, err := db.Begin()
if err != nil {
return err
}
err = fn(tx)
if err == nil {
err = tx.Commit()
}
if err != nil {
_ = tx.Rollback()
}
return err
}

2
tagengine/README.md Normal file
View File

@ -0,0 +1,2 @@
# tagengine

3
tagengine/go.mod Normal file
View File

@ -0,0 +1,3 @@
module git.crumpington.com/lib/tagengine
go 1.23.2

0
tagengine/go.sum Normal file
View File

30
tagengine/ngram.go Normal file
View File

@ -0,0 +1,30 @@
package tagengine
import "unicode"
func ngramLength(s string) int {
N := len(s)
i := 0
count := 0
for {
// Eat spaces.
for i < N && unicode.IsSpace(rune(s[i])) {
i++
}
// Done?
if i == N {
break
}
// Non-space!
count++
// Eat non-spaces.
for i < N && !unicode.IsSpace(rune(s[i])) {
i++
}
}
return count
}

31
tagengine/ngram_test.go Normal file
View File

@ -0,0 +1,31 @@
package tagengine
import (
"log"
"testing"
)
func TestNGramLength(t *testing.T) {
type Case struct {
Input string
Length int
}
cases := []Case{
{"a b c", 3},
{" xyz\nlkj dflaj a", 4},
{"a", 1},
{" a", 1},
{"a", 1},
{" a\n", 1},
{" a ", 1},
{"\tx\ny\nz q ", 4},
}
for _, tc := range cases {
length := ngramLength(tc.Input)
if length != tc.Length {
log.Fatalf("%s: %d != %d", tc.Input, length, tc.Length)
}
}
}

79
tagengine/node.go Normal file
View File

@ -0,0 +1,79 @@
package tagengine
import (
"fmt"
"strings"
)
type node struct {
Token string
Matches []*Rule // If a list of tokens reaches this node, it matches these.
Children map[string]*node
}
func (n *node) AddRule(r *Rule) {
n.addRule(r, 0)
}
func (n *node) addRule(r *Rule, idx int) {
if len(r.Includes) == idx {
n.Matches = append(n.Matches, r)
return
}
token := r.Includes[idx]
child, ok := n.Children[token]
if !ok {
child = &node{
Token: token,
Children: map[string]*node{},
}
n.Children[token] = child
}
child.addRule(r, idx+1)
}
// Note that tokens must be sorted. This is the case for tokens created from
// the tokenize function.
func (n *node) Match(tokens []string) (rules []*Rule) {
return n.match(tokens, rules)
}
func (n *node) match(tokens []string, rules []*Rule) []*Rule {
// Check for a match.
if n.Matches != nil {
rules = append(rules, n.Matches...)
}
if len(tokens) == 0 {
return rules
}
// Attempt to match children.
for i := 0; i < len(tokens); i++ {
token := tokens[i]
if child, ok := n.Children[token]; ok {
rules = child.match(tokens[i+1:], rules)
}
}
return rules
}
func (n *node) Dump() {
n.dump(0)
}
func (n *node) dump(depth int) {
indent := strings.Repeat(" ", 2*depth)
tag := ""
for _, m := range n.Matches {
tag += " " + m.Tag
}
fmt.Printf("%s%s%s\n", indent, n.Token, tag)
for _, child := range n.Children {
child.dump(depth + 1)
}
}

159
tagengine/rule.go Normal file
View File

@ -0,0 +1,159 @@
package tagengine
type Rule struct {
// The purpose of a Rule is to attach it's Tag to matching text.
Tag string
// Includes is a list of strings that must be found in the input in order to
// match.
Includes []string
// Excludes is a list of strings that can exclude a match for this rule.
Excludes []string
// Blocks: If this rule is matched, then it will block matches of any tags
// listed here.
Blocks []string
// The Score encodes the complexity of the Rule. A higher score indicates a
// more specific match. A Rule more includes, or includes with multiple words
// should havee a higher Score than a Rule with fewer includes or less
// complex includes.
Score int
excludes map[string]struct{}
}
func NewRule(tag string) Rule {
return Rule{Tag: tag}
}
func (r Rule) Inc(l ...string) Rule {
return Rule{
Tag: r.Tag,
Includes: append(r.Includes, l...),
Excludes: r.Excludes,
Blocks: r.Blocks,
}
}
func (r Rule) Exc(l ...string) Rule {
return Rule{
Tag: r.Tag,
Includes: r.Includes,
Excludes: append(r.Excludes, l...),
Blocks: r.Blocks,
}
}
func (r Rule) Block(l ...string) Rule {
return Rule{
Tag: r.Tag,
Includes: r.Includes,
Excludes: r.Excludes,
Blocks: append(r.Blocks, l...),
}
}
func (rule *Rule) normalize(sanitize func(string) string) {
for i, token := range rule.Includes {
rule.Includes[i] = sanitize(token)
}
for i, token := range rule.Excludes {
rule.Excludes[i] = sanitize(token)
}
sortTokens(rule.Includes)
sortTokens(rule.Excludes)
rule.excludes = map[string]struct{}{}
for _, s := range rule.Excludes {
rule.excludes[s] = struct{}{}
}
rule.Score = rule.computeScore()
}
func (r Rule) maxNGram() int {
max := 0
for _, s := range r.Includes {
n := ngramLength(s)
if n > max {
max = n
}
}
for _, s := range r.Excludes {
n := ngramLength(s)
if n > max {
max = n
}
}
return max
}
func (r Rule) isExcluded(tokens []string) bool {
// This is most often the case.
if len(r.excludes) == 0 {
return false
}
for _, s := range tokens {
if _, ok := r.excludes[s]; ok {
return true
}
}
return false
}
func (r Rule) computeScore() (score int) {
for _, token := range r.Includes {
n := ngramLength(token)
score += n * (n + 1) / 2
}
return score
}
func ruleLess(lhs, rhs *Rule) bool {
// If scores differ, sort by score.
if lhs.Score != rhs.Score {
return lhs.Score < rhs.Score
}
// If include depth differs, sort by depth.
lDepth := len(lhs.Includes)
rDepth := len(rhs.Includes)
if lDepth != rDepth {
return lDepth < rDepth
}
// If exclude depth differs, sort by depth.
lDepth = len(lhs.Excludes)
rDepth = len(rhs.Excludes)
if lDepth != rDepth {
return lDepth < rDepth
}
// Sort alphabetically by includes.
for i := range lhs.Includes {
if lhs.Includes[i] != rhs.Includes[i] {
return lhs.Includes[i] < rhs.Includes[i]
}
}
// Sort by alphabetically by excludes.
for i := range lhs.Excludes {
if lhs.Excludes[i] != rhs.Excludes[i] {
return lhs.Excludes[i] < rhs.Excludes[i]
}
}
// Sort by tag.
if lhs.Tag != rhs.Tag {
return lhs.Tag < rhs.Tag
}
return false
}

58
tagengine/rulegroup.go Normal file
View File

@ -0,0 +1,58 @@
package tagengine
// A RuleGroup can be converted into a list of rules. Each rule will point to
// the same tag, and have the same exclude set and blocks.
type RuleGroup struct {
Tag string
Includes [][]string
Excludes []string
Blocks []string
}
func NewRuleGroup(tag string) RuleGroup {
return RuleGroup{
Tag: tag,
Includes: [][]string{},
Excludes: []string{},
Blocks: []string{},
}
}
func (g RuleGroup) Inc(l ...string) RuleGroup {
return RuleGroup{
Tag: g.Tag,
Includes: append(g.Includes, l),
Excludes: g.Excludes,
Blocks: g.Blocks,
}
}
func (g RuleGroup) Exc(l ...string) RuleGroup {
return RuleGroup{
Tag: g.Tag,
Includes: g.Includes,
Excludes: append(g.Excludes, l...),
Blocks: g.Blocks,
}
}
func (g RuleGroup) Block(l ...string) RuleGroup {
return RuleGroup{
Tag: g.Tag,
Includes: g.Includes,
Excludes: g.Excludes,
Blocks: append(g.Blocks, l...),
}
}
func (g RuleGroup) ToList() (l []Rule) {
for _, includes := range g.Includes {
l = append(l, Rule{
Tag: g.Tag,
Excludes: g.Excludes,
Includes: includes,
Blocks: g.Blocks,
})
}
return
}

162
tagengine/ruleset.go Normal file
View File

@ -0,0 +1,162 @@
package tagengine
import (
"sort"
)
type RuleSet struct {
root *node
maxNgram int
sanitize func(string) string
rules []*Rule
}
func NewRuleSet() *RuleSet {
return &RuleSet{
root: &node{
Token: "/",
Children: map[string]*node{},
},
sanitize: BasicSanitizer,
rules: []*Rule{},
}
}
func NewRuleSetFromList(rules []Rule) *RuleSet {
rs := NewRuleSet()
rs.AddRule(rules...)
return rs
}
func (t *RuleSet) Add(ruleOrGroup ...interface{}) {
for _, ix := range ruleOrGroup {
switch x := ix.(type) {
case Rule:
t.AddRule(x)
case RuleGroup:
t.AddRuleGroup(x)
default:
panic("Add expects either Rule or RuleGroup objects.")
}
}
}
func (t *RuleSet) AddRule(rules ...Rule) {
for _, rule := range rules {
rule := rule
// Make sure rule is well-formed.
rule.normalize(t.sanitize)
// Update maxNgram.
N := rule.maxNGram()
if N > t.maxNgram {
t.maxNgram = N
}
t.rules = append(t.rules, &rule)
t.root.AddRule(&rule)
}
}
func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) {
for _, rg := range ruleGroups {
t.AddRule(rg.ToList()...)
}
}
// MatchRules will return a list of all matching rules. The rules are sorted by
// the match's Score. The best match will be first.
func (t *RuleSet) MatchRules(input string) (rules []*Rule) {
input = t.sanitize(input)
tokens := Tokenize(input, t.maxNgram)
rules = t.root.Match(tokens)
if len(rules) == 0 {
return rules
}
// Check excludes.
l := rules[:0]
for _, r := range rules {
if !r.isExcluded(tokens) {
l = append(l, r)
}
}
rules = l
// Sort rules descending.
sort.Slice(rules, func(i, j int) bool {
return ruleLess(rules[j], rules[i])
})
return rules
}
type Match struct {
Tag string
// Confidence is used to sort all matches, and is normalized so the sum of
// Confidence values for all matches is 1. Confidence is relative to the
// number of matches and the size of matches in terms of number of tokens.
Confidence float64 // In the range (0,1].
}
// Return a list of matches with confidence. This is useful if you'd like to
// find the best matching rule out of all the matched rules.
//
// If you just want to find all matching rules, then use MatchRules.
func (t *RuleSet) Match(input string) []Match {
rules := t.MatchRules(input)
if len(rules) == 0 {
return []Match{}
}
if len(rules) == 1 {
return []Match{{
Tag: rules[0].Tag,
Confidence: 1,
}}
}
// Create list of blocked tags.
blocks := map[string]struct{}{}
for _, rule := range rules {
for _, tag := range rule.Blocks {
blocks[tag] = struct{}{}
}
}
// Remove rules for blocked tags.
iOut := 0
for _, rule := range rules {
if _, ok := blocks[rule.Tag]; ok {
continue
}
rules[iOut] = rule
iOut++
}
rules = rules[:iOut]
// Matches by index.
matches := map[string]int{}
out := []Match{}
sum := float64(0)
for _, rule := range rules {
idx, ok := matches[rule.Tag]
if !ok {
idx = len(matches)
matches[rule.Tag] = idx
out = append(out, Match{Tag: rule.Tag})
}
out[idx].Confidence += float64(rule.Score)
sum += float64(rule.Score)
}
for i := range out {
out[i].Confidence /= sum
}
return out
}

84
tagengine/ruleset_test.go Normal file
View File

@ -0,0 +1,84 @@
package tagengine
import (
"reflect"
"testing"
)
func TestRulesSet(t *testing.T) {
rs := NewRuleSet()
rs.AddRule(Rule{
Tag: "cc/2",
Includes: []string{"cola", "coca"},
})
rs.AddRule(Rule{
Tag: "cc/0",
Includes: []string{"coca cola"},
})
rs.AddRule(Rule{
Tag: "cz/2",
Includes: []string{"coca", "zero"},
})
rs.AddRule(Rule{
Tag: "cc0/3",
Includes: []string{"zero", "coca", "cola"},
})
rs.AddRule(Rule{
Tag: "cc0/3.1",
Includes: []string{"coca", "cola", "zero"},
Excludes: []string{"pepsi"},
})
rs.AddRule(Rule{
Tag: "spa",
Includes: []string{"spa"},
Blocks: []string{"cc/0", "cc0/3", "cc0/3.1"},
})
type TestCase struct {
Input string
Matches []Match
}
cases := []TestCase{
{
Input: "coca-cola zero",
Matches: []Match{
{"cc0/3.1", 0.3},
{"cc0/3", 0.3},
{"cz/2", 0.2},
{"cc/2", 0.2},
},
}, {
Input: "coca cola",
Matches: []Match{
{"cc/0", 0.6},
{"cc/2", 0.4},
},
}, {
Input: "coca cola zero pepsi",
Matches: []Match{
{"cc0/3", 0.3},
{"cc/0", 0.3},
{"cz/2", 0.2},
{"cc/2", 0.2},
},
}, {
Input: "fanta orange",
Matches: []Match{},
}, {
Input: "coca-cola zero / fanta / spa",
Matches: []Match{
{"cz/2", 0.4},
{"cc/2", 0.4},
{"spa", 0.2},
},
},
}
for _, tc := range cases {
matches := rs.Match(tc.Input)
if !reflect.DeepEqual(matches, tc.Matches) {
t.Fatalf("%v != %v", matches, tc.Matches)
}
}
}

20
tagengine/sanitize.go Normal file
View File

@ -0,0 +1,20 @@
package tagengine
import (
"strings"
"git.crumpington.com/lib/tagengine/sanitize"
)
// The basic sanitizer:
// * lower-case
// * put spaces around numbers
// * put slaces around punctuation
// * collapse multiple spaces
func BasicSanitizer(s string) string {
s = strings.ToLower(s)
s = sanitize.SpaceNumbers(s)
s = sanitize.SpacePunctuation(s)
s = sanitize.CollapseSpaces(s)
return s
}

View File

@ -0,0 +1,91 @@
package sanitize
import (
"strings"
"unicode"
)
func SpaceNumbers(s string) string {
if len(s) == 0 {
return s
}
isDigit := func(b rune) bool {
switch b {
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
return true
}
return false
}
b := strings.Builder{}
var first rune
for _, c := range s {
first = c
break
}
digit := isDigit(first)
// Range over runes.
for _, c := range s {
thisDigit := isDigit(c)
if thisDigit != digit {
b.WriteByte(' ')
digit = thisDigit
}
b.WriteRune(c)
}
return b.String()
}
func SpacePunctuation(s string) string {
needsSpace := func(r rune) bool {
switch r {
case '`', '~', '!', '@', '#', '%', '^', '&', '*', '(', ')',
'-', '_', '+', '=', '[', '{', ']', '}', '\\', '|',
':', ';', '"', '\'', ',', '<', '.', '>', '?', '/':
return true
}
return false
}
b := strings.Builder{}
// Range over runes.
for _, r := range s {
if needsSpace(r) {
b.WriteRune(' ')
b.WriteRune(r)
b.WriteRune(' ')
} else {
b.WriteRune(r)
}
}
return b.String()
}
func CollapseSpaces(s string) string {
// Trim leading and trailing spaces.
s = strings.TrimSpace(s)
b := strings.Builder{}
wasSpace := false
// Range over runes.
for _, c := range s {
if unicode.IsSpace(c) {
wasSpace = true
continue
} else if wasSpace {
wasSpace = false
b.WriteRune(' ')
}
b.WriteRune(c)
}
return b.String()
}

View File

@ -0,0 +1,30 @@
package tagengine
import "testing"
func TestSanitize(t *testing.T) {
sanitize := BasicSanitizer
type Case struct {
In string
Out string
}
cases := []Case{
{"", ""},
{"123abc", "123 abc"},
{"abc123", "abc 123"},
{"abc123xyz", "abc 123 xyz"},
{"1f2", "1 f 2"},
{" abc", "abc"},
{" ; KitKat/m&m's (bottle) @ ", "; kitkat / m & m ' s ( bottle ) @"},
{"€", "€"},
}
for _, tc := range cases {
out := sanitize(tc.In)
if out != tc.Out {
t.Fatalf("%v != %v", out, tc.Out)
}
}
}

63
tagengine/tokenize.go Normal file
View File

@ -0,0 +1,63 @@
package tagengine
import (
"sort"
"strings"
)
var ignoreTokens = map[string]struct{}{}
func init() {
// These on their own are ignored.
tokens := []string{
"`", `~`, `!`, `@`, `#`, `%`, `^`, `&`, `*`, `(`, `)`,
`-`, `_`, `+`, `=`, `[`, `{`, `]`, `}`, `\`, `|`,
`:`, `;`, `"`, `'`, `,`, `<`, `.`, `>`, `?`, `/`,
}
for _, s := range tokens {
ignoreTokens[s] = struct{}{}
}
}
func Tokenize(
input string,
maxNgram int,
) (
tokens []string,
) {
// Avoid duplicate ngrams.
ignored := map[string]bool{}
fields := strings.Fields(input)
if len(fields) < maxNgram {
maxNgram = len(fields)
}
for i := 1; i < maxNgram+1; i++ {
jMax := len(fields) - i + 1
for j := 0; j < jMax; j++ {
ngram := strings.Join(fields[j:i+j], " ")
if _, ok := ignoreTokens[ngram]; !ok {
if _, ok := ignored[ngram]; !ok {
tokens = append(tokens, ngram)
ignored[ngram] = true
}
}
}
}
sortTokens(tokens)
return tokens
}
func sortTokens(tokens []string) {
sort.Slice(tokens, func(i, j int) bool {
if len(tokens[i]) != len(tokens[j]) {
return len(tokens[i]) < len(tokens[j])
}
return tokens[i] < tokens[j]
})
}

View File

@ -0,0 +1,55 @@
package tagengine
import (
"reflect"
"testing"
)
func TestTokenize(t *testing.T) {
type Case struct {
Input string
MaxNgram int
Output []string
}
cases := []Case{
{
Input: "a bb c d",
MaxNgram: 3,
Output: []string{
"a", "c", "d", "bb",
"c d", "a bb", "bb c",
"a bb c", "bb c d",
},
}, {
Input: "a b",
MaxNgram: 3,
Output: []string{
"a", "b", "a b",
},
}, {
Input: "- b c d",
MaxNgram: 3,
Output: []string{
"b", "c", "d",
"- b", "b c", "c d",
"- b c", "b c d",
},
}, {
Input: "a a b c d c d",
MaxNgram: 3,
Output: []string{
"a", "b", "c", "d",
"a a", "a b", "b c", "c d", "d c",
"a a b", "a b c", "b c d", "c d c", "d c d",
},
},
}
for _, tc := range cases {
output := Tokenize(tc.Input, tc.MaxNgram)
if !reflect.DeepEqual(output, tc.Output) {
t.Fatalf("%s: %#v", tc.Input, output)
}
}
}

5
webutil/README.md Normal file
View File

@ -0,0 +1,5 @@
# webutil
## Roadmap
* logging middleware

10
webutil/go.mod Normal file
View File

@ -0,0 +1,10 @@
module git.crumpington.com/lib/webutil
go 1.23.2
require golang.org/x/crypto v0.28.0
require (
golang.org/x/net v0.21.0 // indirect
golang.org/x/text v0.19.0 // indirect
)

6
webutil/go.sum Normal file
View File

@ -0,0 +1,6 @@
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=

24
webutil/listenandserve.go Normal file
View File

@ -0,0 +1,24 @@
package webutil
import (
"errors"
"net/http"
"strings"
"golang.org/x/crypto/acme/autocert"
)
// Serve requests using the given http.Server. If srv.Addr has the format
// `hostname:https`, then use autocert to manage certificates for the domain.
//
// For http on port 80, you can use :http.
func ListenAndServe(srv *http.Server) error {
if strings.HasSuffix(srv.Addr, ":https") {
hostname := strings.TrimSuffix(srv.Addr, ":https")
if len(hostname) == 0 {
return errors.New("https requires a hostname")
}
return srv.Serve(autocert.NewListener(hostname))
}
return srv.ListenAndServe()
}

View File

@ -0,0 +1,47 @@
package webutil
import (
"log"
"net/http"
"os"
"time"
)
var _log = log.New(os.Stderr, "", 0)
type responseWriterWrapper struct {
http.ResponseWriter
httpStatus int
responseSize int
}
func (w *responseWriterWrapper) WriteHeader(status int) {
w.httpStatus = status
w.ResponseWriter.WriteHeader(status)
}
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
if w.httpStatus == 0 {
w.httpStatus = 200
}
w.responseSize += len(b)
return w.ResponseWriter.Write(b)
}
func WithLogging(inner http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
t := time.Now()
wrapper := responseWriterWrapper{w, 0, 0}
inner(&wrapper, r)
_log.Printf("%s \"%s %s %s\" %d %d %v\n",
r.RemoteAddr,
r.Method,
r.URL.Path,
r.Proto,
wrapper.httpStatus,
wrapper.responseSize,
time.Since(t),
)
}
}

100
webutil/template.go Normal file
View File

@ -0,0 +1,100 @@
package webutil
import (
"embed"
"html/template"
"io/fs"
"log"
"path"
"strings"
)
// ParseTemplateSet parses sets of templates from an embed.FS.
//
// Each directory constitutes a set of templates that are parsed together.
//
// Structure (within a directory):
// - share/* are always parsed.
// - base.html will be parsed with each other file in same dir
//
// Call a template with m[path].Execute(w, data) (root dir name is excluded).
//
// For example, if you have
// - /user/share/*
// - /user/base.html
// - /user/home.html
//
// Then you call m["/user/home.html"].Execute(w, data).
func ParseTemplateSet(funcs template.FuncMap, fs embed.FS) map[string]*template.Template {
m := map[string]*template.Template{}
rootDir := readDir(fs, ".")[0].Name()
loadTemplateDir(fs, funcs, m, rootDir, rootDir)
return m
}
func loadTemplateDir(
fs embed.FS,
funcs template.FuncMap,
m map[string]*template.Template,
dirPath string,
rootDir string,
) map[string]*template.Template {
t := template.New("")
if funcs != nil {
t = t.Funcs(funcs)
}
shareDir := path.Join(dirPath, "share")
if _, err := fs.ReadDir(shareDir); err == nil {
log.Printf("Parsing %s...", path.Join(shareDir, "*"))
t = template.Must(t.ParseFS(fs, path.Join(shareDir, "*")))
}
if data, _ := fs.ReadFile(path.Join(dirPath, "base.html")); data != nil {
log.Printf("Parsing %s...", path.Join(dirPath, "base.html"))
t = template.Must(t.Parse(string(data)))
}
for _, ent := range readDir(fs, dirPath) {
if ent.Type().IsDir() {
if ent.Name() != "share" {
m = loadTemplateDir(fs, funcs, m, path.Join(dirPath, ent.Name()), rootDir)
}
continue
}
if !ent.Type().IsRegular() {
continue
}
if ent.Name() == "base.html" {
continue
}
filePath := path.Join(dirPath, ent.Name())
log.Printf("Parsing %s...", filePath)
key := strings.TrimPrefix(path.Join(dirPath, ent.Name()), rootDir)
tt := template.Must(t.Clone())
tt = template.Must(tt.Parse(readFile(fs, filePath)))
m[key] = tt
}
return m
}
func readDir(fs embed.FS, dirPath string) []fs.DirEntry {
ents, err := fs.ReadDir(dirPath)
if err != nil {
panic(err)
}
return ents
}
func readFile(fs embed.FS, path string) string {
data, err := fs.ReadFile(path)
if err != nil {
panic(err)
}
return string(data)
}

49
webutil/template_test.go Normal file
View File

@ -0,0 +1,49 @@
package webutil
import (
"bytes"
"embed"
"html/template"
"strings"
"testing"
)
//go:embed all:test-templates
var testFS embed.FS
func TestParseTemplateSet(t *testing.T) {
funcs := template.FuncMap{"join": strings.Join}
m := ParseTemplateSet(funcs, testFS)
type TestCase struct {
Key string
Data any
Out string
}
cases := []TestCase{
{
Key: "/home.html",
Data: "DATA",
Out: "<p>HOME!</p>",
}, {
Key: "/about.html",
Data: "DATA",
Out: "<p><b>DATA</b></p>",
}, {
Key: "/contact.html",
Data: []string{"a", "b", "c"},
Out: "<p>a,b,c</p>",
},
}
for _, tc := range cases {
b := &bytes.Buffer{}
m[tc.Key].Execute(b, tc.Data)
out := strings.TrimSpace(b.String())
if out != tc.Out {
t.Fatalf("%s != %s", out, tc.Out)
}
}
}

View File

@ -0,0 +1 @@
{{define "body"}}{{template "bold" .}}{{end}}

View File

@ -0,0 +1 @@
<p>{{block "body" .}}default{{end}}</p>

View File

@ -0,0 +1 @@
{{define "body"}}{{join . ","}}{{end}}

View File

@ -0,0 +1 @@
{{define "body"}}HOME!{{end}}

Some files were not shown because too many files have changed in this diff Show More