diff --git a/flock/README.md b/flock/README.md new file mode 100644 index 0000000..71f8c32 --- /dev/null +++ b/flock/README.md @@ -0,0 +1,2 @@ +# flock + diff --git a/flock/flock.go b/flock/flock.go new file mode 100644 index 0000000..5062d1e --- /dev/null +++ b/flock/flock.go @@ -0,0 +1,69 @@ +// The flock package provides a file-system mediated locking mechanism on linux +// using the `flock` system call. + +package flock + +import ( + "errors" + "os" + "syscall" +) + +// Lock gets an exclusive lock on the file at the given path. If the file +// doesn't exist, it's created. +func Lock(path string) (*os.File, error) { + return lock(path, syscall.LOCK_EX) +} + +// TryLock will return a nil file if the file is already locked. +func TryLock(path string) (*os.File, error) { + return lock(path, syscall.LOCK_EX|syscall.LOCK_NB) +} + +func LockFile(f *os.File) error { + _, err := lockFile(f, syscall.LOCK_EX) + return err +} + +// Returns true if the lock was successfully acquired. +func TryLockFile(f *os.File) (bool, error) { + return lockFile(f, syscall.LOCK_EX|syscall.LOCK_NB) +} + +func lockFile(f *os.File, flags int) (bool, error) { + + if err := flock(int(f.Fd()), flags); err != nil { + if flags&syscall.LOCK_NB != 0 && errors.Is(err, syscall.EAGAIN) { + return false, nil + } + return false, err + } + return true, nil +} + +func flock(fd int, how int) error { + _, _, e1 := syscall.Syscall(syscall.SYS_FLOCK, uintptr(fd), uintptr(how), 0) + if e1 != 0 { + return syscall.Errno(e1) + } + return nil +} + +func lock(path string, flags int) (*os.File, error) { + perm := os.O_CREATE | os.O_RDWR + f, err := os.OpenFile(path, perm, 0600) + if err != nil { + return nil, err + } + ok, err := lockFile(f, flags) + if err != nil || !ok { + f.Close() + f = nil + } + return f, err +} + +// Unlock releases the lock acquired via the Lock function. +func Unlock(f *os.File) error { + return f.Close() +} diff --git a/flock/flock_test.go b/flock/flock_test.go new file mode 100644 index 0000000..fc1e3c9 --- /dev/null +++ b/flock/flock_test.go @@ -0,0 +1,66 @@ +package flock + +import ( + "testing" + "time" +) + +func Test_Lock_basic(t *testing.T) { + ch := make(chan int, 1) + f, err := Lock("/tmp/fsutil-test-lock") + if err != nil { + t.Fatal(err) + } + go func() { + time.Sleep(time.Second) + ch <- 10 + Unlock(f) + }() + + select { + case x := <-ch: + t.Fatal(x) + default: + + } + + f2, _ := Lock("/tmp/fsutil-test-lock") + defer Unlock(f2) + select { + case i := <-ch: + if i != 10 { + t.Fatal(i) + } + default: + t.Fatal("No value available.") + } + +} + +func Test_Lock_badPath(t *testing.T) { + _, err := Lock("./dne/file.lock") + if err == nil { + t.Fatal(err) + } +} + +func TestTryLock(t *testing.T) { + lockPath := "/tmp/fsutil-test-lock" + f, err := TryLock(lockPath) + if err != nil { + t.Fatalf("%#v", err) + t.Fatal(err) + } + + f2, err := TryLock(lockPath) + if err != nil { + t.Fatal(err) + } + if f2 != nil { + t.Fatal(f2) + } + + if err := Unlock(f); err != nil { + t.Fatal(err) + } +} diff --git a/flock/go.mod b/flock/go.mod new file mode 100644 index 0000000..d62500e --- /dev/null +++ b/flock/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/flock + +go 1.23.0 diff --git a/flock/go.sum b/flock/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/httpconn/README.md b/httpconn/README.md new file mode 100644 index 0000000..5afb337 --- /dev/null +++ b/httpconn/README.md @@ -0,0 +1,2 @@ +# httpconn + diff --git a/httpconn/client.go b/httpconn/client.go new file mode 100644 index 0000000..813a86e --- /dev/null +++ b/httpconn/client.go @@ -0,0 +1,104 @@ +package httpconn + +import ( + "bufio" + "context" + "crypto/tls" + "errors" + "io" + "net" + "net/http" + "net/url" + "time" +) + +var ( + ErrUnknownScheme = errors.New("uknown scheme") +) + +type Dialer struct { + timeout time.Duration +} + +func NewDialer() *Dialer { + return &Dialer{timeout: 10 * time.Second} +} + +func (d *Dialer) SetTimeout(timeout time.Duration) { + d.timeout = timeout +} + +func (d *Dialer) Dial(rawURL string) (net.Conn, error) { + u, err := url.Parse(rawURL) + if err != nil { + return nil, err + } + + switch u.Scheme { + case "https": + return d.DialHTTPS(u.Host+":443", u.Path) + case "http": + return d.DialHTTP(u.Host, u.Path) + default: + return nil, ErrUnknownScheme + } +} + +func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), d.timeout) + dd := tls.Dialer{} + conn, err := dd.DialContext(ctx, "tcp", host) + cancel() + if err != nil { + return nil, err + } + return d.finishDialing(conn, host, path) + +} + +func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) { + conn, err := net.DialTimeout("tcp", host, d.timeout) + if err != nil { + return nil, err + } + return d.finishDialing(conn, host, path) +} + +func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) { + conn.SetDeadline(time.Now().Add(d.timeout)) + + if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil { + return nil, err + } + if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil { + return nil, err + } + + // Require successful HTTP response before using the conn. + resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) + if err != nil { + conn.Close() + return nil, err + } + + if resp.Status != "200 OK" { + conn.Close() + return nil, err + } + + conn.SetDeadline(time.Time{}) + + return conn, nil +} + +func Dial(rawURL string) (net.Conn, error) { + return NewDialer().Dial(rawURL) +} + +func DialHTTPS(host, path string) (net.Conn, error) { + return NewDialer().DialHTTPS(host, path) +} + +func DialHTTP(host, path string) (net.Conn, error) { + return NewDialer().DialHTTP(host, path) +} diff --git a/httpconn/conn_test.go b/httpconn/conn_test.go new file mode 100644 index 0000000..5982d65 --- /dev/null +++ b/httpconn/conn_test.go @@ -0,0 +1,42 @@ +package httpconn + +import ( + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "golang.org/x/net/nettest" +) + +func TestNetTest_TestConn(t *testing.T) { + nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { + + connCh := make(chan net.Conn, 1) + doneCh := make(chan bool) + + mux := http.NewServeMux() + mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) { + conn, err := Accept(w, r) + if err != nil { + panic(err) + } + connCh <- conn + <-doneCh + }) + + srv := httptest.NewServer(mux) + + c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect") + if err != nil { + panic(err) + } + c2 = <-connCh + + return c1, c2, func() { + doneCh <- true + srv.Close() + }, nil + }) +} diff --git a/httpconn/go.mod b/httpconn/go.mod new file mode 100644 index 0000000..cf6a02b --- /dev/null +++ b/httpconn/go.mod @@ -0,0 +1,5 @@ +module git.crumpington.com/lib/httpconn + +go 1.23.2 + +require golang.org/x/net v0.30.0 diff --git a/httpconn/go.sum b/httpconn/go.sum new file mode 100644 index 0000000..b338806 --- /dev/null +++ b/httpconn/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= +golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= diff --git a/httpconn/server.go b/httpconn/server.go new file mode 100644 index 0000000..f2a4dd5 --- /dev/null +++ b/httpconn/server.go @@ -0,0 +1,32 @@ +package httpconn + +import ( + "io" + "net" + "net/http" + "time" +) + +func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) { + if r.Method != "CONNECT" { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusMethodNotAllowed) + io.WriteString(w, "405 must CONNECT\n") + return nil, http.ErrNotSupported + } + + hj, ok := w.(http.Hijacker) + if !ok { + return nil, http.ErrNotSupported + } + + conn, _, err := hj.Hijack() + if err != nil { + return nil, err + } + + _, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n") + conn.SetDeadline(time.Time{}) + + return conn, nil +} diff --git a/idgen/README.md b/idgen/README.md new file mode 100644 index 0000000..b1daa5a --- /dev/null +++ b/idgen/README.md @@ -0,0 +1,2 @@ +# idgen + diff --git a/idgen/go.mod b/idgen/go.mod new file mode 100644 index 0000000..65e2186 --- /dev/null +++ b/idgen/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/idgen + +go 1.23.2 diff --git a/idgen/idgen.go b/idgen/idgen.go new file mode 100644 index 0000000..c95a1ef --- /dev/null +++ b/idgen/idgen.go @@ -0,0 +1,53 @@ +package idgen + +import ( + "crypto/rand" + "encoding/base32" + "sync" + "time" +) + +// Creates a new, random token. +func NewToken() string { + buf := make([]byte, 20) + _, err := rand.Read(buf) + if err != nil { + panic(err) + } + return base32.StdEncoding.EncodeToString(buf) +} + +var ( + lock sync.Mutex + ts int64 = time.Now().Unix() + counter int64 = 1 + counterMax int64 = 1 << 20 +) + +// NextID can generate ~1M ints per second for a given nodeID. +// +// nodeID must be less than 64. +func NextID(nodeID int64) int64 { + lock.Lock() + defer lock.Unlock() + + tt := time.Now().Unix() + if tt > ts { + ts = tt + counter = 1 + } else { + counter++ + if counter == counterMax { + panic("Too many IDs.") + } + } + + return (ts << 26) + (nodeID << 20) + counter +} + +func SplitID(id int64) (unixTime, nodeID, counter int64) { + counter = id & (0x00000000000FFFFF) + nodeID = (id >> 20) & (0x000000000000003F) + unixTime = id >> 26 + return +} diff --git a/idgen/idgen_test.go b/idgen/idgen_test.go new file mode 100644 index 0000000..822f6d7 --- /dev/null +++ b/idgen/idgen_test.go @@ -0,0 +1,19 @@ +package idgen + +import ( + "log" + "testing" +) + +func BenchmarkNext(b *testing.B) { + for i := 0; i < b.N; i++ { + NextID(0) + } +} + +func TestNextID(t *testing.T) { + id := NextID(32) + + a, b, c := SplitID(id) + log.Print(a, b, c) +} diff --git a/keyedmutex/README.md b/keyedmutex/README.md new file mode 100644 index 0000000..cf73f70 --- /dev/null +++ b/keyedmutex/README.md @@ -0,0 +1,2 @@ +# keyedmutex + diff --git a/keyedmutex/go.mod b/keyedmutex/go.mod new file mode 100644 index 0000000..dae5dde --- /dev/null +++ b/keyedmutex/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/keyedmutex + +go 1.23.2 diff --git a/keyedmutex/keyedmutex.go b/keyedmutex/keyedmutex.go new file mode 100644 index 0000000..699e4b1 --- /dev/null +++ b/keyedmutex/keyedmutex.go @@ -0,0 +1,67 @@ +package keyedmutex + +import ( + "container/list" + "sync" +) + +type KeyedMutex[K comparable] struct { + mu *sync.Mutex + waitList map[K]*list.List +} + +func New[K comparable]() KeyedMutex[K] { + return KeyedMutex[K]{ + mu: new(sync.Mutex), + waitList: map[K]*list.List{}, + } +} + +func (m KeyedMutex[K]) Lock(key K) { + if ch := m.lock(key); ch != nil { + <-ch + } +} + +func (m KeyedMutex[K]) lock(key K) chan struct{} { + m.mu.Lock() + defer m.mu.Unlock() + + if waitList, ok := m.waitList[key]; ok { + ch := make(chan struct{}) + waitList.PushBack(ch) + return ch + } + + m.waitList[key] = list.New() + return nil +} + +func (m KeyedMutex[K]) TryLock(key K) bool { + m.mu.Lock() + defer m.mu.Unlock() + + if _, ok := m.waitList[key]; ok { + return false + } + + m.waitList[key] = list.New() + return true +} + +func (m KeyedMutex[K]) Unlock(key K) { + m.mu.Lock() + defer m.mu.Unlock() + + waitList, ok := m.waitList[key] + if !ok { + panic("unlock of unlocked mutex") + } + + if waitList.Len() == 0 { + delete(m.waitList, key) + } else { + ch := waitList.Remove(waitList.Front()).(chan struct{}) + ch <- struct{}{} + } +} diff --git a/keyedmutex/keyedmutex_test.go b/keyedmutex/keyedmutex_test.go new file mode 100644 index 0000000..14fdaf0 --- /dev/null +++ b/keyedmutex/keyedmutex_test.go @@ -0,0 +1,123 @@ +package keyedmutex + + +import ( + "sync" + "testing" + "time" +) + +func TestKeyedMutex(t *testing.T) { + checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) { + if len(m.waitList) != len(keys) { + t.Fatal(m.waitList, keys) + } + + for _, key := range keys { + if _, ok := m.waitList[key]; !ok { + t.Fatal(key) + } + } + } + + m := New[string]() + checkState(t, m) + + m.Lock("a") + checkState(t, m, "a") + m.Lock("b") + checkState(t, m, "a", "b") + m.Lock("c") + checkState(t, m, "a", "b", "c") + + if m.TryLock("a") { + t.Fatal("a") + } + if m.TryLock("b") { + t.Fatal("b") + } + if m.TryLock("c") { + t.Fatal("c") + } + + if !m.TryLock("d") { + t.Fatal("d") + } + + checkState(t, m, "a", "b", "c", "d") + + if !m.TryLock("e") { + t.Fatal("e") + } + checkState(t, m, "a", "b", "c", "d", "e") + + m.Unlock("c") + checkState(t, m, "a", "b", "d", "e") + m.Unlock("a") + checkState(t, m, "b", "d", "e") + m.Unlock("e") + checkState(t, m, "b", "d") + + wg := sync.WaitGroup{} + for i := 0; i < 8; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock("b") + m.Unlock("b") + }() + } + + time.Sleep(100 * time.Millisecond) + m.Unlock("b") + wg.Wait() + + checkState(t, m, "d") + + m.Unlock("d") + checkState(t, m) +} + +func TestKeyedMutex_unlockUnlocked(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Fatal(r) + } + }() + + m := New[string]() + m.Unlock("aldkfj") +} + +func BenchmarkUncontendedMutex(b *testing.B) { + m := New[string]() + key := "xyz" + + for i := 0; i < b.N; i++ { + m.Lock(key) + m.Unlock(key) + } +} + +func BenchmarkContendedMutex(b *testing.B) { + m := New[string]() + key := "xyz" + + m.Lock(key) + + wg := sync.WaitGroup{} + for i := 0; i < b.N; i++ { + wg.Add(1) + go func() { + defer wg.Done() + m.Lock(key) + m.Unlock(key) + }() + } + + time.Sleep(time.Second) + + b.ResetTimer() + m.Unlock(key) + wg.Wait() +} diff --git a/kvmemcache/README.md b/kvmemcache/README.md new file mode 100644 index 0000000..cd8ea48 --- /dev/null +++ b/kvmemcache/README.md @@ -0,0 +1,2 @@ +# kvmemcache + diff --git a/kvmemcache/cache.go b/kvmemcache/cache.go new file mode 100644 index 0000000..1648047 --- /dev/null +++ b/kvmemcache/cache.go @@ -0,0 +1,134 @@ +package kvmemcache + +import ( + "container/list" + "sync" + "time" + + "git.crumpington.com/lib/keyedmutex" +) + +type Cache[K comparable, V any] struct { + updateLock keyedmutex.KeyedMutex[K] + src func(K) (V, error) + ttl time.Duration + maxSize int + + // Lock protects variables below. + lock sync.Mutex + cache map[K]*list.Element + ll *list.List + stats Stats +} + +type lruItem[K comparable, V any] struct { + key K + createdAt time.Time + value V + err error +} + +type Config[K comparable, V any] struct { + MaxSize int + TTL time.Duration // Zero to ignore. + Src func(K) (V, error) +} + +func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] { + return &Cache[K, V]{ + updateLock: keyedmutex.New[K](), + src: conf.Src, + ttl: conf.TTL, + maxSize: conf.MaxSize, + lock: sync.Mutex{}, + cache: make(map[K]*list.Element, conf.MaxSize+1), + ll: list.New(), + } +} + +func (c *Cache[K, V]) Get(key K) (V, error) { + ok, val, err := c.get(key) + if ok { + return val, err + } + + return c.load(key) +} + +func (c *Cache[K, V]) Evict(key K) { + c.lock.Lock() + defer c.lock.Unlock() + c.evict(key) +} + +func (c *Cache[K, V]) Stats() Stats { + c.lock.Lock() + defer c.lock.Unlock() + return c.stats +} + +func (c *Cache[K, V]) put(key K, value V, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + c.stats.Misses++ + + c.cache[key] = c.ll.PushFront(lruItem[K, V]{ + key: key, + createdAt: time.Now(), + value: value, + err: err, + }) + + if c.maxSize != 0 && len(c.cache) > c.maxSize { + li := c.ll.Back() + c.ll.Remove(li) + delete(c.cache, li.Value.(lruItem[K, V]).key) + } +} + +func (c *Cache[K, V]) evict(key K) { + elem := c.cache[key] + if elem != nil { + delete(c.cache, key) + c.ll.Remove(elem) + } +} + +func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) { + c.lock.Lock() + defer c.lock.Unlock() + + li := c.cache[key] + if li == nil { + return false, val, nil + } + + item := li.Value.(lruItem[K, V]) + // Maybe evict. + if c.ttl != 0 && time.Since(item.createdAt) > c.ttl { + c.evict(key) + return false, val, nil + } + + c.stats.Hits++ + + c.ll.MoveToFront(li) + return true, item.value, item.err +} + +func (c *Cache[K, V]) load(key K) (V, error) { + c.updateLock.Lock(key) + defer c.updateLock.Unlock(key) + + // Check again in case we lost the update race. + ok, val, err := c.get(key) + if ok { + return val, err + } + + // Won the update race. + val, err = c.src(key) + c.put(key, val, err) + return val, err +} diff --git a/kvmemcache/cache_test.go b/kvmemcache/cache_test.go new file mode 100644 index 0000000..bc16943 --- /dev/null +++ b/kvmemcache/cache_test.go @@ -0,0 +1,249 @@ +package kvmemcache + +import ( + "errors" + "fmt" + "sync" + "testing" + "time" +) + +type State[K comparable] struct { + Keys []K + Stats Stats +} + +func (c *Cache[K,V]) assert(state State[K]) error { + c.lock.Lock() + defer c.lock.Unlock() + + if len(c.cache) != len(state.Keys) { + return fmt.Errorf( + "Expected %d keys but found %d.", + len(state.Keys), + len(c.cache)) + } + + for _, k := range state.Keys { + if _, ok := c.cache[k]; !ok { + return fmt.Errorf( + "Expected key %v not found.", + k) + } + } + + if c.stats.Hits != state.Stats.Hits { + return fmt.Errorf( + "Expected %d hits, but found %d.", + state.Stats.Hits, + c.stats.Hits) + } + + if c.stats.Misses != state.Stats.Misses { + return fmt.Errorf( + "Expected %d misses, but found %d.", + state.Stats.Misses, + c.stats.Misses) + } + + return nil +} + +var ErrTest = errors.New("Hello") + +func TestCache_basic(t *testing.T) { + c := New(Config[string, string]{ + MaxSize: 4, + TTL: 50 * time.Millisecond, + Src: func(key string) (string, error) { + if key == "err" { + return "", ErrTest + } + return key, nil + }, + }) + + type testCase struct { + name string + sleep time.Duration + key string + evict bool + state State[string] + } + + cases := []testCase{ + { + name: "get a", + key: "a", + state: State[string]{ + Keys: []string{"a"}, + Stats: Stats{Hits: 0, Misses: 1}, + }, + }, { + name: "get a again", + key: "a", + state: State[string]{ + Keys: []string{"a"}, + Stats: Stats{Hits: 1, Misses: 1}, + }, + }, { + name: "sleep, then get a again", + sleep: 55 * time.Millisecond, + key: "a", + state: State[string]{ + Keys: []string{"a"}, + Stats: Stats{Hits: 1, Misses: 2}, + }, + }, { + name: "get b", + key: "b", + state: State[string]{ + Keys: []string{"a", "b"}, + Stats: Stats{Hits: 1, Misses: 3}, + }, + }, { + name: "get c", + key: "c", + state: State[string]{ + Keys: []string{"a", "b", "c"}, + Stats: Stats{Hits: 1, Misses: 4}, + }, + }, { + name: "get d", + key: "d", + state: State[string]{ + Keys: []string{"a", "b", "c", "d"}, + Stats: Stats{Hits: 1, Misses: 5}, + }, + }, { + name: "get e", + key: "e", + state: State[string]{ + Keys: []string{"b", "c", "d", "e"}, + Stats: Stats{Hits: 1, Misses: 6}, + }, + }, { + name: "get c again", + key: "c", + state: State[string]{ + Keys: []string{"b", "c", "d", "e"}, + Stats: Stats{Hits: 2, Misses: 6}, + }, + }, { + name: "get err", + key: "err", + state: State[string]{ + Keys: []string{"c", "d", "e", "err"}, + Stats: Stats{Hits: 2, Misses: 7}, + }, + }, { + name: "get err again", + key: "err", + state: State[string]{ + Keys: []string{"c", "d", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 7}, + }, + }, { + name: "evict c", + key: "c", + evict: true, + state: State[string]{ + Keys: []string{"d", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 7}, + }, + }, { + name: "reload-all a", + key: "a", + state: State[string]{ + Keys: []string{"a", "d", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 8}, + }, + }, { + name: "reload-all b", + key: "b", + state: State[string]{ + Keys: []string{"a", "b", "e", "err"}, + Stats: Stats{Hits: 3, Misses: 9}, + }, + }, { + name: "reload-all c", + key: "c", + state: State[string]{ + Keys: []string{"a", "b", "c", "err"}, + Stats: Stats{Hits: 3, Misses: 10}, + }, + }, { + name: "reload-all d", + key: "d", + state: State[string]{ + Keys: []string{"a", "b", "c", "d"}, + Stats: Stats{Hits: 3, Misses: 11}, + }, + }, { + name: "read a again", + key: "a", + state: State[string]{ + Keys: []string{"b", "c", "d", "a"}, + Stats: Stats{Hits: 4, Misses: 11}, + }, + }, { + name: "read e, evicting b", + key: "e", + state: State[string]{ + Keys: []string{"c", "d", "a", "e"}, + Stats: Stats{Hits: 4, Misses: 12}, + }, + }, + } + + for _, tc := range cases { + time.Sleep(tc.sleep) + if !tc.evict { + val, err := c.Get(tc.key) + if tc.key == "err" && err != ErrTest { + t.Fatal(tc.name, val) + } + if tc.key != "err" && val != tc.key { + t.Fatal(tc.name, tc.key, val) + } + } else { + c.Evict(tc.key) + } + + if err := c.assert(tc.state); err != nil { + t.Fatal(err) + } + } +} + +func TestCache_thunderingHerd(t *testing.T) { + c := New(Config[string,string]{ + MaxSize: 4, + Src: func(key string) (string, error) { + time.Sleep(time.Second) + return key, nil + }, + }) + + wg := sync.WaitGroup{} + for i := 0; i < 16384; i++ { + wg.Add(1) + go func() { + defer wg.Done() + val, err := c.Get("a") + if err != nil { + panic(err) + } + if val != "a" { + panic(err) + } + }() + } + + wg.Wait() + + stats := c.Stats() + if stats.Hits != 16383 || stats.Misses != 1 { + t.Fatal(stats) + } +} diff --git a/kvmemcache/go.mod b/kvmemcache/go.mod new file mode 100644 index 0000000..1f3d823 --- /dev/null +++ b/kvmemcache/go.mod @@ -0,0 +1,5 @@ +module git.crumpington.com/lib/kvmemcache + +go 1.23.2 + +require git.crumpington.com/lib/keyedmutex v1.0.1 diff --git a/kvmemcache/go.sum b/kvmemcache/go.sum new file mode 100644 index 0000000..1581bac --- /dev/null +++ b/kvmemcache/go.sum @@ -0,0 +1,2 @@ +git.crumpington.com/lib/keyedmutex v1.0.1 h1:5ylwGXQzL9ojZIhlqkut6dpa4yt6Wz6bOWbf/tQBAMQ= +git.crumpington.com/lib/keyedmutex v1.0.1/go.mod h1:VxxJRU/XvvF61IuJZG7kUIv954Q8+Rh8bnVpEzGYrQ4= diff --git a/kvmemcache/stats.go b/kvmemcache/stats.go new file mode 100644 index 0000000..b6415f9 --- /dev/null +++ b/kvmemcache/stats.go @@ -0,0 +1,6 @@ +package kvmemcache + +type Stats struct { + Hits uint64 + Misses uint64 +} diff --git a/mmap/README.md b/mmap/README.md new file mode 100644 index 0000000..8ae69d4 --- /dev/null +++ b/mmap/README.md @@ -0,0 +1,2 @@ +# mmap + diff --git a/mmap/file.go b/mmap/file.go new file mode 100644 index 0000000..e35676a --- /dev/null +++ b/mmap/file.go @@ -0,0 +1,100 @@ +package mmap + +import "os" + +type File struct { + f *os.File + Map []byte +} + +func Create(path string, size int64) (*File, error) { + f, err := os.Create(path) + if err != nil { + return nil, err + } + if err := f.Truncate(size); err != nil { + f.Close() + return nil, err + } + m, err := Map(f, PROT_READ|PROT_WRITE) + if err != nil { + f.Close() + return nil, err + } + return &File{f, m}, nil +} + +// Opens a mapped file in read-only mode. +func Open(path string) (*File, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + m, err := Map(f, PROT_READ) + if err != nil { + f.Close() + return nil, err + } + return &File{f, m}, nil +} + +func OpenFile( + path string, + fileFlags int, + perm os.FileMode, + size int64, // -1 for file size. +) (*File, error) { + f, err := os.OpenFile(path, fileFlags, perm) + if err != nil { + return nil, err + } + + writable := fileFlags|os.O_RDWR != 0 || fileFlags|os.O_WRONLY != 0 + + fi, err := f.Stat() + if err != nil { + f.Close() + return nil, err + } + + if writable && size > 0 && size != fi.Size() { + if err := f.Truncate(size); err != nil { + f.Close() + return nil, err + } + } + + mapFlags := PROT_READ + if writable { + mapFlags |= PROT_WRITE + } + + m, err := Map(f, mapFlags) + if err != nil { + f.Close() + return nil, err + } + + return &File{f, m}, nil +} + +func (f *File) Sync() error { + return Sync(f.Map) +} + +func (f *File) Close() error { + if f.Map != nil { + if err := Unmap(f.Map); err != nil { + return err + } + f.Map = nil + } + + if f.f != nil { + if err := f.f.Close(); err != nil { + return err + } + f.f = nil + } + return nil +} diff --git a/mmap/go.mod b/mmap/go.mod new file mode 100644 index 0000000..c265826 --- /dev/null +++ b/mmap/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/mmap + +go 1.23.2 diff --git a/mmap/go.sum b/mmap/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/mmap/mmap.go b/mmap/mmap.go new file mode 100644 index 0000000..21458a0 --- /dev/null +++ b/mmap/mmap.go @@ -0,0 +1,59 @@ +package mmap + +import ( + "os" + "syscall" + "unsafe" +) + +const ( + PROT_READ = syscall.PROT_READ + PROT_WRITE = syscall.PROT_WRITE +) + +// Mmap creates a memory map of the given file. The flags argument should be a +// combination of PROT_READ and PROT_WRITE. The size of the map will be the +// file's size. +func Map(f *os.File, flags int) ([]byte, error) { + fi, err := f.Stat() + if err != nil { + return nil, err + } + size := fi.Size() + + addr, _, errno := syscall.Syscall6( + syscall.SYS_MMAP, + 0, // addr: 0 => allow kernel to choose + uintptr(size), + uintptr(flags), + uintptr(syscall.MAP_SHARED), + f.Fd(), + 0) // offset: 0 => start of file + if errno != 0 { + return nil, syscall.Errno(errno) + } + + return unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), nil +} + +// Munmap unmaps the data obtained by Map. +func Unmap(data []byte) error { + _, _, errno := syscall.Syscall( + syscall.SYS_MUNMAP, + uintptr(unsafe.Pointer(&data[:1][0])), + uintptr(cap(data)), + 0) + if errno != 0 { + return syscall.Errno(errno) + } + return nil +} + +func Sync(b []byte) (err error) { + _p0 := unsafe.Pointer(&b[0]) + _, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(_p0), uintptr(len(b)), uintptr(syscall.MS_SYNC)) + if errno != 0 { + err = syscall.Errno(errno) + } + return +} diff --git a/pgutil/README.md b/pgutil/README.md new file mode 100644 index 0000000..595ce78 --- /dev/null +++ b/pgutil/README.md @@ -0,0 +1,45 @@ +# pgutil + +## Transactions + +Simplify postgres transactions using `WithTx` for serializable transactions, +or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner` +type to get automatic retries of serialization errors. + +## Migrations + +Put your migrations into a directory, for example `migrations`, ordered by name +(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the +`Migrate` function: + +```Go +//go:embed migrations +var migrations embed.FS + +func init() { + Migrate(db, migrations) // Check the error, of course. +} +``` + +## Testing + +In order to test this packge, we need to create a test user and database: + +``` +sudo su postgres +psql + +CREATE DATABASE test; +CREATE USER test WITH ENCRYPTED PASSWORD 'test'; +GRANT ALL PRIVILEGES ON DATABASE test TO test; + +use test + +GRANT ALL ON SCHEMA public TO test; +``` + +Check that you can connect via the command line: + +``` +psql -h 127.0.0.1 -U test --password test +``` diff --git a/pgutil/dropall.go b/pgutil/dropall.go new file mode 100644 index 0000000..d4d6bf1 --- /dev/null +++ b/pgutil/dropall.go @@ -0,0 +1,42 @@ +package pgutil + +import ( + "database/sql" + "log" +) + +const dropTablesQueryQuery = ` +SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;' +FROM + pg_tables +WHERE + schemaname='public'` + +// Deletes all tables in the database. Useful for testing. +func DropAllTables(db *sql.DB) error { + rows, err := db.Query(dropTablesQueryQuery) + if err != nil { + return err + } + + queries := []string{} + for rows.Next() { + var s string + if err := rows.Scan(&s); err != nil { + return err + } + queries = append(queries, s) + } + + if len(queries) > 0 { + log.Printf("DROPPING ALL (%d) TABLES", len(queries)) + } + + for _, query := range queries { + if _, err := db.Exec(query); err != nil { + return err + } + } + + return nil +} diff --git a/pgutil/errors.go b/pgutil/errors.go new file mode 100644 index 0000000..ec26f75 --- /dev/null +++ b/pgutil/errors.go @@ -0,0 +1,31 @@ +package pgutil + +import ( + "errors" + + "github.com/lib/pq" +) + +func ErrIsDuplicateKey(err error) bool { + return ErrHasCode(err, "23505") +} + +func ErrIsForeignKey(err error) bool { + return ErrHasCode(err, "23503") +} + +func ErrIsSerializationFaiilure(err error) bool { + return ErrHasCode(err, "40001") +} + +func ErrHasCode(err error, code string) bool { + if err == nil { + return false + } + + var pErr *pq.Error + if errors.As(err, &pErr) { + return pErr.Code == pq.ErrorCode(code) + } + return false +} diff --git a/pgutil/errors_test.go b/pgutil/errors_test.go new file mode 100644 index 0000000..581ae94 --- /dev/null +++ b/pgutil/errors_test.go @@ -0,0 +1,36 @@ +package pgutil + +import ( + "database/sql" + "testing" +) + +func TestErrors(t *testing.T) { + db, err := sql.Open( + "postgres", + "host=127.0.0.1 dbname=test sslmode=disable user=test password=test") + if err != nil { + t.Fatal(err) + } + + if err := DropAllTables(db); err != nil { + t.Fatal(err) + } + + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + _, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`) + if !ErrIsDuplicateKey(err) { + t.Fatal(err) + } + _, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`) + if !ErrIsDuplicateKey(err) { + t.Fatal(err) + } + _, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`) + if !ErrIsForeignKey(err) { + t.Fatal(err) + } +} diff --git a/pgutil/go.mod b/pgutil/go.mod new file mode 100644 index 0000000..6cab6ef --- /dev/null +++ b/pgutil/go.mod @@ -0,0 +1,5 @@ +module git.crumpington.com/git/pgutil + +go 1.23.2 + +require github.com/lib/pq v1.10.9 diff --git a/pgutil/go.sum b/pgutil/go.sum new file mode 100644 index 0000000..aeddeae --- /dev/null +++ b/pgutil/go.sum @@ -0,0 +1,2 @@ +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= diff --git a/pgutil/migrate.go b/pgutil/migrate.go new file mode 100644 index 0000000..0b26c9e --- /dev/null +++ b/pgutil/migrate.go @@ -0,0 +1,82 @@ +package pgutil + +import ( + "database/sql" + "embed" + "errors" + "fmt" + "path/filepath" + "sort" +) + +const initMigrationTableQuery = ` +CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` + +const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` + +const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` + +func Migrate(db *sql.DB, migrationFS embed.FS) error { + return WithTx(db, func(tx *sql.Tx) error { + if _, err := tx.Exec(initMigrationTableQuery); err != nil { + return err + } + + dirs, err := migrationFS.ReadDir(".") + if err != nil { + return err + } + if len(dirs) != 1 { + return errors.New("expected a single migrations directory") + } + + if !dirs[0].IsDir() { + return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) + } + + dirName := dirs[0].Name() + files, err := migrationFS.ReadDir(dirName) + if err != nil { + return err + } + + // Sort sql files by name. + sort.Slice(files, func(i, j int) bool { + return files[i].Name() < files[j].Name() + }) + + for _, dirEnt := range files { + if !dirEnt.Type().IsRegular() { + return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) + } + + var ( + name = dirEnt.Name() + exists bool + ) + + err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) + if err != nil { + return err + } + + if exists { + continue + } + + migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) + if err != nil { + return err + } + + if _, err := tx.Exec(string(migration)); err != nil { + return fmt.Errorf("migration %s failed: %v", name, err) + } + + if _, err := tx.Exec(insertMigrationQuery, name); err != nil { + return err + } + } + return nil + }) +} diff --git a/pgutil/migrate_test.go b/pgutil/migrate_test.go new file mode 100644 index 0000000..5d0136f --- /dev/null +++ b/pgutil/migrate_test.go @@ -0,0 +1,50 @@ +package pgutil + +import ( + "database/sql" + "embed" + "testing" +) + +//go:embed test-migrations +var testMigrationFS embed.FS + +func TestMigrate(t *testing.T) { + db, err := sql.Open( + "postgres", + "host=127.0.0.1 dbname=test sslmode=disable user=test password=test") + if err != nil { + t.Fatal(err) + } + + if err := DropAllTables(db); err != nil { + t.Fatal(err) + } + + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + // Shouldn't have any effect. + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` + var exists bool + + if err = db.QueryRow(query, 1).Scan(&exists); err != nil { + t.Fatal(err) + } + if exists { + t.Fatal("1 shouldn't exist") + } + + if err = db.QueryRow(query, 2).Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatal("2 should exist") + } + +} diff --git a/pgutil/test-migrations/000.sql b/pgutil/test-migrations/000.sql new file mode 100644 index 0000000..ecb559c --- /dev/null +++ b/pgutil/test-migrations/000.sql @@ -0,0 +1,9 @@ +CREATE TABLE users( + UserID BIGINT NOT NULL PRIMARY KEY, + Email TEXT NOT NULL UNIQUE); + +CREATE TABLE user_notes( + UserID BIGINT NOT NULL REFERENCES users(UserID), + NoteID BIGINT NOT NULL, + Note Text NOT NULL, + PRIMARY KEY(UserID,NoteID)); diff --git a/pgutil/test-migrations/001.sql b/pgutil/test-migrations/001.sql new file mode 100644 index 0000000..e424c57 --- /dev/null +++ b/pgutil/test-migrations/001.sql @@ -0,0 +1 @@ +INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); diff --git a/pgutil/test-migrations/002.sql b/pgutil/test-migrations/002.sql new file mode 100644 index 0000000..ba414d2 --- /dev/null +++ b/pgutil/test-migrations/002.sql @@ -0,0 +1 @@ +DELETE FROM users WHERE UserID=1; diff --git a/pgutil/tx.go b/pgutil/tx.go new file mode 100644 index 0000000..9d7156f --- /dev/null +++ b/pgutil/tx.go @@ -0,0 +1,70 @@ +package pgutil + +import ( + "database/sql" + "math/rand" + "time" +) + +// Postgres doesn't use serializable transactions by default. This wrapper will +// run the enclosed function within a serializable. Note: this may return an +// retriable serialization error (see ErrIsSerializationFaiilure). +func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { + return WithTxDefault(db, func(tx *sql.Tx) error { + if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil { + return err + } + return fn(tx) + }) +} + +// This is a convenience function to provide a transaction wrapper with the +// default isolation level. +func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error { + // Start a transaction. + tx, err := db.Begin() + if err != nil { + return err + } + + err = fn(tx) + + if err == nil { + err = tx.Commit() + } + + if err != nil { + _ = tx.Rollback() + } + + return err +} + +// SerialTxRunner attempts serializable transactions in a loop. If a +// transaction fails due to a serialization error, then the runner will retry +// with exponential backoff, until the sleep time reaches MaxTimeout. +// +// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep +// for ~100, 200, 400, and 800 ms between retries. +// +// 10% jitter is added to the sleep time. +type SerialTxRunner struct { + MinTimeout time.Duration + MaxTimeout time.Duration +} + +func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error { + timeout := r.MinTimeout + for { + err := WithTx(db, fn) + if err == nil { + return nil + } + if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) { + return err + } + sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10))) + time.Sleep(sleepTimeout) + timeout *= 2 + } +} diff --git a/pgutil/tx_test.go b/pgutil/tx_test.go new file mode 100644 index 0000000..5b898c6 --- /dev/null +++ b/pgutil/tx_test.go @@ -0,0 +1,120 @@ +package pgutil + +import ( + "database/sql" + "fmt" + "sync" + "testing" + "time" +) + +// TestExecuteTx verifies transaction retry using the classic +// example of write skew in bank account balance transfers. +func TestWithTx(t *testing.T) { + db, err := sql.Open( + "postgres", + "host=127.0.0.1 dbname=test sslmode=disable user=test password=test") + if err != nil { + t.Fatal(err) + } + + if err := DropAllTables(db); err != nil { + t.Fatal(err) + } + + defer db.Close() + + initStmt := ` +CREATE TABLE t (acct INT PRIMARY KEY, balance INT); +INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100); +` + if _, err := db.Exec(initStmt); err != nil { + t.Fatal(err) + } + + type queryI interface { + Query(string, ...interface{}) (*sql.Rows, error) + } + + getBalances := func(q queryI) (bal1, bal2 int, err error) { + var rows *sql.Rows + rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`) + if err != nil { + return + } + defer rows.Close() + balances := []*int{&bal1, &bal2} + i := 0 + for ; rows.Next(); i++ { + if err = rows.Scan(balances[i]); err != nil { + return + } + } + if i != 2 { + err = fmt.Errorf("expected two balances; got %d", i) + return + } + return + } + + txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond} + + runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error { + errCh := make(chan error, 1) + go func() { + *iter = 0 + errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error { + *iter++ + bal1, bal2, err := getBalances(tx) + if err != nil { + return err + } + // If this is the first iteration, wait for the other tx to + // also read. + if *iter == 1 { + wg.Done() + wg.Wait() + } + // Now, subtract from one account and give to the other. + if bal1 > bal2 { + if _, err := tx.Exec(` +UPDATE t SET balance=balance-100 WHERE acct=1; +UPDATE t SET balance=balance+100 WHERE acct=2; +`); err != nil { + return err + } + } else { + if _, err := tx.Exec(` +UPDATE t SET balance=balance+100 WHERE acct=1; +UPDATE t SET balance=balance-100 WHERE acct=2; +`); err != nil { + return err + } + } + return nil + }) + }() + return errCh + } + + var wg sync.WaitGroup + wg.Add(2) + var iters1, iters2 int + txn1Err := runTxn(&wg, &iters1) + txn2Err := runTxn(&wg, &iters2) + if err := <-txn1Err; err != nil { + t.Errorf("expected success in txn1; got %s", err) + } + if err := <-txn2Err; err != nil { + t.Errorf("expected success in txn2; got %s", err) + } + if iters1+iters2 <= 2 { + t.Errorf("expected retries between the competing transactions; "+ + "got txn1=%d, txn2=%d", iters1, iters2) + } + bal1, bal2, err := getBalances(db) + if err != nil || bal1 != 100 || bal2 != 100 { + t.Errorf("expected balances to be restored without error; "+ + "got acct1=%d, acct2=%d: %s", bal1, bal2, err) + } +} diff --git a/ratelimiter/go.mod b/ratelimiter/go.mod new file mode 100644 index 0000000..b6a02db --- /dev/null +++ b/ratelimiter/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/ratelimiter + +go 1.23.2 diff --git a/ratelimiter/ratelimiter.go b/ratelimiter/ratelimiter.go new file mode 100644 index 0000000..f17b473 --- /dev/null +++ b/ratelimiter/ratelimiter.go @@ -0,0 +1,86 @@ +package ratelimiter + +import ( + "errors" + "sync" + "time" +) + +var ErrBackoff = errors.New("Backoff") + +type Config struct { + BurstLimit int64 // Number of requests to allow to burst. + FillPeriod time.Duration // Add one per period. + MaxWaitCount int64 // Max number of waiting requests. 0 disables. +} + +type Limiter struct { + lock sync.Mutex + + fillPeriod time.Duration + minWaitTime time.Duration + maxWaitTime time.Duration + + waitTime time.Duration // If waitTime < 0, no waiting occurs. + lastRequest time.Time +} + +func New(conf Config) *Limiter { + if conf.BurstLimit < 0 { + panic(conf.BurstLimit) + } + if conf.FillPeriod <= 0 { + panic(conf.FillPeriod) + } + if conf.MaxWaitCount < 0 { + panic(conf.MaxWaitCount) + } + + lim := &Limiter{ + lastRequest: time.Now(), + fillPeriod: conf.FillPeriod, + waitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), + minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), + maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount), + } + + lim.waitTime = lim.minWaitTime + + return lim +} + +func (lim *Limiter) limit(count int64) (time.Duration, error) { + lim.lock.Lock() + defer lim.lock.Unlock() + + dt := time.Since(lim.lastRequest) + + waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod + + if waitTime < lim.minWaitTime { + waitTime = lim.minWaitTime + } else if waitTime > lim.maxWaitTime { + return 0, ErrBackoff + } + + lim.waitTime = waitTime + lim.lastRequest = lim.lastRequest.Add(dt) + + return lim.waitTime, nil +} + +// Apply the limiter to the calling thread. The function may sleep for up to +// maxWaitTime before returning. If the timeout would need to be more than +// maxWaitTime to enforce the rate limit, ErrBackoff is returned. +func (lim *Limiter) Limit() error { + dt, err := lim.limit(1) + time.Sleep(dt) // Will return immediately for dt <= 0. + return err +} + +// Apply the limiter for multiple items at once. +func (lim *Limiter) LimitMultiple(count int64) error { + dt, err := lim.limit(count) + time.Sleep(dt) + return err +} diff --git a/ratelimiter/ratelimiter_test.go b/ratelimiter/ratelimiter_test.go new file mode 100644 index 0000000..d3f0ef6 --- /dev/null +++ b/ratelimiter/ratelimiter_test.go @@ -0,0 +1,101 @@ +package ratelimiter + +import ( + "sync" + "testing" + "time" +) + +func TestRateLimiter_Limit_Errors(t *testing.T) { + type TestCase struct { + Name string + Conf Config + N int + ErrCount int + DT time.Duration + } + + cases := []TestCase{ + { + Name: "no burst, no wait", + Conf: Config{ + BurstLimit: 0, + FillPeriod: 100 * time.Millisecond, + MaxWaitCount: 0, + }, + N: 32, + ErrCount: 32, + DT: 0, + }, { + Name: "no wait", + Conf: Config{ + BurstLimit: 10, + FillPeriod: 100 * time.Millisecond, + MaxWaitCount: 0, + }, + N: 32, + ErrCount: 22, + DT: 0, + }, { + Name: "no burst", + Conf: Config{ + BurstLimit: 0, + FillPeriod: 10 * time.Millisecond, + MaxWaitCount: 10, + }, + N: 32, + ErrCount: 22, + DT: 100 * time.Millisecond, + }, { + Name: "burst and wait", + Conf: Config{ + BurstLimit: 10, + FillPeriod: 10 * time.Millisecond, + MaxWaitCount: 10, + }, + N: 32, + ErrCount: 12, + DT: 100 * time.Millisecond, + }, + } + + for _, tc := range cases { + wg := sync.WaitGroup{} + l := New(tc.Conf) + errs := make([]error, tc.N) + + t0 := time.Now() + + for i := 0; i < tc.N; i++ { + wg.Add(1) + go func(i int) { + errs[i] = l.Limit() + wg.Done() + }(i) + } + + wg.Wait() + + dt := time.Since(t0) + + errCount := 0 + for _, err := range errs { + if err != nil { + errCount++ + } + } + + if errCount != tc.ErrCount { + t.Fatalf("%s: Expected %d errors but got %d.", + tc.Name, tc.ErrCount, errCount) + } + + if dt < tc.DT { + t.Fatal(tc.Name, dt, tc.DT) + } + + if dt > tc.DT+10*time.Millisecond { + t.Fatal(tc.Name, dt, tc.DT) + } + } +} diff --git a/sqlgen/README.md b/sqlgen/README.md new file mode 100644 index 0000000..27e99cf --- /dev/null +++ b/sqlgen/README.md @@ -0,0 +1,22 @@ +# sqlgen + +## Installing + +``` +go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest +``` + +## Usage + +``` +sqlgen [driver] [defs-path] [output-path] +``` + +## File Format + +``` +TABLE [sql-name] OF [go-type] ( + [sql-column] [go-type] , + ... +); +``` diff --git a/sqlgen/cmd/sqlgen/main.go b/sqlgen/cmd/sqlgen/main.go new file mode 100644 index 0000000..958cb62 --- /dev/null +++ b/sqlgen/cmd/sqlgen/main.go @@ -0,0 +1,7 @@ +package main + +import "git.crumpington.com/lib/sqlgen" + +func main() { + sqlgen.Main() +} diff --git a/sqlgen/go.mod b/sqlgen/go.mod new file mode 100644 index 0000000..d8275a5 --- /dev/null +++ b/sqlgen/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/sqlgen + +go 1.23.2 diff --git a/sqlgen/main.go b/sqlgen/main.go new file mode 100644 index 0000000..170629b --- /dev/null +++ b/sqlgen/main.go @@ -0,0 +1,43 @@ +package sqlgen + +import ( + "fmt" + "os" +) + +func Main() { + usage := func() { + fmt.Fprintf(os.Stderr, ` +%s DRIVER DEFS_PATH OUTPUT_PATH + +Drivers are one of: sqlite, postgres +`, + os.Args[0]) + os.Exit(1) + } + + if len(os.Args) != 4 { + usage() + } + + var ( + template string + driver = os.Args[1] + defsPath = os.Args[2] + outputPath = os.Args[3] + ) + + switch driver { + case "sqlite": + template = sqliteTemplate + default: + fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver) + usage() + } + + err := render(template, defsPath, outputPath) + if err != nil { + fmt.Fprintf(os.Stderr, "Error: %v", err) + os.Exit(1) + } +} diff --git a/sqlgen/parse.go b/sqlgen/parse.go new file mode 100644 index 0000000..b1e9c64 --- /dev/null +++ b/sqlgen/parse.go @@ -0,0 +1,143 @@ +package sqlgen + +import ( + "errors" + "os" + "strings" +) + +func parsePath(filePath string) (*schema, error) { + fileBytes, err := os.ReadFile(filePath) + if err != nil { + return nil, err + } + return parseBytes(fileBytes) +} + +func parseBytes(fileBytes []byte) (*schema, error) { + s := string(fileBytes) + for _, c := range []string{",", "(", ")", ";"} { + s = strings.ReplaceAll(s, c, " "+c+" ") + } + + var ( + tokens = strings.Fields(s) + schema = &schema{} + err error + ) + + for len(tokens) > 0 { + switch tokens[0] { + case "TABLE": + tokens, err = parseTable(schema, tokens) + if err != nil { + return nil, err + } + + default: + return nil, errors.New("invalid token: " + tokens[0]) + } + } + + return schema, nil +} + +func parseTable(schema *schema, tokens []string) ([]string, error) { + tokens = tokens[1:] + if len(tokens) < 3 { + return tokens, errors.New("incomplete table definition") + } + if tokens[1] != "OF" { + return tokens, errors.New("expected OF in table definition") + } + + table := &table{ + Name: tokens[0], + Type: tokens[2], + } + schema.Tables = append(schema.Tables, table) + + tokens = tokens[3:] + + if len(tokens) == 0 { + return tokens, errors.New("missing table definition body") + } + + for len(tokens) > 0 { + switch tokens[0] { + case "NoInsert": + table.NoInsert = true + tokens = tokens[1:] + case "NoUpdate": + table.NoUpdate = true + tokens = tokens[1:] + case "NoDelete": + table.NoDelete = true + tokens = tokens[1:] + case "(": + return parseTableBody(table, tokens[1:]) + default: + return tokens, errors.New("unexpected token in table definition: " + tokens[0]) + } + } + + return tokens, errors.New("incomplete table definition") +} + +func parseTableBody(table *table, tokens []string) ([]string, error) { + var err error + + for len(tokens) > 0 && tokens[0] != ";" { + tokens, err = parseTableColumn(table, tokens) + if err != nil { + return tokens, err + } + } + + if len(tokens) < 1 || tokens[0] != ";" { + return tokens, errors.New("incomplete table column definitions") + } + + return tokens[1:], nil +} + +func parseTableColumn(table *table, tokens []string) ([]string, error) { + if len(tokens) < 2 { + return tokens, errors.New("incomplete column definition") + } + column := &column{ + Name: tokens[0], + Type: tokens[1], + SqlName: tokens[0], + } + table.Columns = append(table.Columns, column) + + tokens = tokens[2:] + for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" { + switch tokens[0] { + case "AS": + if len(tokens) < 2 { + return tokens, errors.New("incomplete AS clause in column definition") + } + column.Name = tokens[1] + tokens = tokens[2:] + case "PK": + column.PK = true + tokens = tokens[1:] + case "NoInsert": + column.NoInsert = true + tokens = tokens[1:] + case "NoUpdate": + column.NoUpdate = true + tokens = tokens[1:] + default: + return tokens, errors.New("unexpected token in column definition: " + tokens[0]) + } + } + + if len(tokens) == 0 { + return tokens, errors.New("incomplete column definition") + } + + return tokens[1:], nil +} diff --git a/sqlgen/parse_test.go b/sqlgen/parse_test.go new file mode 100644 index 0000000..9838b69 --- /dev/null +++ b/sqlgen/parse_test.go @@ -0,0 +1,45 @@ +package sqlgen + +import ( + "encoding/json" + "os" + "path/filepath" + "reflect" + "strings" + "testing" +) + +func TestParse(t *testing.T) { + toString := func(v any) string { + txt, _ := json.MarshalIndent(v, "", " ") + return string(txt) + } + + paths, err := filepath.Glob("test-files/TestParse/*.def") + if err != nil { + t.Fatal(err) + } + + for _, defPath := range paths { + t.Run(filepath.Base(defPath), func(t *testing.T) { + parsed, err := parsePath(defPath) + if err != nil { + t.Fatal(err) + } + + b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json") + if err != nil { + t.Fatal(err) + } + + expected := &schema{} + if err := json.Unmarshal(b, expected); err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(parsed, expected) { + t.Fatalf("%s != %s", toString(parsed), toString(expected)) + } + }) + } +} diff --git a/sqlgen/schema.go b/sqlgen/schema.go new file mode 100644 index 0000000..a75dc7b --- /dev/null +++ b/sqlgen/schema.go @@ -0,0 +1,263 @@ +package sqlgen + +import ( + "fmt" + "strings" +) + +type schema struct { + Tables []*table +} + +type table struct { + Name string // Name in SQL + Type string // Go type + NoInsert bool + NoUpdate bool + NoDelete bool + Columns []*column +} + +type column struct { + Name string + Type string + SqlName string // Defaults to Name + PK bool // PK won't be updated + NoInsert bool + NoUpdate bool // Don't update column in update function +} + +// ---------------------------------------------------------------------------- + +func (t *table) colSQLNames() []string { + names := make([]string, len(t.Columns)) + for i := range names { + names[i] = t.Columns[i].SqlName + } + return names +} + +func (t *table) SelectQuery() string { + return fmt.Sprintf(`SELECT %s FROM %s`, + strings.Join(t.colSQLNames(), ","), + t.Name) +} + +func (t *table) insertCols() (cols []*column) { + for _, c := range t.Columns { + if !c.NoInsert { + cols = append(cols, c) + } + } + return cols +} + +func (t *table) InsertQuery() string { + cols := t.insertCols() + b := &strings.Builder{} + b.WriteString(`INSERT INTO `) + b.WriteString(t.Name) + b.WriteString(`(`) + for i, c := range cols { + if i != 0 { + b.WriteString(`,`) + } + b.WriteString(c.SqlName) + } + b.WriteString(`) VALUES(`) + for i := range cols { + if i != 0 { + b.WriteString(`,`) + } + b.WriteString(`?`) + } + b.WriteString(`)`) + return b.String() +} + +func (t *table) InsertArgs() string { + args := []string{} + for i, col := range t.Columns { + if !col.NoInsert { + args = append(args, "row."+t.Columns[i].Name) + } + } + return strings.Join(args, ", ") +} + +func (t *table) UpdateCols() (cols []*column) { + for _, c := range t.Columns { + if !(c.PK || c.NoUpdate) { + cols = append(cols, c) + } + } + return cols +} + +func (t *table) UpdateQuery() string { + cols := t.UpdateCols() + + b := &strings.Builder{} + b.WriteString(`UPDATE `) + b.WriteString(t.Name + ` SET `) + for i, col := range cols { + if i != 0 { + b.WriteByte(',') + } + b.WriteString(col.SqlName + `=?`) + } + + b.WriteString(` WHERE`) + + for i, c := range t.pkCols() { + if i != 0 { + b.WriteString(` AND`) + } + b.WriteString(` ` + c.SqlName + `=?`) + } + + return b.String() +} + +func (t *table) UpdateArgs() string { + cols := t.UpdateCols() + + b := &strings.Builder{} + for i, col := range cols { + if i != 0 { + b.WriteString(`, `) + } + b.WriteString("row." + col.Name) + } + for _, col := range t.pkCols() { + b.WriteString(", row." + col.Name) + } + return b.String() +} + +func (t *table) UpdateFullCols() (cols []*column) { + for _, c := range t.Columns { + if !c.PK { + cols = append(cols, c) + } + } + return cols +} + +func (t *table) UpdateFullQuery() string { + cols := t.UpdateFullCols() + + b := &strings.Builder{} + b.WriteString(`UPDATE `) + b.WriteString(t.Name + ` SET `) + for i, col := range cols { + if i != 0 { + b.WriteByte(',') + } + b.WriteString(col.SqlName + `=?`) + } + + b.WriteString(` WHERE`) + + for i, c := range t.pkCols() { + if i != 0 { + b.WriteString(` AND`) + } + b.WriteString(` ` + c.SqlName + `=?`) + } + + return b.String() +} + +func (t *table) UpdateFullArgs() string { + cols := t.UpdateFullCols() + + b := &strings.Builder{} + for i, col := range cols { + if i != 0 { + b.WriteString(`, `) + } + b.WriteString("row." + col.Name) + } + + for _, col := range t.pkCols() { + b.WriteString(", row." + col.Name) + } + + return b.String() +} + +func (t *table) pkCols() (cols []*column) { + for _, c := range t.Columns { + if c.PK { + cols = append(cols, c) + } + } + return cols +} + +func (t *table) PKFunctionArgs() string { + b := &strings.Builder{} + for _, col := range t.pkCols() { + b.WriteString(col.Name) + b.WriteString(` `) + b.WriteString(col.Type) + b.WriteString(",\n") + } + return b.String() +} + +func (t *table) DeleteQuery() string { + cols := t.pkCols() + + b := &strings.Builder{} + b.WriteString(`DELETE FROM `) + b.WriteString(t.Name) + b.WriteString(` WHERE `) + + for i, col := range cols { + if i != 0 { + b.WriteString(` AND `) + } + b.WriteString(col.SqlName) + b.WriteString(`=?`) + } + return b.String() +} + +func (t *table) DeleteArgs() string { + cols := t.pkCols() + b := &strings.Builder{} + + for i, col := range cols { + if i != 0 { + b.WriteString(`,`) + } + b.WriteString(col.Name) + } + return b.String() +} + +func (t *table) GetQuery() string { + b := &strings.Builder{} + b.WriteString(t.SelectQuery()) + b.WriteString(` WHERE `) + for i, col := range t.pkCols() { + if i != 0 { + b.WriteString(` AND `) + } + b.WriteString(col.SqlName + `=?`) + } + return b.String() +} + +func (t *table) ScanArgs() string { + b := &strings.Builder{} + for i, col := range t.Columns { + if i != 0 { + b.WriteString(`, `) + } + b.WriteString(`&row.` + col.Name) + } + return b.String() +} diff --git a/sqlgen/sqlite.go.tmpl b/sqlgen/sqlite.go.tmpl new file mode 100644 index 0000000..d7db01c --- /dev/null +++ b/sqlgen/sqlite.go.tmpl @@ -0,0 +1,206 @@ +package {{.PackageName}} + +import ( + "database/sql" + "iter" +) + +type TX interface { + Exec(query string, args ...any) (sql.Result, error) + Query(query string, args ...any) (*sql.Rows, error) + QueryRow(query string, args ...any) *sql.Row +} + +{{range .Schema.Tables}} + +// ---------------------------------------------------------------------------- +// Table: {{.Name}} +// ---------------------------------------------------------------------------- + +type {{.Type}} struct { + {{- range .Columns}} + {{.Name}} {{.Type}}{{end}} +} + +const {{.Type}}_SelectQuery = "{{.SelectQuery}}" + +{{if not .NoInsert -}} + +func {{.Type}}_Insert( + tx TX, + row *{{.Type}}, +) (err error) { + {{.Type}}_Sanitize(row) + if err = {{.Type}}_Validate(row); err != nil { + return err + } + + _, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) + return err +} + +{{- end}} {{/* if not .NoInsert */}} + +{{if not .NoUpdate -}} + +{{if .UpdateCols -}} + +func {{.Type}}_Update( + tx TX, + row *{{.Type}}, +) (found bool, err error) { + {{.Type}}_Sanitize(row) + if err = {{.Type}}_Validate(row); err != nil { + return false, err + } + + result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) + if err != nil { + return false, err + } + + n, err := result.RowsAffected() + if err != nil { + panic(err) + } + + if n > 1 { + panic("multiple rows updated") + } + + return n != 0, nil +} +{{- end}} + +{{if .UpdateFullCols -}} + +func {{.Type}}_UpdateFull( + tx TX, + row *{{.Type}}, +) (found bool, err error) { + {{.Type}}_Sanitize(row) + if err = {{.Type}}_Validate(row); err != nil { + return false, err + } + + result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) + if err != nil { + return false, err + } + + n, err := result.RowsAffected() + if err != nil { + panic(err) + } + + if n > 1 { + panic("multiple rows updated") + } + + return n != 0, nil +} +{{- end}} + + +{{- end}} {{/* if not .NoUpdate */}} + +{{if not .NoDelete -}} + +func {{.Type}}_Delete( + tx TX, + {{.PKFunctionArgs -}} +) (found bool, err error) { + result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) + if err != nil { + return false, err + } + + n, err := result.RowsAffected() + if err != nil { + panic(err) + } + + if n > 1 { + panic("multiple rows deleted") + } + + return n != 0, nil +} + +{{- end}} + +func {{.Type}}_Get( + tx TX, + {{.PKFunctionArgs -}} +) ( + row *{{.Type}}, + err error, +) { + row = &{{.Type}}{} + r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) + err = r.Scan({{.ScanArgs}}) + return +} + + +func {{.Type}}_GetWhere( + tx TX, + query string, + args ...any, +) ( + row *{{.Type}}, + err error, +) { + row = &{{.Type}}{} + r := tx.QueryRow(query, args...) + err = r.Scan({{.ScanArgs}}) + return +} + + + +func {{.Type}}_Iterate( + tx TX, + query string, + args ...any, +) ( + iter.Seq2[*{{.Type}}, error], +) { + rows, err := tx.Query(query, args...) + if err != nil { + return func(yield func(*{{.Type}}, error) bool) { + yield(nil, err) + } + } + + return func(yield func(*{{.Type}}, error) bool) { + defer rows.Close() + for rows.Next() { + row := &{{.Type}}{} + err := rows.Scan({{.ScanArgs}}) + if !yield(row, err) { + return + } + } + } +} + +func {{.Type}}_List( + tx TX, + query string, + args ...any, +) ( + l []*{{.Type}}, + err error, +) { + for row, err := range {{.Type}}_Iterate(tx, query, args...) { + if err != nil { + return nil, err + } + l = append(l, row) + } + return l, nil +} + + +{{end}} {{/* range .Schema.Tables */}} diff --git a/sqlgen/template.go b/sqlgen/template.go new file mode 100644 index 0000000..5372c84 --- /dev/null +++ b/sqlgen/template.go @@ -0,0 +1,36 @@ +package sqlgen + +import ( + _ "embed" + "os" + "os/exec" + "path/filepath" + "text/template" +) + +//go:embed sqlite.go.tmpl +var sqliteTemplate string + +func render(templateStr, schemaPath, outputPath string) error { + sch, err := parsePath(schemaPath) + if err != nil { + return err + } + + tmpl := template.Must(template.New("").Parse(templateStr)) + fOut, err := os.Create(outputPath) + if err != nil { + return err + } + defer fOut.Close() + + err = tmpl.Execute(fOut, struct { + PackageName string + Schema *schema + }{filepath.Base(filepath.Dir(outputPath)), sch}) + if err != nil { + return err + } + + return exec.Command("gofmt", "-w", outputPath).Run() +} diff --git a/sqlgen/test-files/TestParse/000.def b/sqlgen/test-files/TestParse/000.def new file mode 100644 index 0000000..ca57f1f --- /dev/null +++ b/sqlgen/test-files/TestParse/000.def @@ -0,0 +1,3 @@ +TABLE users OF User NoDelete ( + user_id string AS UserID PK +); \ No newline at end of file diff --git a/sqlgen/test-files/TestParse/000.json b/sqlgen/test-files/TestParse/000.json new file mode 100644 index 0000000..596727c --- /dev/null +++ b/sqlgen/test-files/TestParse/000.json @@ -0,0 +1,17 @@ +{ + "Tables": [ + { + "Name": "users", + "Type": "User", + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + } + ] + } + ] +} diff --git a/sqlgen/test-files/TestParse/001.def b/sqlgen/test-files/TestParse/001.def new file mode 100644 index 0000000..1ddc7b8 --- /dev/null +++ b/sqlgen/test-files/TestParse/001.def @@ -0,0 +1,4 @@ +TABLE users OF User NoDelete ( + user_id string AS UserID PK, + email string AS Email NoUpdate +); \ No newline at end of file diff --git a/sqlgen/test-files/TestParse/001.json b/sqlgen/test-files/TestParse/001.json new file mode 100644 index 0000000..ded27f4 --- /dev/null +++ b/sqlgen/test-files/TestParse/001.json @@ -0,0 +1,22 @@ +{ + "Tables": [ + { + "Name": "users", + "Type": "User", + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + }, { + "Name": "Email", + "Type": "string", + "SqlName": "email", + "NoUpdate": true + } + ] + } + ] +} diff --git a/sqlgen/test-files/TestParse/002.def b/sqlgen/test-files/TestParse/002.def new file mode 100644 index 0000000..019e3b5 --- /dev/null +++ b/sqlgen/test-files/TestParse/002.def @@ -0,0 +1,6 @@ +TABLE users OF User NoDelete ( + user_id string AS UserID PK, + email string AS Email NoUpdate, + name string AS Name NoInsert, + admin bool AS Admin NoInsert NoUpdate +); diff --git a/sqlgen/test-files/TestParse/002.json b/sqlgen/test-files/TestParse/002.json new file mode 100644 index 0000000..2693687 --- /dev/null +++ b/sqlgen/test-files/TestParse/002.json @@ -0,0 +1,33 @@ +{ + "Tables": [ + { + "Name": "users", + "Type": "User", + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + }, { + "Name": "Email", + "Type": "string", + "SqlName": "email", + "NoUpdate": true + }, { + "Name": "Name", + "Type": "string", + "SqlName": "name", + "NoInsert": true + }, { + "Name": "Admin", + "Type": "bool", + "SqlName": "admin", + "NoInsert": true, + "NoUpdate": true + } + ] + } + ] +} diff --git a/sqlgen/test-files/TestParse/003.def b/sqlgen/test-files/TestParse/003.def new file mode 100644 index 0000000..f7fa25c --- /dev/null +++ b/sqlgen/test-files/TestParse/003.def @@ -0,0 +1,12 @@ +TABLE users OF User NoDelete ( + user_id string AS UserID PK, + email string AS Email NoUpdate, + name string AS Name NoInsert, + admin bool AS Admin NoInsert NoUpdate +); + +TABLE users_view OF UserView NoInsert NoUpdate NoDelete ( + user_id string AS UserID PK, + email string AS Email, + name string AS Name +); \ No newline at end of file diff --git a/sqlgen/test-files/TestParse/003.json b/sqlgen/test-files/TestParse/003.json new file mode 100644 index 0000000..110d1fb --- /dev/null +++ b/sqlgen/test-files/TestParse/003.json @@ -0,0 +1,61 @@ +{ + "Tables": [ + { + "Name": "users", + "Type": "User", + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + }, + { + "Name": "Email", + "Type": "string", + "SqlName": "email", + "NoUpdate": true + }, + { + "Name": "Name", + "Type": "string", + "SqlName": "name", + "NoInsert": true + }, + { + "Name": "Admin", + "Type": "bool", + "SqlName": "admin", + "NoInsert": true, + "NoUpdate": true + } + ] + }, + { + "Name": "users_view", + "Type": "UserView", + "NoInsert": true, + "NoUpdate": true, + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + }, + { + "Name": "Email", + "Type": "string", + "SqlName": "email" + }, + { + "Name": "Name", + "Type": "string", + "SqlName": "name" + } + ] + } + ] +} diff --git a/sqlgen/test-files/TestParse/004.def b/sqlgen/test-files/TestParse/004.def new file mode 100644 index 0000000..9618d26 --- /dev/null +++ b/sqlgen/test-files/TestParse/004.def @@ -0,0 +1,13 @@ +TABLE users OF User NoDelete ( + user_id string AS UserID PK, + email string AS Email NoUpdate, + name string AS Name NoInsert, + admin bool AS Admin NoInsert NoUpdate, + SSN string NoUpdate +); + +TABLE users_view OF UserView NoInsert NoUpdate NoDelete ( + user_id string AS UserID PK, + email string AS Email, + name string AS Name +); diff --git a/sqlgen/test-files/TestParse/004.json b/sqlgen/test-files/TestParse/004.json new file mode 100644 index 0000000..088f931 --- /dev/null +++ b/sqlgen/test-files/TestParse/004.json @@ -0,0 +1,66 @@ +{ + "Tables": [ + { + "Name": "users", + "Type": "User", + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + }, + { + "Name": "Email", + "Type": "string", + "SqlName": "email", + "NoUpdate": true + }, + { + "Name": "Name", + "Type": "string", + "SqlName": "name", + "NoInsert": true + }, + { + "Name": "Admin", + "Type": "bool", + "SqlName": "admin", + "NoInsert": true, + "NoUpdate": true + }, { + "Name": "SSN", + "Type": "string", + "SqlName": "SSN", + "NoUpdate": true + } + ] + }, + { + "Name": "users_view", + "Type": "UserView", + "NoInsert": true, + "NoUpdate": true, + "NoDelete": true, + "Columns": [ + { + "Name": "UserID", + "Type": "string", + "SqlName": "user_id", + "PK": true + }, + { + "Name": "Email", + "Type": "string", + "SqlName": "email" + }, + { + "Name": "Name", + "Type": "string", + "SqlName": "name" + } + ] + } + ] +} diff --git a/sqliteutil/README.md b/sqliteutil/README.md new file mode 100644 index 0000000..eb59767 --- /dev/null +++ b/sqliteutil/README.md @@ -0,0 +1,45 @@ +# sqliteutil + +## Transactions + +Simplify postgres transactions using `WithTx` for serializable transactions, +or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner` +type to get automatic retries of serialization errors. + +## Migrations + +Put your migrations into a directory, for example `migrations`, ordered by name +(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the +`Migrate` function: + +```Go +//go:embed migrations +var migrations embed.FS + +func init() { + Migrate(db, migrations) // Check the error, of course. +} +``` + +## Testing + +In order to test this packge, we need to create a test user and database: + +``` +sudo su postgres +psql + +CREATE DATABASE test; +CREATE USER test WITH ENCRYPTED PASSWORD 'test'; +GRANT ALL PRIVILEGES ON DATABASE test TO test; + +use test + +GRANT ALL ON SCHEMA public TO test; +``` + +Check that you can connect via the command line: + +``` +psql -h 127.0.0.1 -U test --password test +``` diff --git a/sqliteutil/go.mod b/sqliteutil/go.mod new file mode 100644 index 0000000..91f0673 --- /dev/null +++ b/sqliteutil/go.mod @@ -0,0 +1,5 @@ +module git.crumpington.com/lib/sqliteutil + +go 1.23.2 + +require github.com/mattn/go-sqlite3 v1.14.24 // indirect diff --git a/sqliteutil/go.sum b/sqliteutil/go.sum new file mode 100644 index 0000000..9dcdc9b --- /dev/null +++ b/sqliteutil/go.sum @@ -0,0 +1,2 @@ +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= diff --git a/sqliteutil/migrate.go b/sqliteutil/migrate.go new file mode 100644 index 0000000..ad0701f --- /dev/null +++ b/sqliteutil/migrate.go @@ -0,0 +1,82 @@ +package sqliteutil + +import ( + "database/sql" + "embed" + "errors" + "fmt" + "path/filepath" + "sort" +) + +const initMigrationTableQuery = ` +CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` + +const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` + +const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` + +func Migrate(db *sql.DB, migrationFS embed.FS) error { + return WithTx(db, func(tx *sql.Tx) error { + if _, err := tx.Exec(initMigrationTableQuery); err != nil { + return err + } + + dirs, err := migrationFS.ReadDir(".") + if err != nil { + return err + } + if len(dirs) != 1 { + return errors.New("expected a single migrations directory") + } + + if !dirs[0].IsDir() { + return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) + } + + dirName := dirs[0].Name() + files, err := migrationFS.ReadDir(dirName) + if err != nil { + return err + } + + // Sort sql files by name. + sort.Slice(files, func(i, j int) bool { + return files[i].Name() < files[j].Name() + }) + + for _, dirEnt := range files { + if !dirEnt.Type().IsRegular() { + return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) + } + + var ( + name = dirEnt.Name() + exists bool + ) + + err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) + if err != nil { + return err + } + + if exists { + continue + } + + migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) + if err != nil { + return err + } + + if _, err := tx.Exec(string(migration)); err != nil { + return fmt.Errorf("migration %s failed: %v", name, err) + } + + if _, err := tx.Exec(insertMigrationQuery, name); err != nil { + return err + } + } + return nil + }) +} diff --git a/sqliteutil/migrate_test.go b/sqliteutil/migrate_test.go new file mode 100644 index 0000000..f78d4af --- /dev/null +++ b/sqliteutil/migrate_test.go @@ -0,0 +1,44 @@ +package sqliteutil + +import ( + "database/sql" + "embed" + "testing" +) + +//go:embed test-migrations +var testMigrationFS embed.FS + +func TestMigrate(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + // Shouldn't have any effect. + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` + var exists bool + + if err = db.QueryRow(query, 1).Scan(&exists); err != nil { + t.Fatal(err) + } + if exists { + t.Fatal("1 shouldn't exist") + } + + if err = db.QueryRow(query, 2).Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatal("2 should exist") + } + +} diff --git a/sqliteutil/test-migrations/000.sql b/sqliteutil/test-migrations/000.sql new file mode 100644 index 0000000..ecb559c --- /dev/null +++ b/sqliteutil/test-migrations/000.sql @@ -0,0 +1,9 @@ +CREATE TABLE users( + UserID BIGINT NOT NULL PRIMARY KEY, + Email TEXT NOT NULL UNIQUE); + +CREATE TABLE user_notes( + UserID BIGINT NOT NULL REFERENCES users(UserID), + NoteID BIGINT NOT NULL, + Note Text NOT NULL, + PRIMARY KEY(UserID,NoteID)); diff --git a/sqliteutil/test-migrations/001.sql b/sqliteutil/test-migrations/001.sql new file mode 100644 index 0000000..e424c57 --- /dev/null +++ b/sqliteutil/test-migrations/001.sql @@ -0,0 +1 @@ +INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); diff --git a/sqliteutil/test-migrations/002.sql b/sqliteutil/test-migrations/002.sql new file mode 100644 index 0000000..ba414d2 --- /dev/null +++ b/sqliteutil/test-migrations/002.sql @@ -0,0 +1 @@ +DELETE FROM users WHERE UserID=1; diff --git a/sqliteutil/tx.go b/sqliteutil/tx.go new file mode 100644 index 0000000..37f6f33 --- /dev/null +++ b/sqliteutil/tx.go @@ -0,0 +1,28 @@ +package sqliteutil + +import ( + "database/sql" + + _ "github.com/mattn/go-sqlite3" +) + +// This is a convenience function to run a function within a transaction. +func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { + // Start a transaction. + tx, err := db.Begin() + if err != nil { + return err + } + + err = fn(tx) + + if err == nil { + err = tx.Commit() + } + + if err != nil { + _ = tx.Rollback() + } + + return err +} diff --git a/tagengine/README.md b/tagengine/README.md new file mode 100644 index 0000000..2b418ae --- /dev/null +++ b/tagengine/README.md @@ -0,0 +1,2 @@ +# tagengine + diff --git a/tagengine/go.mod b/tagengine/go.mod new file mode 100644 index 0000000..0d28550 --- /dev/null +++ b/tagengine/go.mod @@ -0,0 +1,3 @@ +module git.crumpington.com/lib/tagengine + +go 1.23.2 diff --git a/tagengine/go.sum b/tagengine/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/tagengine/ngram.go b/tagengine/ngram.go new file mode 100644 index 0000000..dbaf65c --- /dev/null +++ b/tagengine/ngram.go @@ -0,0 +1,30 @@ +package tagengine + +import "unicode" + +func ngramLength(s string) int { + N := len(s) + i := 0 + count := 0 + + for { + // Eat spaces. + for i < N && unicode.IsSpace(rune(s[i])) { + i++ + } + + // Done? + if i == N { + break + } + + // Non-space! + count++ + + // Eat non-spaces. + for i < N && !unicode.IsSpace(rune(s[i])) { + i++ + } + } + return count +} diff --git a/tagengine/ngram_test.go b/tagengine/ngram_test.go new file mode 100644 index 0000000..82e2304 --- /dev/null +++ b/tagengine/ngram_test.go @@ -0,0 +1,31 @@ +package tagengine + +import ( + "log" + "testing" +) + +func TestNGramLength(t *testing.T) { + type Case struct { + Input string + Length int + } + + cases := []Case{ + {"a b c", 3}, + {" xyz\nlkj dflaj a", 4}, + {"a", 1}, + {" a", 1}, + {"a", 1}, + {" a\n", 1}, + {" a ", 1}, + {"\tx\ny\nz q ", 4}, + } + + for _, tc := range cases { + length := ngramLength(tc.Input) + if length != tc.Length { + log.Fatalf("%s: %d != %d", tc.Input, length, tc.Length) + } + } +} diff --git a/tagengine/node.go b/tagengine/node.go new file mode 100644 index 0000000..c48d982 --- /dev/null +++ b/tagengine/node.go @@ -0,0 +1,79 @@ +package tagengine + +import ( + "fmt" + "strings" +) + +type node struct { + Token string + Matches []*Rule // If a list of tokens reaches this node, it matches these. + Children map[string]*node +} + +func (n *node) AddRule(r *Rule) { + n.addRule(r, 0) +} + +func (n *node) addRule(r *Rule, idx int) { + if len(r.Includes) == idx { + n.Matches = append(n.Matches, r) + return + } + + token := r.Includes[idx] + + child, ok := n.Children[token] + if !ok { + child = &node{ + Token: token, + Children: map[string]*node{}, + } + n.Children[token] = child + } + + child.addRule(r, idx+1) +} + +// Note that tokens must be sorted. This is the case for tokens created from +// the tokenize function. +func (n *node) Match(tokens []string) (rules []*Rule) { + return n.match(tokens, rules) +} + +func (n *node) match(tokens []string, rules []*Rule) []*Rule { + // Check for a match. + if n.Matches != nil { + rules = append(rules, n.Matches...) + } + + if len(tokens) == 0 { + return rules + } + + // Attempt to match children. + for i := 0; i < len(tokens); i++ { + token := tokens[i] + if child, ok := n.Children[token]; ok { + rules = child.match(tokens[i+1:], rules) + } + } + + return rules +} + +func (n *node) Dump() { + n.dump(0) +} + +func (n *node) dump(depth int) { + indent := strings.Repeat(" ", 2*depth) + tag := "" + for _, m := range n.Matches { + tag += " " + m.Tag + } + fmt.Printf("%s%s%s\n", indent, n.Token, tag) + for _, child := range n.Children { + child.dump(depth + 1) + } +} diff --git a/tagengine/rule.go b/tagengine/rule.go new file mode 100644 index 0000000..77e742e --- /dev/null +++ b/tagengine/rule.go @@ -0,0 +1,159 @@ +package tagengine + +type Rule struct { + // The purpose of a Rule is to attach it's Tag to matching text. + Tag string + + // Includes is a list of strings that must be found in the input in order to + // match. + Includes []string + + // Excludes is a list of strings that can exclude a match for this rule. + Excludes []string + + // Blocks: If this rule is matched, then it will block matches of any tags + // listed here. + Blocks []string + + // The Score encodes the complexity of the Rule. A higher score indicates a + // more specific match. A Rule more includes, or includes with multiple words + // should havee a higher Score than a Rule with fewer includes or less + // complex includes. + Score int + + excludes map[string]struct{} +} + +func NewRule(tag string) Rule { + return Rule{Tag: tag} +} + +func (r Rule) Inc(l ...string) Rule { + return Rule{ + Tag: r.Tag, + Includes: append(r.Includes, l...), + Excludes: r.Excludes, + Blocks: r.Blocks, + } +} + +func (r Rule) Exc(l ...string) Rule { + return Rule{ + Tag: r.Tag, + Includes: r.Includes, + Excludes: append(r.Excludes, l...), + Blocks: r.Blocks, + } +} + +func (r Rule) Block(l ...string) Rule { + return Rule{ + Tag: r.Tag, + Includes: r.Includes, + Excludes: r.Excludes, + Blocks: append(r.Blocks, l...), + } +} + +func (rule *Rule) normalize(sanitize func(string) string) { + for i, token := range rule.Includes { + rule.Includes[i] = sanitize(token) + } + for i, token := range rule.Excludes { + rule.Excludes[i] = sanitize(token) + } + + sortTokens(rule.Includes) + sortTokens(rule.Excludes) + + rule.excludes = map[string]struct{}{} + for _, s := range rule.Excludes { + rule.excludes[s] = struct{}{} + } + + rule.Score = rule.computeScore() +} + +func (r Rule) maxNGram() int { + max := 0 + for _, s := range r.Includes { + n := ngramLength(s) + if n > max { + max = n + } + } + for _, s := range r.Excludes { + n := ngramLength(s) + if n > max { + max = n + } + } + + return max +} + +func (r Rule) isExcluded(tokens []string) bool { + // This is most often the case. + if len(r.excludes) == 0 { + return false + } + + for _, s := range tokens { + if _, ok := r.excludes[s]; ok { + return true + } + } + return false +} + +func (r Rule) computeScore() (score int) { + for _, token := range r.Includes { + n := ngramLength(token) + score += n * (n + 1) / 2 + } + return score +} + +func ruleLess(lhs, rhs *Rule) bool { + // If scores differ, sort by score. + if lhs.Score != rhs.Score { + return lhs.Score < rhs.Score + } + + // If include depth differs, sort by depth. + lDepth := len(lhs.Includes) + rDepth := len(rhs.Includes) + + if lDepth != rDepth { + return lDepth < rDepth + } + + // If exclude depth differs, sort by depth. + lDepth = len(lhs.Excludes) + rDepth = len(rhs.Excludes) + + if lDepth != rDepth { + return lDepth < rDepth + } + + // Sort alphabetically by includes. + for i := range lhs.Includes { + if lhs.Includes[i] != rhs.Includes[i] { + return lhs.Includes[i] < rhs.Includes[i] + } + } + + // Sort by alphabetically by excludes. + for i := range lhs.Excludes { + if lhs.Excludes[i] != rhs.Excludes[i] { + return lhs.Excludes[i] < rhs.Excludes[i] + } + } + + // Sort by tag. + if lhs.Tag != rhs.Tag { + return lhs.Tag < rhs.Tag + } + + return false +} diff --git a/tagengine/rulegroup.go b/tagengine/rulegroup.go new file mode 100644 index 0000000..3a30657 --- /dev/null +++ b/tagengine/rulegroup.go @@ -0,0 +1,58 @@ +package tagengine + +// A RuleGroup can be converted into a list of rules. Each rule will point to +// the same tag, and have the same exclude set and blocks. +type RuleGroup struct { + Tag string + Includes [][]string + Excludes []string + Blocks []string +} + +func NewRuleGroup(tag string) RuleGroup { + return RuleGroup{ + Tag: tag, + Includes: [][]string{}, + Excludes: []string{}, + Blocks: []string{}, + } +} + +func (g RuleGroup) Inc(l ...string) RuleGroup { + return RuleGroup{ + Tag: g.Tag, + Includes: append(g.Includes, l), + Excludes: g.Excludes, + Blocks: g.Blocks, + } +} + +func (g RuleGroup) Exc(l ...string) RuleGroup { + return RuleGroup{ + Tag: g.Tag, + Includes: g.Includes, + Excludes: append(g.Excludes, l...), + Blocks: g.Blocks, + } +} + +func (g RuleGroup) Block(l ...string) RuleGroup { + return RuleGroup{ + Tag: g.Tag, + Includes: g.Includes, + Excludes: g.Excludes, + Blocks: append(g.Blocks, l...), + } +} + +func (g RuleGroup) ToList() (l []Rule) { + for _, includes := range g.Includes { + l = append(l, Rule{ + Tag: g.Tag, + Excludes: g.Excludes, + Includes: includes, + Blocks: g.Blocks, + }) + } + return +} diff --git a/tagengine/ruleset.go b/tagengine/ruleset.go new file mode 100644 index 0000000..1efe341 --- /dev/null +++ b/tagengine/ruleset.go @@ -0,0 +1,162 @@ +package tagengine + +import ( + "sort" +) + +type RuleSet struct { + root *node + maxNgram int + sanitize func(string) string + rules []*Rule +} + +func NewRuleSet() *RuleSet { + return &RuleSet{ + root: &node{ + Token: "/", + Children: map[string]*node{}, + }, + sanitize: BasicSanitizer, + rules: []*Rule{}, + } +} + +func NewRuleSetFromList(rules []Rule) *RuleSet { + rs := NewRuleSet() + rs.AddRule(rules...) + return rs +} + +func (t *RuleSet) Add(ruleOrGroup ...interface{}) { + for _, ix := range ruleOrGroup { + switch x := ix.(type) { + case Rule: + t.AddRule(x) + case RuleGroup: + t.AddRuleGroup(x) + default: + panic("Add expects either Rule or RuleGroup objects.") + } + } +} + +func (t *RuleSet) AddRule(rules ...Rule) { + for _, rule := range rules { + rule := rule + + // Make sure rule is well-formed. + rule.normalize(t.sanitize) + + // Update maxNgram. + N := rule.maxNGram() + if N > t.maxNgram { + t.maxNgram = N + } + + t.rules = append(t.rules, &rule) + t.root.AddRule(&rule) + } +} + +func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) { + for _, rg := range ruleGroups { + t.AddRule(rg.ToList()...) + } +} + +// MatchRules will return a list of all matching rules. The rules are sorted by +// the match's Score. The best match will be first. +func (t *RuleSet) MatchRules(input string) (rules []*Rule) { + input = t.sanitize(input) + tokens := Tokenize(input, t.maxNgram) + + rules = t.root.Match(tokens) + if len(rules) == 0 { + return rules + } + + // Check excludes. + l := rules[:0] + for _, r := range rules { + if !r.isExcluded(tokens) { + l = append(l, r) + } + } + + rules = l + + // Sort rules descending. + sort.Slice(rules, func(i, j int) bool { + return ruleLess(rules[j], rules[i]) + }) + + return rules +} + +type Match struct { + Tag string + + // Confidence is used to sort all matches, and is normalized so the sum of + // Confidence values for all matches is 1. Confidence is relative to the + // number of matches and the size of matches in terms of number of tokens. + Confidence float64 // In the range (0,1]. +} + +// Return a list of matches with confidence. This is useful if you'd like to +// find the best matching rule out of all the matched rules. +// +// If you just want to find all matching rules, then use MatchRules. +func (t *RuleSet) Match(input string) []Match { + rules := t.MatchRules(input) + if len(rules) == 0 { + return []Match{} + } + if len(rules) == 1 { + return []Match{{ + Tag: rules[0].Tag, + Confidence: 1, + }} + } + + // Create list of blocked tags. + blocks := map[string]struct{}{} + for _, rule := range rules { + for _, tag := range rule.Blocks { + blocks[tag] = struct{}{} + } + } + + // Remove rules for blocked tags. + iOut := 0 + for _, rule := range rules { + if _, ok := blocks[rule.Tag]; ok { + continue + } + rules[iOut] = rule + iOut++ + } + rules = rules[:iOut] + + // Matches by index. + matches := map[string]int{} + out := []Match{} + sum := float64(0) + + for _, rule := range rules { + idx, ok := matches[rule.Tag] + if !ok { + idx = len(matches) + matches[rule.Tag] = idx + out = append(out, Match{Tag: rule.Tag}) + } + out[idx].Confidence += float64(rule.Score) + sum += float64(rule.Score) + } + + for i := range out { + out[i].Confidence /= sum + } + + return out +} diff --git a/tagengine/ruleset_test.go b/tagengine/ruleset_test.go new file mode 100644 index 0000000..23ec5cc --- /dev/null +++ b/tagengine/ruleset_test.go @@ -0,0 +1,84 @@ +package tagengine + +import ( + "reflect" + "testing" +) + +func TestRulesSet(t *testing.T) { + rs := NewRuleSet() + rs.AddRule(Rule{ + Tag: "cc/2", + Includes: []string{"cola", "coca"}, + }) + rs.AddRule(Rule{ + Tag: "cc/0", + Includes: []string{"coca cola"}, + }) + rs.AddRule(Rule{ + Tag: "cz/2", + Includes: []string{"coca", "zero"}, + }) + rs.AddRule(Rule{ + Tag: "cc0/3", + Includes: []string{"zero", "coca", "cola"}, + }) + rs.AddRule(Rule{ + Tag: "cc0/3.1", + Includes: []string{"coca", "cola", "zero"}, + Excludes: []string{"pepsi"}, + }) + rs.AddRule(Rule{ + Tag: "spa", + Includes: []string{"spa"}, + Blocks: []string{"cc/0", "cc0/3", "cc0/3.1"}, + }) + + type TestCase struct { + Input string + Matches []Match + } + + cases := []TestCase{ + { + Input: "coca-cola zero", + Matches: []Match{ + {"cc0/3.1", 0.3}, + {"cc0/3", 0.3}, + {"cz/2", 0.2}, + {"cc/2", 0.2}, + }, + }, { + Input: "coca cola", + Matches: []Match{ + {"cc/0", 0.6}, + {"cc/2", 0.4}, + }, + }, { + Input: "coca cola zero pepsi", + Matches: []Match{ + {"cc0/3", 0.3}, + {"cc/0", 0.3}, + {"cz/2", 0.2}, + {"cc/2", 0.2}, + }, + }, { + Input: "fanta orange", + Matches: []Match{}, + }, { + Input: "coca-cola zero / fanta / spa", + Matches: []Match{ + {"cz/2", 0.4}, + {"cc/2", 0.4}, + {"spa", 0.2}, + }, + }, + } + + for _, tc := range cases { + matches := rs.Match(tc.Input) + if !reflect.DeepEqual(matches, tc.Matches) { + t.Fatalf("%v != %v", matches, tc.Matches) + } + } +} diff --git a/tagengine/sanitize.go b/tagengine/sanitize.go new file mode 100644 index 0000000..3928c3b --- /dev/null +++ b/tagengine/sanitize.go @@ -0,0 +1,20 @@ +package tagengine + +import ( + "strings" + + "git.crumpington.com/lib/tagengine/sanitize" +) + +// The basic sanitizer: +// * lower-case +// * put spaces around numbers +// * put slaces around punctuation +// * collapse multiple spaces +func BasicSanitizer(s string) string { + s = strings.ToLower(s) + s = sanitize.SpaceNumbers(s) + s = sanitize.SpacePunctuation(s) + s = sanitize.CollapseSpaces(s) + return s +} diff --git a/tagengine/sanitize/sanitize.go b/tagengine/sanitize/sanitize.go new file mode 100644 index 0000000..b786eb1 --- /dev/null +++ b/tagengine/sanitize/sanitize.go @@ -0,0 +1,91 @@ +package sanitize + +import ( + "strings" + "unicode" +) + +func SpaceNumbers(s string) string { + if len(s) == 0 { + return s + } + + isDigit := func(b rune) bool { + switch b { + case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': + return true + } + return false + } + + b := strings.Builder{} + + var first rune + for _, c := range s { + first = c + break + } + + digit := isDigit(first) + + // Range over runes. + for _, c := range s { + thisDigit := isDigit(c) + if thisDigit != digit { + b.WriteByte(' ') + digit = thisDigit + } + b.WriteRune(c) + } + + return b.String() +} + +func SpacePunctuation(s string) string { + needsSpace := func(r rune) bool { + switch r { + case '`', '~', '!', '@', '#', '%', '^', '&', '*', '(', ')', + '-', '_', '+', '=', '[', '{', ']', '}', '\\', '|', + ':', ';', '"', '\'', ',', '<', '.', '>', '?', '/': + return true + } + return false + } + + b := strings.Builder{} + + // Range over runes. + for _, r := range s { + if needsSpace(r) { + b.WriteRune(' ') + b.WriteRune(r) + b.WriteRune(' ') + } else { + b.WriteRune(r) + } + } + + return b.String() +} + +func CollapseSpaces(s string) string { + // Trim leading and trailing spaces. + s = strings.TrimSpace(s) + + b := strings.Builder{} + wasSpace := false + + // Range over runes. + for _, c := range s { + if unicode.IsSpace(c) { + wasSpace = true + continue + } else if wasSpace { + wasSpace = false + b.WriteRune(' ') + } + b.WriteRune(c) + } + + return b.String() +} diff --git a/tagengine/sanitize_test.go b/tagengine/sanitize_test.go new file mode 100644 index 0000000..82aca01 --- /dev/null +++ b/tagengine/sanitize_test.go @@ -0,0 +1,30 @@ +package tagengine + +import "testing" + +func TestSanitize(t *testing.T) { + sanitize := BasicSanitizer + + type Case struct { + In string + Out string + } + + cases := []Case{ + {"", ""}, + {"123abc", "123 abc"}, + {"abc123", "abc 123"}, + {"abc123xyz", "abc 123 xyz"}, + {"1f2", "1 f 2"}, + {" abc", "abc"}, + {" ; KitKat/m&m's (bottle) @ ", "; kitkat / m & m ' s ( bottle ) @"}, + {"€", "€"}, + } + + for _, tc := range cases { + out := sanitize(tc.In) + if out != tc.Out { + t.Fatalf("%v != %v", out, tc.Out) + } + } +} diff --git a/tagengine/tokenize.go b/tagengine/tokenize.go new file mode 100644 index 0000000..5dd1bb9 --- /dev/null +++ b/tagengine/tokenize.go @@ -0,0 +1,63 @@ +package tagengine + +import ( + "sort" + "strings" +) + +var ignoreTokens = map[string]struct{}{} + +func init() { + // These on their own are ignored. + tokens := []string{ + "`", `~`, `!`, `@`, `#`, `%`, `^`, `&`, `*`, `(`, `)`, + `-`, `_`, `+`, `=`, `[`, `{`, `]`, `}`, `\`, `|`, + `:`, `;`, `"`, `'`, `,`, `<`, `.`, `>`, `?`, `/`, + } + for _, s := range tokens { + ignoreTokens[s] = struct{}{} + } +} + +func Tokenize( + input string, + maxNgram int, +) ( + tokens []string, +) { + // Avoid duplicate ngrams. + ignored := map[string]bool{} + + fields := strings.Fields(input) + + if len(fields) < maxNgram { + maxNgram = len(fields) + } + + for i := 1; i < maxNgram+1; i++ { + jMax := len(fields) - i + 1 + + for j := 0; j < jMax; j++ { + ngram := strings.Join(fields[j:i+j], " ") + if _, ok := ignoreTokens[ngram]; !ok { + if _, ok := ignored[ngram]; !ok { + tokens = append(tokens, ngram) + ignored[ngram] = true + } + } + } + } + + sortTokens(tokens) + + return tokens +} + +func sortTokens(tokens []string) { + sort.Slice(tokens, func(i, j int) bool { + if len(tokens[i]) != len(tokens[j]) { + return len(tokens[i]) < len(tokens[j]) + } + return tokens[i] < tokens[j] + }) +} diff --git a/tagengine/tokenize_test.go b/tagengine/tokenize_test.go new file mode 100644 index 0000000..06c775e --- /dev/null +++ b/tagengine/tokenize_test.go @@ -0,0 +1,55 @@ +package tagengine + +import ( + "reflect" + "testing" +) + +func TestTokenize(t *testing.T) { + type Case struct { + Input string + MaxNgram int + Output []string + } + + cases := []Case{ + { + Input: "a bb c d", + MaxNgram: 3, + Output: []string{ + "a", "c", "d", "bb", + "c d", "a bb", "bb c", + "a bb c", "bb c d", + }, + }, { + Input: "a b", + MaxNgram: 3, + Output: []string{ + "a", "b", "a b", + }, + }, { + Input: "- b c d", + MaxNgram: 3, + Output: []string{ + "b", "c", "d", + "- b", "b c", "c d", + "- b c", "b c d", + }, + }, { + Input: "a a b c d c d", + MaxNgram: 3, + Output: []string{ + "a", "b", "c", "d", + "a a", "a b", "b c", "c d", "d c", + "a a b", "a b c", "b c d", "c d c", "d c d", + }, + }, + } + + for _, tc := range cases { + output := Tokenize(tc.Input, tc.MaxNgram) + if !reflect.DeepEqual(output, tc.Output) { + t.Fatalf("%s: %#v", tc.Input, output) + } + } +} diff --git a/webutil/README.md b/webutil/README.md new file mode 100644 index 0000000..c0f37f0 --- /dev/null +++ b/webutil/README.md @@ -0,0 +1,5 @@ +# webutil + +## Roadmap + +* logging middleware diff --git a/webutil/go.mod b/webutil/go.mod new file mode 100644 index 0000000..3eed7f5 --- /dev/null +++ b/webutil/go.mod @@ -0,0 +1,10 @@ +module git.crumpington.com/lib/webutil + +go 1.23.2 + +require golang.org/x/crypto v0.28.0 + +require ( + golang.org/x/net v0.21.0 // indirect + golang.org/x/text v0.19.0 // indirect +) diff --git a/webutil/go.sum b/webutil/go.sum new file mode 100644 index 0000000..50ea1d3 --- /dev/null +++ b/webutil/go.sum @@ -0,0 +1,6 @@ +golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= +golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= +golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= +golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= diff --git a/webutil/listenandserve.go b/webutil/listenandserve.go new file mode 100644 index 0000000..ebfbee1 --- /dev/null +++ b/webutil/listenandserve.go @@ -0,0 +1,24 @@ +package webutil + +import ( + "errors" + "net/http" + "strings" + + "golang.org/x/crypto/acme/autocert" +) + +// Serve requests using the given http.Server. If srv.Addr has the format +// `hostname:https`, then use autocert to manage certificates for the domain. +// +// For http on port 80, you can use :http. +func ListenAndServe(srv *http.Server) error { + if strings.HasSuffix(srv.Addr, ":https") { + hostname := strings.TrimSuffix(srv.Addr, ":https") + if len(hostname) == 0 { + return errors.New("https requires a hostname") + } + return srv.Serve(autocert.NewListener(hostname)) + } + return srv.ListenAndServe() +} diff --git a/webutil/middleware-logging.go b/webutil/middleware-logging.go new file mode 100644 index 0000000..39a5a30 --- /dev/null +++ b/webutil/middleware-logging.go @@ -0,0 +1,47 @@ +package webutil + +import ( + "log" + "net/http" + "os" + "time" +) + +var _log = log.New(os.Stderr, "", 0) + +type responseWriterWrapper struct { + http.ResponseWriter + httpStatus int + responseSize int +} + +func (w *responseWriterWrapper) WriteHeader(status int) { + w.httpStatus = status + w.ResponseWriter.WriteHeader(status) +} + +func (w *responseWriterWrapper) Write(b []byte) (int, error) { + if w.httpStatus == 0 { + w.httpStatus = 200 + } + w.responseSize += len(b) + return w.ResponseWriter.Write(b) +} + +func WithLogging(inner http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + t := time.Now() + wrapper := responseWriterWrapper{w, 0, 0} + + inner(&wrapper, r) + _log.Printf("%s \"%s %s %s\" %d %d %v\n", + r.RemoteAddr, + r.Method, + r.URL.Path, + r.Proto, + wrapper.httpStatus, + wrapper.responseSize, + time.Since(t), + ) + } +} diff --git a/webutil/template.go b/webutil/template.go new file mode 100644 index 0000000..e18f37d --- /dev/null +++ b/webutil/template.go @@ -0,0 +1,100 @@ +package webutil + +import ( + "embed" + "html/template" + "io/fs" + "log" + "path" + "strings" +) + +// ParseTemplateSet parses sets of templates from an embed.FS. +// +// Each directory constitutes a set of templates that are parsed together. +// +// Structure (within a directory): +// - share/* are always parsed. +// - base.html will be parsed with each other file in same dir +// +// Call a template with m[path].Execute(w, data) (root dir name is excluded). +// +// For example, if you have +// - /user/share/* +// - /user/base.html +// - /user/home.html +// +// Then you call m["/user/home.html"].Execute(w, data). +func ParseTemplateSet(funcs template.FuncMap, fs embed.FS) map[string]*template.Template { + m := map[string]*template.Template{} + rootDir := readDir(fs, ".")[0].Name() + loadTemplateDir(fs, funcs, m, rootDir, rootDir) + return m +} + +func loadTemplateDir( + fs embed.FS, + funcs template.FuncMap, + m map[string]*template.Template, + dirPath string, + rootDir string, +) map[string]*template.Template { + t := template.New("") + if funcs != nil { + t = t.Funcs(funcs) + } + + shareDir := path.Join(dirPath, "share") + if _, err := fs.ReadDir(shareDir); err == nil { + log.Printf("Parsing %s...", path.Join(shareDir, "*")) + t = template.Must(t.ParseFS(fs, path.Join(shareDir, "*"))) + } + + if data, _ := fs.ReadFile(path.Join(dirPath, "base.html")); data != nil { + log.Printf("Parsing %s...", path.Join(dirPath, "base.html")) + t = template.Must(t.Parse(string(data))) + } + + for _, ent := range readDir(fs, dirPath) { + if ent.Type().IsDir() { + if ent.Name() != "share" { + m = loadTemplateDir(fs, funcs, m, path.Join(dirPath, ent.Name()), rootDir) + } + continue + } + + if !ent.Type().IsRegular() { + continue + } + + if ent.Name() == "base.html" { + continue + } + + filePath := path.Join(dirPath, ent.Name()) + log.Printf("Parsing %s...", filePath) + + key := strings.TrimPrefix(path.Join(dirPath, ent.Name()), rootDir) + tt := template.Must(t.Clone()) + tt = template.Must(tt.Parse(readFile(fs, filePath))) + m[key] = tt + } + + return m +} + +func readDir(fs embed.FS, dirPath string) []fs.DirEntry { + ents, err := fs.ReadDir(dirPath) + if err != nil { + panic(err) + } + return ents +} + +func readFile(fs embed.FS, path string) string { + data, err := fs.ReadFile(path) + if err != nil { + panic(err) + } + return string(data) +} diff --git a/webutil/template_test.go b/webutil/template_test.go new file mode 100644 index 0000000..40c9313 --- /dev/null +++ b/webutil/template_test.go @@ -0,0 +1,49 @@ +package webutil + +import ( + "bytes" + "embed" + "html/template" + "strings" + "testing" +) + +//go:embed all:test-templates +var testFS embed.FS + +func TestParseTemplateSet(t *testing.T) { + funcs := template.FuncMap{"join": strings.Join} + m := ParseTemplateSet(funcs, testFS) + + type TestCase struct { + Key string + Data any + Out string + } + + cases := []TestCase{ + { + Key: "/home.html", + Data: "DATA", + Out: "

HOME!

", + }, { + Key: "/about.html", + Data: "DATA", + Out: "

DATA

", + }, { + Key: "/contact.html", + Data: []string{"a", "b", "c"}, + Out: "

a,b,c

", + }, + } + + for _, tc := range cases { + b := &bytes.Buffer{} + m[tc.Key].Execute(b, tc.Data) + out := strings.TrimSpace(b.String()) + if out != tc.Out { + t.Fatalf("%s != %s", out, tc.Out) + } + } + +} diff --git a/webutil/test-templates/about.html b/webutil/test-templates/about.html new file mode 100644 index 0000000..dfaaaa0 --- /dev/null +++ b/webutil/test-templates/about.html @@ -0,0 +1 @@ +{{define "body"}}{{template "bold" .}}{{end}} diff --git a/webutil/test-templates/base.html b/webutil/test-templates/base.html new file mode 100644 index 0000000..f1c87e6 --- /dev/null +++ b/webutil/test-templates/base.html @@ -0,0 +1 @@ +

{{block "body" .}}default{{end}}

diff --git a/webutil/test-templates/contact.html b/webutil/test-templates/contact.html new file mode 100644 index 0000000..a4394c9 --- /dev/null +++ b/webutil/test-templates/contact.html @@ -0,0 +1 @@ +{{define "body"}}{{join . ","}}{{end}} diff --git a/webutil/test-templates/home.html b/webutil/test-templates/home.html new file mode 100644 index 0000000..fc06425 --- /dev/null +++ b/webutil/test-templates/home.html @@ -0,0 +1 @@ +{{define "body"}}HOME!{{end}} diff --git a/webutil/test-templates/share/bold.html b/webutil/test-templates/share/bold.html new file mode 100644 index 0000000..090a255 --- /dev/null +++ b/webutil/test-templates/share/bold.html @@ -0,0 +1 @@ +{{define "bold"}}{{.}}{{end}} diff --git a/webutil/test-templates/share/italic.html b/webutil/test-templates/share/italic.html new file mode 100644 index 0000000..e9502e5 --- /dev/null +++ b/webutil/test-templates/share/italic.html @@ -0,0 +1 @@ +{{define "italic"}}{{.}}{{end}}