wip
This commit is contained in:
parent
d0587cc585
commit
c5419d662e
2
flock/README.md
Normal file
2
flock/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# flock
|
||||
|
69
flock/flock.go
Normal file
69
flock/flock.go
Normal file
@ -0,0 +1,69 @@
|
||||
// The flock package provides a file-system mediated locking mechanism on linux
|
||||
// using the `flock` system call.
|
||||
|
||||
package flock
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
// Lock gets an exclusive lock on the file at the given path. If the file
|
||||
// doesn't exist, it's created.
|
||||
func Lock(path string) (*os.File, error) {
|
||||
return lock(path, syscall.LOCK_EX)
|
||||
}
|
||||
|
||||
// TryLock will return a nil file if the file is already locked.
|
||||
func TryLock(path string) (*os.File, error) {
|
||||
return lock(path, syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
}
|
||||
|
||||
func LockFile(f *os.File) error {
|
||||
_, err := lockFile(f, syscall.LOCK_EX)
|
||||
return err
|
||||
}
|
||||
|
||||
// Returns true if the lock was successfully acquired.
|
||||
func TryLockFile(f *os.File) (bool, error) {
|
||||
return lockFile(f, syscall.LOCK_EX|syscall.LOCK_NB)
|
||||
}
|
||||
|
||||
func lockFile(f *os.File, flags int) (bool, error) {
|
||||
|
||||
if err := flock(int(f.Fd()), flags); err != nil {
|
||||
if flags&syscall.LOCK_NB != 0 && errors.Is(err, syscall.EAGAIN) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func flock(fd int, how int) error {
|
||||
_, _, e1 := syscall.Syscall(syscall.SYS_FLOCK, uintptr(fd), uintptr(how), 0)
|
||||
if e1 != 0 {
|
||||
return syscall.Errno(e1)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func lock(path string, flags int) (*os.File, error) {
|
||||
perm := os.O_CREATE | os.O_RDWR
|
||||
f, err := os.OpenFile(path, perm, 0600)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ok, err := lockFile(f, flags)
|
||||
if err != nil || !ok {
|
||||
f.Close()
|
||||
f = nil
|
||||
}
|
||||
return f, err
|
||||
}
|
||||
|
||||
// Unlock releases the lock acquired via the Lock function.
|
||||
func Unlock(f *os.File) error {
|
||||
return f.Close()
|
||||
}
|
66
flock/flock_test.go
Normal file
66
flock/flock_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package flock
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_Lock_basic(t *testing.T) {
|
||||
ch := make(chan int, 1)
|
||||
f, err := Lock("/tmp/fsutil-test-lock")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
go func() {
|
||||
time.Sleep(time.Second)
|
||||
ch <- 10
|
||||
Unlock(f)
|
||||
}()
|
||||
|
||||
select {
|
||||
case x := <-ch:
|
||||
t.Fatal(x)
|
||||
default:
|
||||
|
||||
}
|
||||
|
||||
f2, _ := Lock("/tmp/fsutil-test-lock")
|
||||
defer Unlock(f2)
|
||||
select {
|
||||
case i := <-ch:
|
||||
if i != 10 {
|
||||
t.Fatal(i)
|
||||
}
|
||||
default:
|
||||
t.Fatal("No value available.")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_Lock_badPath(t *testing.T) {
|
||||
_, err := Lock("./dne/file.lock")
|
||||
if err == nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTryLock(t *testing.T) {
|
||||
lockPath := "/tmp/fsutil-test-lock"
|
||||
f, err := TryLock(lockPath)
|
||||
if err != nil {
|
||||
t.Fatalf("%#v", err)
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
f2, err := TryLock(lockPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if f2 != nil {
|
||||
t.Fatal(f2)
|
||||
}
|
||||
|
||||
if err := Unlock(f); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
3
flock/go.mod
Normal file
3
flock/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/flock
|
||||
|
||||
go 1.23.0
|
0
flock/go.sum
Normal file
0
flock/go.sum
Normal file
2
httpconn/README.md
Normal file
2
httpconn/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# httpconn
|
||||
|
104
httpconn/client.go
Normal file
104
httpconn/client.go
Normal file
@ -0,0 +1,104 @@
|
||||
package httpconn
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrUnknownScheme = errors.New("uknown scheme")
|
||||
)
|
||||
|
||||
type Dialer struct {
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
func NewDialer() *Dialer {
|
||||
return &Dialer{timeout: 10 * time.Second}
|
||||
}
|
||||
|
||||
func (d *Dialer) SetTimeout(timeout time.Duration) {
|
||||
d.timeout = timeout
|
||||
}
|
||||
|
||||
func (d *Dialer) Dial(rawURL string) (net.Conn, error) {
|
||||
u, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch u.Scheme {
|
||||
case "https":
|
||||
return d.DialHTTPS(u.Host+":443", u.Path)
|
||||
case "http":
|
||||
return d.DialHTTP(u.Host, u.Path)
|
||||
default:
|
||||
return nil, ErrUnknownScheme
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), d.timeout)
|
||||
dd := tls.Dialer{}
|
||||
conn, err := dd.DialContext(ctx, "tcp", host)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.finishDialing(conn, host, path)
|
||||
|
||||
}
|
||||
|
||||
func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) {
|
||||
conn, err := net.DialTimeout("tcp", host, d.timeout)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d.finishDialing(conn, host, path)
|
||||
}
|
||||
|
||||
func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) {
|
||||
conn.SetDeadline(time.Now().Add(d.timeout))
|
||||
|
||||
if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Require successful HTTP response before using the conn.
|
||||
resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Status != "200 OK" {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func Dial(rawURL string) (net.Conn, error) {
|
||||
return NewDialer().Dial(rawURL)
|
||||
}
|
||||
|
||||
func DialHTTPS(host, path string) (net.Conn, error) {
|
||||
return NewDialer().DialHTTPS(host, path)
|
||||
}
|
||||
|
||||
func DialHTTP(host, path string) (net.Conn, error) {
|
||||
return NewDialer().DialHTTP(host, path)
|
||||
}
|
42
httpconn/conn_test.go
Normal file
42
httpconn/conn_test.go
Normal file
@ -0,0 +1,42 @@
|
||||
package httpconn
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
func TestNetTest_TestConn(t *testing.T) {
|
||||
nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) {
|
||||
|
||||
connCh := make(chan net.Conn, 1)
|
||||
doneCh := make(chan bool)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) {
|
||||
conn, err := Accept(w, r)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
connCh <- conn
|
||||
<-doneCh
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(mux)
|
||||
|
||||
c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
c2 = <-connCh
|
||||
|
||||
return c1, c2, func() {
|
||||
doneCh <- true
|
||||
srv.Close()
|
||||
}, nil
|
||||
})
|
||||
}
|
5
httpconn/go.mod
Normal file
5
httpconn/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module git.crumpington.com/lib/httpconn
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require golang.org/x/net v0.30.0
|
2
httpconn/go.sum
Normal file
2
httpconn/go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4=
|
||||
golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU=
|
32
httpconn/server.go
Normal file
32
httpconn/server.go
Normal file
@ -0,0 +1,32 @@
|
||||
package httpconn
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) {
|
||||
if r.Method != "CONNECT" {
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.WriteHeader(http.StatusMethodNotAllowed)
|
||||
io.WriteString(w, "405 must CONNECT\n")
|
||||
return nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
hj, ok := w.(http.Hijacker)
|
||||
if !ok {
|
||||
return nil, http.ErrNotSupported
|
||||
}
|
||||
|
||||
conn, _, err := hj.Hijack()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n")
|
||||
conn.SetDeadline(time.Time{})
|
||||
|
||||
return conn, nil
|
||||
}
|
2
idgen/README.md
Normal file
2
idgen/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# idgen
|
||||
|
3
idgen/go.mod
Normal file
3
idgen/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/idgen
|
||||
|
||||
go 1.23.2
|
53
idgen/idgen.go
Normal file
53
idgen/idgen.go
Normal file
@ -0,0 +1,53 @@
|
||||
package idgen
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base32"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Creates a new, random token.
|
||||
func NewToken() string {
|
||||
buf := make([]byte, 20)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return base32.StdEncoding.EncodeToString(buf)
|
||||
}
|
||||
|
||||
var (
|
||||
lock sync.Mutex
|
||||
ts int64 = time.Now().Unix()
|
||||
counter int64 = 1
|
||||
counterMax int64 = 1 << 20
|
||||
)
|
||||
|
||||
// NextID can generate ~1M ints per second for a given nodeID.
|
||||
//
|
||||
// nodeID must be less than 64.
|
||||
func NextID(nodeID int64) int64 {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
|
||||
tt := time.Now().Unix()
|
||||
if tt > ts {
|
||||
ts = tt
|
||||
counter = 1
|
||||
} else {
|
||||
counter++
|
||||
if counter == counterMax {
|
||||
panic("Too many IDs.")
|
||||
}
|
||||
}
|
||||
|
||||
return (ts << 26) + (nodeID << 20) + counter
|
||||
}
|
||||
|
||||
func SplitID(id int64) (unixTime, nodeID, counter int64) {
|
||||
counter = id & (0x00000000000FFFFF)
|
||||
nodeID = (id >> 20) & (0x000000000000003F)
|
||||
unixTime = id >> 26
|
||||
return
|
||||
}
|
19
idgen/idgen_test.go
Normal file
19
idgen/idgen_test.go
Normal file
@ -0,0 +1,19 @@
|
||||
package idgen
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func BenchmarkNext(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NextID(0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNextID(t *testing.T) {
|
||||
id := NextID(32)
|
||||
|
||||
a, b, c := SplitID(id)
|
||||
log.Print(a, b, c)
|
||||
}
|
2
keyedmutex/README.md
Normal file
2
keyedmutex/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# keyedmutex
|
||||
|
3
keyedmutex/go.mod
Normal file
3
keyedmutex/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/keyedmutex
|
||||
|
||||
go 1.23.2
|
67
keyedmutex/keyedmutex.go
Normal file
67
keyedmutex/keyedmutex.go
Normal file
@ -0,0 +1,67 @@
|
||||
package keyedmutex
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type KeyedMutex[K comparable] struct {
|
||||
mu *sync.Mutex
|
||||
waitList map[K]*list.List
|
||||
}
|
||||
|
||||
func New[K comparable]() KeyedMutex[K] {
|
||||
return KeyedMutex[K]{
|
||||
mu: new(sync.Mutex),
|
||||
waitList: map[K]*list.List{},
|
||||
}
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) Lock(key K) {
|
||||
if ch := m.lock(key); ch != nil {
|
||||
<-ch
|
||||
}
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) lock(key K) chan struct{} {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if waitList, ok := m.waitList[key]; ok {
|
||||
ch := make(chan struct{})
|
||||
waitList.PushBack(ch)
|
||||
return ch
|
||||
}
|
||||
|
||||
m.waitList[key] = list.New()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) TryLock(key K) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.waitList[key]; ok {
|
||||
return false
|
||||
}
|
||||
|
||||
m.waitList[key] = list.New()
|
||||
return true
|
||||
}
|
||||
|
||||
func (m KeyedMutex[K]) Unlock(key K) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
waitList, ok := m.waitList[key]
|
||||
if !ok {
|
||||
panic("unlock of unlocked mutex")
|
||||
}
|
||||
|
||||
if waitList.Len() == 0 {
|
||||
delete(m.waitList, key)
|
||||
} else {
|
||||
ch := waitList.Remove(waitList.Front()).(chan struct{})
|
||||
ch <- struct{}{}
|
||||
}
|
||||
}
|
123
keyedmutex/keyedmutex_test.go
Normal file
123
keyedmutex/keyedmutex_test.go
Normal file
@ -0,0 +1,123 @@
|
||||
package keyedmutex
|
||||
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestKeyedMutex(t *testing.T) {
|
||||
checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) {
|
||||
if len(m.waitList) != len(keys) {
|
||||
t.Fatal(m.waitList, keys)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
if _, ok := m.waitList[key]; !ok {
|
||||
t.Fatal(key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
m := New[string]()
|
||||
checkState(t, m)
|
||||
|
||||
m.Lock("a")
|
||||
checkState(t, m, "a")
|
||||
m.Lock("b")
|
||||
checkState(t, m, "a", "b")
|
||||
m.Lock("c")
|
||||
checkState(t, m, "a", "b", "c")
|
||||
|
||||
if m.TryLock("a") {
|
||||
t.Fatal("a")
|
||||
}
|
||||
if m.TryLock("b") {
|
||||
t.Fatal("b")
|
||||
}
|
||||
if m.TryLock("c") {
|
||||
t.Fatal("c")
|
||||
}
|
||||
|
||||
if !m.TryLock("d") {
|
||||
t.Fatal("d")
|
||||
}
|
||||
|
||||
checkState(t, m, "a", "b", "c", "d")
|
||||
|
||||
if !m.TryLock("e") {
|
||||
t.Fatal("e")
|
||||
}
|
||||
checkState(t, m, "a", "b", "c", "d", "e")
|
||||
|
||||
m.Unlock("c")
|
||||
checkState(t, m, "a", "b", "d", "e")
|
||||
m.Unlock("a")
|
||||
checkState(t, m, "b", "d", "e")
|
||||
m.Unlock("e")
|
||||
checkState(t, m, "b", "d")
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 8; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Lock("b")
|
||||
m.Unlock("b")
|
||||
}()
|
||||
}
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
m.Unlock("b")
|
||||
wg.Wait()
|
||||
|
||||
checkState(t, m, "d")
|
||||
|
||||
m.Unlock("d")
|
||||
checkState(t, m)
|
||||
}
|
||||
|
||||
func TestKeyedMutex_unlockUnlocked(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r == nil {
|
||||
t.Fatal(r)
|
||||
}
|
||||
}()
|
||||
|
||||
m := New[string]()
|
||||
m.Unlock("aldkfj")
|
||||
}
|
||||
|
||||
func BenchmarkUncontendedMutex(b *testing.B) {
|
||||
m := New[string]()
|
||||
key := "xyz"
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
m.Lock(key)
|
||||
m.Unlock(key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkContendedMutex(b *testing.B) {
|
||||
m := New[string]()
|
||||
key := "xyz"
|
||||
|
||||
m.Lock(key)
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < b.N; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
m.Lock(key)
|
||||
m.Unlock(key)
|
||||
}()
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
|
||||
b.ResetTimer()
|
||||
m.Unlock(key)
|
||||
wg.Wait()
|
||||
}
|
2
kvmemcache/README.md
Normal file
2
kvmemcache/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# kvmemcache
|
||||
|
134
kvmemcache/cache.go
Normal file
134
kvmemcache/cache.go
Normal file
@ -0,0 +1,134 @@
|
||||
package kvmemcache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.crumpington.com/lib/keyedmutex"
|
||||
)
|
||||
|
||||
type Cache[K comparable, V any] struct {
|
||||
updateLock keyedmutex.KeyedMutex[K]
|
||||
src func(K) (V, error)
|
||||
ttl time.Duration
|
||||
maxSize int
|
||||
|
||||
// Lock protects variables below.
|
||||
lock sync.Mutex
|
||||
cache map[K]*list.Element
|
||||
ll *list.List
|
||||
stats Stats
|
||||
}
|
||||
|
||||
type lruItem[K comparable, V any] struct {
|
||||
key K
|
||||
createdAt time.Time
|
||||
value V
|
||||
err error
|
||||
}
|
||||
|
||||
type Config[K comparable, V any] struct {
|
||||
MaxSize int
|
||||
TTL time.Duration // Zero to ignore.
|
||||
Src func(K) (V, error)
|
||||
}
|
||||
|
||||
func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] {
|
||||
return &Cache[K, V]{
|
||||
updateLock: keyedmutex.New[K](),
|
||||
src: conf.Src,
|
||||
ttl: conf.TTL,
|
||||
maxSize: conf.MaxSize,
|
||||
lock: sync.Mutex{},
|
||||
cache: make(map[K]*list.Element, conf.MaxSize+1),
|
||||
ll: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) Get(key K) (V, error) {
|
||||
ok, val, err := c.get(key)
|
||||
if ok {
|
||||
return val, err
|
||||
}
|
||||
|
||||
return c.load(key)
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) Evict(key K) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
c.evict(key)
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) Stats() Stats {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
return c.stats
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) put(key K, value V, err error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
c.stats.Misses++
|
||||
|
||||
c.cache[key] = c.ll.PushFront(lruItem[K, V]{
|
||||
key: key,
|
||||
createdAt: time.Now(),
|
||||
value: value,
|
||||
err: err,
|
||||
})
|
||||
|
||||
if c.maxSize != 0 && len(c.cache) > c.maxSize {
|
||||
li := c.ll.Back()
|
||||
c.ll.Remove(li)
|
||||
delete(c.cache, li.Value.(lruItem[K, V]).key)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) evict(key K) {
|
||||
elem := c.cache[key]
|
||||
if elem != nil {
|
||||
delete(c.cache, key)
|
||||
c.ll.Remove(elem)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
li := c.cache[key]
|
||||
if li == nil {
|
||||
return false, val, nil
|
||||
}
|
||||
|
||||
item := li.Value.(lruItem[K, V])
|
||||
// Maybe evict.
|
||||
if c.ttl != 0 && time.Since(item.createdAt) > c.ttl {
|
||||
c.evict(key)
|
||||
return false, val, nil
|
||||
}
|
||||
|
||||
c.stats.Hits++
|
||||
|
||||
c.ll.MoveToFront(li)
|
||||
return true, item.value, item.err
|
||||
}
|
||||
|
||||
func (c *Cache[K, V]) load(key K) (V, error) {
|
||||
c.updateLock.Lock(key)
|
||||
defer c.updateLock.Unlock(key)
|
||||
|
||||
// Check again in case we lost the update race.
|
||||
ok, val, err := c.get(key)
|
||||
if ok {
|
||||
return val, err
|
||||
}
|
||||
|
||||
// Won the update race.
|
||||
val, err = c.src(key)
|
||||
c.put(key, val, err)
|
||||
return val, err
|
||||
}
|
249
kvmemcache/cache_test.go
Normal file
249
kvmemcache/cache_test.go
Normal file
@ -0,0 +1,249 @@
|
||||
package kvmemcache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type State[K comparable] struct {
|
||||
Keys []K
|
||||
Stats Stats
|
||||
}
|
||||
|
||||
func (c *Cache[K,V]) assert(state State[K]) error {
|
||||
c.lock.Lock()
|
||||
defer c.lock.Unlock()
|
||||
|
||||
if len(c.cache) != len(state.Keys) {
|
||||
return fmt.Errorf(
|
||||
"Expected %d keys but found %d.",
|
||||
len(state.Keys),
|
||||
len(c.cache))
|
||||
}
|
||||
|
||||
for _, k := range state.Keys {
|
||||
if _, ok := c.cache[k]; !ok {
|
||||
return fmt.Errorf(
|
||||
"Expected key %v not found.",
|
||||
k)
|
||||
}
|
||||
}
|
||||
|
||||
if c.stats.Hits != state.Stats.Hits {
|
||||
return fmt.Errorf(
|
||||
"Expected %d hits, but found %d.",
|
||||
state.Stats.Hits,
|
||||
c.stats.Hits)
|
||||
}
|
||||
|
||||
if c.stats.Misses != state.Stats.Misses {
|
||||
return fmt.Errorf(
|
||||
"Expected %d misses, but found %d.",
|
||||
state.Stats.Misses,
|
||||
c.stats.Misses)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var ErrTest = errors.New("Hello")
|
||||
|
||||
func TestCache_basic(t *testing.T) {
|
||||
c := New(Config[string, string]{
|
||||
MaxSize: 4,
|
||||
TTL: 50 * time.Millisecond,
|
||||
Src: func(key string) (string, error) {
|
||||
if key == "err" {
|
||||
return "", ErrTest
|
||||
}
|
||||
return key, nil
|
||||
},
|
||||
})
|
||||
|
||||
type testCase struct {
|
||||
name string
|
||||
sleep time.Duration
|
||||
key string
|
||||
evict bool
|
||||
state State[string]
|
||||
}
|
||||
|
||||
cases := []testCase{
|
||||
{
|
||||
name: "get a",
|
||||
key: "a",
|
||||
state: State[string]{
|
||||
Keys: []string{"a"},
|
||||
Stats: Stats{Hits: 0, Misses: 1},
|
||||
},
|
||||
}, {
|
||||
name: "get a again",
|
||||
key: "a",
|
||||
state: State[string]{
|
||||
Keys: []string{"a"},
|
||||
Stats: Stats{Hits: 1, Misses: 1},
|
||||
},
|
||||
}, {
|
||||
name: "sleep, then get a again",
|
||||
sleep: 55 * time.Millisecond,
|
||||
key: "a",
|
||||
state: State[string]{
|
||||
Keys: []string{"a"},
|
||||
Stats: Stats{Hits: 1, Misses: 2},
|
||||
},
|
||||
}, {
|
||||
name: "get b",
|
||||
key: "b",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "b"},
|
||||
Stats: Stats{Hits: 1, Misses: 3},
|
||||
},
|
||||
}, {
|
||||
name: "get c",
|
||||
key: "c",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "b", "c"},
|
||||
Stats: Stats{Hits: 1, Misses: 4},
|
||||
},
|
||||
}, {
|
||||
name: "get d",
|
||||
key: "d",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "b", "c", "d"},
|
||||
Stats: Stats{Hits: 1, Misses: 5},
|
||||
},
|
||||
}, {
|
||||
name: "get e",
|
||||
key: "e",
|
||||
state: State[string]{
|
||||
Keys: []string{"b", "c", "d", "e"},
|
||||
Stats: Stats{Hits: 1, Misses: 6},
|
||||
},
|
||||
}, {
|
||||
name: "get c again",
|
||||
key: "c",
|
||||
state: State[string]{
|
||||
Keys: []string{"b", "c", "d", "e"},
|
||||
Stats: Stats{Hits: 2, Misses: 6},
|
||||
},
|
||||
}, {
|
||||
name: "get err",
|
||||
key: "err",
|
||||
state: State[string]{
|
||||
Keys: []string{"c", "d", "e", "err"},
|
||||
Stats: Stats{Hits: 2, Misses: 7},
|
||||
},
|
||||
}, {
|
||||
name: "get err again",
|
||||
key: "err",
|
||||
state: State[string]{
|
||||
Keys: []string{"c", "d", "e", "err"},
|
||||
Stats: Stats{Hits: 3, Misses: 7},
|
||||
},
|
||||
}, {
|
||||
name: "evict c",
|
||||
key: "c",
|
||||
evict: true,
|
||||
state: State[string]{
|
||||
Keys: []string{"d", "e", "err"},
|
||||
Stats: Stats{Hits: 3, Misses: 7},
|
||||
},
|
||||
}, {
|
||||
name: "reload-all a",
|
||||
key: "a",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "d", "e", "err"},
|
||||
Stats: Stats{Hits: 3, Misses: 8},
|
||||
},
|
||||
}, {
|
||||
name: "reload-all b",
|
||||
key: "b",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "b", "e", "err"},
|
||||
Stats: Stats{Hits: 3, Misses: 9},
|
||||
},
|
||||
}, {
|
||||
name: "reload-all c",
|
||||
key: "c",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "b", "c", "err"},
|
||||
Stats: Stats{Hits: 3, Misses: 10},
|
||||
},
|
||||
}, {
|
||||
name: "reload-all d",
|
||||
key: "d",
|
||||
state: State[string]{
|
||||
Keys: []string{"a", "b", "c", "d"},
|
||||
Stats: Stats{Hits: 3, Misses: 11},
|
||||
},
|
||||
}, {
|
||||
name: "read a again",
|
||||
key: "a",
|
||||
state: State[string]{
|
||||
Keys: []string{"b", "c", "d", "a"},
|
||||
Stats: Stats{Hits: 4, Misses: 11},
|
||||
},
|
||||
}, {
|
||||
name: "read e, evicting b",
|
||||
key: "e",
|
||||
state: State[string]{
|
||||
Keys: []string{"c", "d", "a", "e"},
|
||||
Stats: Stats{Hits: 4, Misses: 12},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
time.Sleep(tc.sleep)
|
||||
if !tc.evict {
|
||||
val, err := c.Get(tc.key)
|
||||
if tc.key == "err" && err != ErrTest {
|
||||
t.Fatal(tc.name, val)
|
||||
}
|
||||
if tc.key != "err" && val != tc.key {
|
||||
t.Fatal(tc.name, tc.key, val)
|
||||
}
|
||||
} else {
|
||||
c.Evict(tc.key)
|
||||
}
|
||||
|
||||
if err := c.assert(tc.state); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCache_thunderingHerd(t *testing.T) {
|
||||
c := New(Config[string,string]{
|
||||
MaxSize: 4,
|
||||
Src: func(key string) (string, error) {
|
||||
time.Sleep(time.Second)
|
||||
return key, nil
|
||||
},
|
||||
})
|
||||
|
||||
wg := sync.WaitGroup{}
|
||||
for i := 0; i < 16384; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
val, err := c.Get("a")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if val != "a" {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
stats := c.Stats()
|
||||
if stats.Hits != 16383 || stats.Misses != 1 {
|
||||
t.Fatal(stats)
|
||||
}
|
||||
}
|
5
kvmemcache/go.mod
Normal file
5
kvmemcache/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module git.crumpington.com/lib/kvmemcache
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require git.crumpington.com/lib/keyedmutex v1.0.1
|
2
kvmemcache/go.sum
Normal file
2
kvmemcache/go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
git.crumpington.com/lib/keyedmutex v1.0.1 h1:5ylwGXQzL9ojZIhlqkut6dpa4yt6Wz6bOWbf/tQBAMQ=
|
||||
git.crumpington.com/lib/keyedmutex v1.0.1/go.mod h1:VxxJRU/XvvF61IuJZG7kUIv954Q8+Rh8bnVpEzGYrQ4=
|
6
kvmemcache/stats.go
Normal file
6
kvmemcache/stats.go
Normal file
@ -0,0 +1,6 @@
|
||||
package kvmemcache
|
||||
|
||||
type Stats struct {
|
||||
Hits uint64
|
||||
Misses uint64
|
||||
}
|
2
mmap/README.md
Normal file
2
mmap/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# mmap
|
||||
|
100
mmap/file.go
Normal file
100
mmap/file.go
Normal file
@ -0,0 +1,100 @@
|
||||
package mmap
|
||||
|
||||
import "os"
|
||||
|
||||
type File struct {
|
||||
f *os.File
|
||||
Map []byte
|
||||
}
|
||||
|
||||
func Create(path string, size int64) (*File, error) {
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := f.Truncate(size); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
m, err := Map(f, PROT_READ|PROT_WRITE)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &File{f, m}, nil
|
||||
}
|
||||
|
||||
// Opens a mapped file in read-only mode.
|
||||
func Open(path string) (*File, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
m, err := Map(f, PROT_READ)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
return &File{f, m}, nil
|
||||
}
|
||||
|
||||
func OpenFile(
|
||||
path string,
|
||||
fileFlags int,
|
||||
perm os.FileMode,
|
||||
size int64, // -1 for file size.
|
||||
) (*File, error) {
|
||||
f, err := os.OpenFile(path, fileFlags, perm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
writable := fileFlags|os.O_RDWR != 0 || fileFlags|os.O_WRONLY != 0
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if writable && size > 0 && size != fi.Size() {
|
||||
if err := f.Truncate(size); err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
mapFlags := PROT_READ
|
||||
if writable {
|
||||
mapFlags |= PROT_WRITE
|
||||
}
|
||||
|
||||
m, err := Map(f, mapFlags)
|
||||
if err != nil {
|
||||
f.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &File{f, m}, nil
|
||||
}
|
||||
|
||||
func (f *File) Sync() error {
|
||||
return Sync(f.Map)
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
if f.Map != nil {
|
||||
if err := Unmap(f.Map); err != nil {
|
||||
return err
|
||||
}
|
||||
f.Map = nil
|
||||
}
|
||||
|
||||
if f.f != nil {
|
||||
if err := f.f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
f.f = nil
|
||||
}
|
||||
return nil
|
||||
}
|
3
mmap/go.mod
Normal file
3
mmap/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/mmap
|
||||
|
||||
go 1.23.2
|
0
mmap/go.sum
Normal file
0
mmap/go.sum
Normal file
59
mmap/mmap.go
Normal file
59
mmap/mmap.go
Normal file
@ -0,0 +1,59 @@
|
||||
package mmap
|
||||
|
||||
import (
|
||||
"os"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
const (
|
||||
PROT_READ = syscall.PROT_READ
|
||||
PROT_WRITE = syscall.PROT_WRITE
|
||||
)
|
||||
|
||||
// Mmap creates a memory map of the given file. The flags argument should be a
|
||||
// combination of PROT_READ and PROT_WRITE. The size of the map will be the
|
||||
// file's size.
|
||||
func Map(f *os.File, flags int) ([]byte, error) {
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
size := fi.Size()
|
||||
|
||||
addr, _, errno := syscall.Syscall6(
|
||||
syscall.SYS_MMAP,
|
||||
0, // addr: 0 => allow kernel to choose
|
||||
uintptr(size),
|
||||
uintptr(flags),
|
||||
uintptr(syscall.MAP_SHARED),
|
||||
f.Fd(),
|
||||
0) // offset: 0 => start of file
|
||||
if errno != 0 {
|
||||
return nil, syscall.Errno(errno)
|
||||
}
|
||||
|
||||
return unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), nil
|
||||
}
|
||||
|
||||
// Munmap unmaps the data obtained by Map.
|
||||
func Unmap(data []byte) error {
|
||||
_, _, errno := syscall.Syscall(
|
||||
syscall.SYS_MUNMAP,
|
||||
uintptr(unsafe.Pointer(&data[:1][0])),
|
||||
uintptr(cap(data)),
|
||||
0)
|
||||
if errno != 0 {
|
||||
return syscall.Errno(errno)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func Sync(b []byte) (err error) {
|
||||
_p0 := unsafe.Pointer(&b[0])
|
||||
_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(_p0), uintptr(len(b)), uintptr(syscall.MS_SYNC))
|
||||
if errno != 0 {
|
||||
err = syscall.Errno(errno)
|
||||
}
|
||||
return
|
||||
}
|
45
pgutil/README.md
Normal file
45
pgutil/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# pgutil
|
||||
|
||||
## Transactions
|
||||
|
||||
Simplify postgres transactions using `WithTx` for serializable transactions,
|
||||
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
|
||||
type to get automatic retries of serialization errors.
|
||||
|
||||
## Migrations
|
||||
|
||||
Put your migrations into a directory, for example `migrations`, ordered by name
|
||||
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
|
||||
`Migrate` function:
|
||||
|
||||
```Go
|
||||
//go:embed migrations
|
||||
var migrations embed.FS
|
||||
|
||||
func init() {
|
||||
Migrate(db, migrations) // Check the error, of course.
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
In order to test this packge, we need to create a test user and database:
|
||||
|
||||
```
|
||||
sudo su postgres
|
||||
psql
|
||||
|
||||
CREATE DATABASE test;
|
||||
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
|
||||
GRANT ALL PRIVILEGES ON DATABASE test TO test;
|
||||
|
||||
use test
|
||||
|
||||
GRANT ALL ON SCHEMA public TO test;
|
||||
```
|
||||
|
||||
Check that you can connect via the command line:
|
||||
|
||||
```
|
||||
psql -h 127.0.0.1 -U test --password test
|
||||
```
|
42
pgutil/dropall.go
Normal file
42
pgutil/dropall.go
Normal file
@ -0,0 +1,42 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
)
|
||||
|
||||
const dropTablesQueryQuery = `
|
||||
SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;'
|
||||
FROM
|
||||
pg_tables
|
||||
WHERE
|
||||
schemaname='public'`
|
||||
|
||||
// Deletes all tables in the database. Useful for testing.
|
||||
func DropAllTables(db *sql.DB) error {
|
||||
rows, err := db.Query(dropTablesQueryQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{}
|
||||
for rows.Next() {
|
||||
var s string
|
||||
if err := rows.Scan(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
queries = append(queries, s)
|
||||
}
|
||||
|
||||
if len(queries) > 0 {
|
||||
log.Printf("DROPPING ALL (%d) TABLES", len(queries))
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
31
pgutil/errors.go
Normal file
31
pgutil/errors.go
Normal file
@ -0,0 +1,31 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func ErrIsDuplicateKey(err error) bool {
|
||||
return ErrHasCode(err, "23505")
|
||||
}
|
||||
|
||||
func ErrIsForeignKey(err error) bool {
|
||||
return ErrHasCode(err, "23503")
|
||||
}
|
||||
|
||||
func ErrIsSerializationFaiilure(err error) bool {
|
||||
return ErrHasCode(err, "40001")
|
||||
}
|
||||
|
||||
func ErrHasCode(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var pErr *pq.Error
|
||||
if errors.As(err, &pErr) {
|
||||
return pErr.Code == pq.ErrorCode(code)
|
||||
}
|
||||
return false
|
||||
}
|
36
pgutil/errors_test.go
Normal file
36
pgutil/errors_test.go
Normal file
@ -0,0 +1,36 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
db, err := sql.Open(
|
||||
"postgres",
|
||||
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DropAllTables(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`)
|
||||
if !ErrIsDuplicateKey(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`)
|
||||
if !ErrIsDuplicateKey(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`)
|
||||
if !ErrIsForeignKey(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
5
pgutil/go.mod
Normal file
5
pgutil/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module git.crumpington.com/git/pgutil
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require github.com/lib/pq v1.10.9
|
2
pgutil/go.sum
Normal file
2
pgutil/go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
82
pgutil/migrate.go
Normal file
82
pgutil/migrate.go
Normal file
@ -0,0 +1,82 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const initMigrationTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
|
||||
|
||||
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
|
||||
|
||||
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
|
||||
|
||||
func Migrate(db *sql.DB, migrationFS embed.FS) error {
|
||||
return WithTx(db, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dirs, err := migrationFS.ReadDir(".")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(dirs) != 1 {
|
||||
return errors.New("expected a single migrations directory")
|
||||
}
|
||||
|
||||
if !dirs[0].IsDir() {
|
||||
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
|
||||
}
|
||||
|
||||
dirName := dirs[0].Name()
|
||||
files, err := migrationFS.ReadDir(dirName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sort sql files by name.
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Name() < files[j].Name()
|
||||
})
|
||||
|
||||
for _, dirEnt := range files {
|
||||
if !dirEnt.Type().IsRegular() {
|
||||
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
|
||||
}
|
||||
|
||||
var (
|
||||
name = dirEnt.Name()
|
||||
exists bool
|
||||
)
|
||||
|
||||
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(string(migration)); err != nil {
|
||||
return fmt.Errorf("migration %s failed: %v", name, err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
50
pgutil/migrate_test.go
Normal file
50
pgutil/migrate_test.go
Normal file
@ -0,0 +1,50 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//go:embed test-migrations
|
||||
var testMigrationFS embed.FS
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
db, err := sql.Open(
|
||||
"postgres",
|
||||
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DropAllTables(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Shouldn't have any effect.
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
|
||||
var exists bool
|
||||
|
||||
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if exists {
|
||||
t.Fatal("1 shouldn't exist")
|
||||
}
|
||||
|
||||
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("2 should exist")
|
||||
}
|
||||
|
||||
}
|
9
pgutil/test-migrations/000.sql
Normal file
9
pgutil/test-migrations/000.sql
Normal file
@ -0,0 +1,9 @@
|
||||
CREATE TABLE users(
|
||||
UserID BIGINT NOT NULL PRIMARY KEY,
|
||||
Email TEXT NOT NULL UNIQUE);
|
||||
|
||||
CREATE TABLE user_notes(
|
||||
UserID BIGINT NOT NULL REFERENCES users(UserID),
|
||||
NoteID BIGINT NOT NULL,
|
||||
Note Text NOT NULL,
|
||||
PRIMARY KEY(UserID,NoteID));
|
1
pgutil/test-migrations/001.sql
Normal file
1
pgutil/test-migrations/001.sql
Normal file
@ -0,0 +1 @@
|
||||
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');
|
1
pgutil/test-migrations/002.sql
Normal file
1
pgutil/test-migrations/002.sql
Normal file
@ -0,0 +1 @@
|
||||
DELETE FROM users WHERE UserID=1;
|
70
pgutil/tx.go
Normal file
70
pgutil/tx.go
Normal file
@ -0,0 +1,70 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Postgres doesn't use serializable transactions by default. This wrapper will
|
||||
// run the enclosed function within a serializable. Note: this may return an
|
||||
// retriable serialization error (see ErrIsSerializationFaiilure).
|
||||
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
return WithTxDefault(db, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil {
|
||||
return err
|
||||
}
|
||||
return fn(tx)
|
||||
})
|
||||
}
|
||||
|
||||
// This is a convenience function to provide a transaction wrapper with the
|
||||
// default isolation level.
|
||||
func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
// Start a transaction.
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fn(tx)
|
||||
|
||||
if err == nil {
|
||||
err = tx.Commit()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SerialTxRunner attempts serializable transactions in a loop. If a
|
||||
// transaction fails due to a serialization error, then the runner will retry
|
||||
// with exponential backoff, until the sleep time reaches MaxTimeout.
|
||||
//
|
||||
// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep
|
||||
// for ~100, 200, 400, and 800 ms between retries.
|
||||
//
|
||||
// 10% jitter is added to the sleep time.
|
||||
type SerialTxRunner struct {
|
||||
MinTimeout time.Duration
|
||||
MaxTimeout time.Duration
|
||||
}
|
||||
|
||||
func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
timeout := r.MinTimeout
|
||||
for {
|
||||
err := WithTx(db, fn)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) {
|
||||
return err
|
||||
}
|
||||
sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10)))
|
||||
time.Sleep(sleepTimeout)
|
||||
timeout *= 2
|
||||
}
|
||||
}
|
120
pgutil/tx_test.go
Normal file
120
pgutil/tx_test.go
Normal file
@ -0,0 +1,120 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestExecuteTx verifies transaction retry using the classic
|
||||
// example of write skew in bank account balance transfers.
|
||||
func TestWithTx(t *testing.T) {
|
||||
db, err := sql.Open(
|
||||
"postgres",
|
||||
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DropAllTables(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
|
||||
initStmt := `
|
||||
CREATE TABLE t (acct INT PRIMARY KEY, balance INT);
|
||||
INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100);
|
||||
`
|
||||
if _, err := db.Exec(initStmt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type queryI interface {
|
||||
Query(string, ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
getBalances := func(q queryI) (bal1, bal2 int, err error) {
|
||||
var rows *sql.Rows
|
||||
rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
balances := []*int{&bal1, &bal2}
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
if err = rows.Scan(balances[i]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if i != 2 {
|
||||
err = fmt.Errorf("expected two balances; got %d", i)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond}
|
||||
|
||||
runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error {
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
*iter = 0
|
||||
errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error {
|
||||
*iter++
|
||||
bal1, bal2, err := getBalances(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If this is the first iteration, wait for the other tx to
|
||||
// also read.
|
||||
if *iter == 1 {
|
||||
wg.Done()
|
||||
wg.Wait()
|
||||
}
|
||||
// Now, subtract from one account and give to the other.
|
||||
if bal1 > bal2 {
|
||||
if _, err := tx.Exec(`
|
||||
UPDATE t SET balance=balance-100 WHERE acct=1;
|
||||
UPDATE t SET balance=balance+100 WHERE acct=2;
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := tx.Exec(`
|
||||
UPDATE t SET balance=balance+100 WHERE acct=1;
|
||||
UPDATE t SET balance=balance-100 WHERE acct=2;
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
return errCh
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
var iters1, iters2 int
|
||||
txn1Err := runTxn(&wg, &iters1)
|
||||
txn2Err := runTxn(&wg, &iters2)
|
||||
if err := <-txn1Err; err != nil {
|
||||
t.Errorf("expected success in txn1; got %s", err)
|
||||
}
|
||||
if err := <-txn2Err; err != nil {
|
||||
t.Errorf("expected success in txn2; got %s", err)
|
||||
}
|
||||
if iters1+iters2 <= 2 {
|
||||
t.Errorf("expected retries between the competing transactions; "+
|
||||
"got txn1=%d, txn2=%d", iters1, iters2)
|
||||
}
|
||||
bal1, bal2, err := getBalances(db)
|
||||
if err != nil || bal1 != 100 || bal2 != 100 {
|
||||
t.Errorf("expected balances to be restored without error; "+
|
||||
"got acct1=%d, acct2=%d: %s", bal1, bal2, err)
|
||||
}
|
||||
}
|
3
ratelimiter/go.mod
Normal file
3
ratelimiter/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/ratelimiter
|
||||
|
||||
go 1.23.2
|
86
ratelimiter/ratelimiter.go
Normal file
86
ratelimiter/ratelimiter.go
Normal file
@ -0,0 +1,86 @@
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrBackoff = errors.New("Backoff")
|
||||
|
||||
type Config struct {
|
||||
BurstLimit int64 // Number of requests to allow to burst.
|
||||
FillPeriod time.Duration // Add one per period.
|
||||
MaxWaitCount int64 // Max number of waiting requests. 0 disables.
|
||||
}
|
||||
|
||||
type Limiter struct {
|
||||
lock sync.Mutex
|
||||
|
||||
fillPeriod time.Duration
|
||||
minWaitTime time.Duration
|
||||
maxWaitTime time.Duration
|
||||
|
||||
waitTime time.Duration // If waitTime < 0, no waiting occurs.
|
||||
lastRequest time.Time
|
||||
}
|
||||
|
||||
func New(conf Config) *Limiter {
|
||||
if conf.BurstLimit < 0 {
|
||||
panic(conf.BurstLimit)
|
||||
}
|
||||
if conf.FillPeriod <= 0 {
|
||||
panic(conf.FillPeriod)
|
||||
}
|
||||
if conf.MaxWaitCount < 0 {
|
||||
panic(conf.MaxWaitCount)
|
||||
}
|
||||
|
||||
lim := &Limiter{
|
||||
lastRequest: time.Now(),
|
||||
fillPeriod: conf.FillPeriod,
|
||||
waitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit),
|
||||
minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit),
|
||||
maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount),
|
||||
}
|
||||
|
||||
lim.waitTime = lim.minWaitTime
|
||||
|
||||
return lim
|
||||
}
|
||||
|
||||
func (lim *Limiter) limit(count int64) (time.Duration, error) {
|
||||
lim.lock.Lock()
|
||||
defer lim.lock.Unlock()
|
||||
|
||||
dt := time.Since(lim.lastRequest)
|
||||
|
||||
waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod
|
||||
|
||||
if waitTime < lim.minWaitTime {
|
||||
waitTime = lim.minWaitTime
|
||||
} else if waitTime > lim.maxWaitTime {
|
||||
return 0, ErrBackoff
|
||||
}
|
||||
|
||||
lim.waitTime = waitTime
|
||||
lim.lastRequest = lim.lastRequest.Add(dt)
|
||||
|
||||
return lim.waitTime, nil
|
||||
}
|
||||
|
||||
// Apply the limiter to the calling thread. The function may sleep for up to
|
||||
// maxWaitTime before returning. If the timeout would need to be more than
|
||||
// maxWaitTime to enforce the rate limit, ErrBackoff is returned.
|
||||
func (lim *Limiter) Limit() error {
|
||||
dt, err := lim.limit(1)
|
||||
time.Sleep(dt) // Will return immediately for dt <= 0.
|
||||
return err
|
||||
}
|
||||
|
||||
// Apply the limiter for multiple items at once.
|
||||
func (lim *Limiter) LimitMultiple(count int64) error {
|
||||
dt, err := lim.limit(count)
|
||||
time.Sleep(dt)
|
||||
return err
|
||||
}
|
101
ratelimiter/ratelimiter_test.go
Normal file
101
ratelimiter/ratelimiter_test.go
Normal file
@ -0,0 +1,101 @@
|
||||
package ratelimiter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestRateLimiter_Limit_Errors(t *testing.T) {
|
||||
type TestCase struct {
|
||||
Name string
|
||||
Conf Config
|
||||
N int
|
||||
ErrCount int
|
||||
DT time.Duration
|
||||
}
|
||||
|
||||
cases := []TestCase{
|
||||
{
|
||||
Name: "no burst, no wait",
|
||||
Conf: Config{
|
||||
BurstLimit: 0,
|
||||
FillPeriod: 100 * time.Millisecond,
|
||||
MaxWaitCount: 0,
|
||||
},
|
||||
N: 32,
|
||||
ErrCount: 32,
|
||||
DT: 0,
|
||||
}, {
|
||||
Name: "no wait",
|
||||
Conf: Config{
|
||||
BurstLimit: 10,
|
||||
FillPeriod: 100 * time.Millisecond,
|
||||
MaxWaitCount: 0,
|
||||
},
|
||||
N: 32,
|
||||
ErrCount: 22,
|
||||
DT: 0,
|
||||
}, {
|
||||
Name: "no burst",
|
||||
Conf: Config{
|
||||
BurstLimit: 0,
|
||||
FillPeriod: 10 * time.Millisecond,
|
||||
MaxWaitCount: 10,
|
||||
},
|
||||
N: 32,
|
||||
ErrCount: 22,
|
||||
DT: 100 * time.Millisecond,
|
||||
}, {
|
||||
Name: "burst and wait",
|
||||
Conf: Config{
|
||||
BurstLimit: 10,
|
||||
FillPeriod: 10 * time.Millisecond,
|
||||
MaxWaitCount: 10,
|
||||
},
|
||||
N: 32,
|
||||
ErrCount: 12,
|
||||
DT: 100 * time.Millisecond,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
wg := sync.WaitGroup{}
|
||||
l := New(tc.Conf)
|
||||
errs := make([]error, tc.N)
|
||||
|
||||
t0 := time.Now()
|
||||
|
||||
for i := 0; i < tc.N; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
errs[i] = l.Limit()
|
||||
wg.Done()
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
dt := time.Since(t0)
|
||||
|
||||
errCount := 0
|
||||
for _, err := range errs {
|
||||
if err != nil {
|
||||
errCount++
|
||||
}
|
||||
}
|
||||
|
||||
if errCount != tc.ErrCount {
|
||||
t.Fatalf("%s: Expected %d errors but got %d.",
|
||||
tc.Name, tc.ErrCount, errCount)
|
||||
}
|
||||
|
||||
if dt < tc.DT {
|
||||
t.Fatal(tc.Name, dt, tc.DT)
|
||||
}
|
||||
|
||||
if dt > tc.DT+10*time.Millisecond {
|
||||
t.Fatal(tc.Name, dt, tc.DT)
|
||||
}
|
||||
}
|
||||
}
|
22
sqlgen/README.md
Normal file
22
sqlgen/README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# sqlgen
|
||||
|
||||
## Installing
|
||||
|
||||
```
|
||||
go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
```
|
||||
sqlgen [driver] [defs-path] [output-path]
|
||||
```
|
||||
|
||||
## File Format
|
||||
|
||||
```
|
||||
TABLE [sql-name] OF [go-type] <NoInsert> <NoUpdate> <NoDelete> (
|
||||
[sql-column] [go-type] <AS go-name> <PK> <NoInsert> <NoUpdate>,
|
||||
...
|
||||
);
|
||||
```
|
7
sqlgen/cmd/sqlgen/main.go
Normal file
7
sqlgen/cmd/sqlgen/main.go
Normal file
@ -0,0 +1,7 @@
|
||||
package main
|
||||
|
||||
import "git.crumpington.com/lib/sqlgen"
|
||||
|
||||
func main() {
|
||||
sqlgen.Main()
|
||||
}
|
3
sqlgen/go.mod
Normal file
3
sqlgen/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/sqlgen
|
||||
|
||||
go 1.23.2
|
43
sqlgen/main.go
Normal file
43
sqlgen/main.go
Normal file
@ -0,0 +1,43 @@
|
||||
package sqlgen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
func Main() {
|
||||
usage := func() {
|
||||
fmt.Fprintf(os.Stderr, `
|
||||
%s DRIVER DEFS_PATH OUTPUT_PATH
|
||||
|
||||
Drivers are one of: sqlite, postgres
|
||||
`,
|
||||
os.Args[0])
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
if len(os.Args) != 4 {
|
||||
usage()
|
||||
}
|
||||
|
||||
var (
|
||||
template string
|
||||
driver = os.Args[1]
|
||||
defsPath = os.Args[2]
|
||||
outputPath = os.Args[3]
|
||||
)
|
||||
|
||||
switch driver {
|
||||
case "sqlite":
|
||||
template = sqliteTemplate
|
||||
default:
|
||||
fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver)
|
||||
usage()
|
||||
}
|
||||
|
||||
err := render(template, defsPath, outputPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
143
sqlgen/parse.go
Normal file
143
sqlgen/parse.go
Normal file
@ -0,0 +1,143 @@
|
||||
package sqlgen
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parsePath(filePath string) (*schema, error) {
|
||||
fileBytes, err := os.ReadFile(filePath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return parseBytes(fileBytes)
|
||||
}
|
||||
|
||||
func parseBytes(fileBytes []byte) (*schema, error) {
|
||||
s := string(fileBytes)
|
||||
for _, c := range []string{",", "(", ")", ";"} {
|
||||
s = strings.ReplaceAll(s, c, " "+c+" ")
|
||||
}
|
||||
|
||||
var (
|
||||
tokens = strings.Fields(s)
|
||||
schema = &schema{}
|
||||
err error
|
||||
)
|
||||
|
||||
for len(tokens) > 0 {
|
||||
switch tokens[0] {
|
||||
case "TABLE":
|
||||
tokens, err = parseTable(schema, tokens)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, errors.New("invalid token: " + tokens[0])
|
||||
}
|
||||
}
|
||||
|
||||
return schema, nil
|
||||
}
|
||||
|
||||
func parseTable(schema *schema, tokens []string) ([]string, error) {
|
||||
tokens = tokens[1:]
|
||||
if len(tokens) < 3 {
|
||||
return tokens, errors.New("incomplete table definition")
|
||||
}
|
||||
if tokens[1] != "OF" {
|
||||
return tokens, errors.New("expected OF in table definition")
|
||||
}
|
||||
|
||||
table := &table{
|
||||
Name: tokens[0],
|
||||
Type: tokens[2],
|
||||
}
|
||||
schema.Tables = append(schema.Tables, table)
|
||||
|
||||
tokens = tokens[3:]
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return tokens, errors.New("missing table definition body")
|
||||
}
|
||||
|
||||
for len(tokens) > 0 {
|
||||
switch tokens[0] {
|
||||
case "NoInsert":
|
||||
table.NoInsert = true
|
||||
tokens = tokens[1:]
|
||||
case "NoUpdate":
|
||||
table.NoUpdate = true
|
||||
tokens = tokens[1:]
|
||||
case "NoDelete":
|
||||
table.NoDelete = true
|
||||
tokens = tokens[1:]
|
||||
case "(":
|
||||
return parseTableBody(table, tokens[1:])
|
||||
default:
|
||||
return tokens, errors.New("unexpected token in table definition: " + tokens[0])
|
||||
}
|
||||
}
|
||||
|
||||
return tokens, errors.New("incomplete table definition")
|
||||
}
|
||||
|
||||
func parseTableBody(table *table, tokens []string) ([]string, error) {
|
||||
var err error
|
||||
|
||||
for len(tokens) > 0 && tokens[0] != ";" {
|
||||
tokens, err = parseTableColumn(table, tokens)
|
||||
if err != nil {
|
||||
return tokens, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(tokens) < 1 || tokens[0] != ";" {
|
||||
return tokens, errors.New("incomplete table column definitions")
|
||||
}
|
||||
|
||||
return tokens[1:], nil
|
||||
}
|
||||
|
||||
func parseTableColumn(table *table, tokens []string) ([]string, error) {
|
||||
if len(tokens) < 2 {
|
||||
return tokens, errors.New("incomplete column definition")
|
||||
}
|
||||
column := &column{
|
||||
Name: tokens[0],
|
||||
Type: tokens[1],
|
||||
SqlName: tokens[0],
|
||||
}
|
||||
table.Columns = append(table.Columns, column)
|
||||
|
||||
tokens = tokens[2:]
|
||||
for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" {
|
||||
switch tokens[0] {
|
||||
case "AS":
|
||||
if len(tokens) < 2 {
|
||||
return tokens, errors.New("incomplete AS clause in column definition")
|
||||
}
|
||||
column.Name = tokens[1]
|
||||
tokens = tokens[2:]
|
||||
case "PK":
|
||||
column.PK = true
|
||||
tokens = tokens[1:]
|
||||
case "NoInsert":
|
||||
column.NoInsert = true
|
||||
tokens = tokens[1:]
|
||||
case "NoUpdate":
|
||||
column.NoUpdate = true
|
||||
tokens = tokens[1:]
|
||||
default:
|
||||
return tokens, errors.New("unexpected token in column definition: " + tokens[0])
|
||||
}
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return tokens, errors.New("incomplete column definition")
|
||||
}
|
||||
|
||||
return tokens[1:], nil
|
||||
}
|
45
sqlgen/parse_test.go
Normal file
45
sqlgen/parse_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package sqlgen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
toString := func(v any) string {
|
||||
txt, _ := json.MarshalIndent(v, "", " ")
|
||||
return string(txt)
|
||||
}
|
||||
|
||||
paths, err := filepath.Glob("test-files/TestParse/*.def")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, defPath := range paths {
|
||||
t.Run(filepath.Base(defPath), func(t *testing.T) {
|
||||
parsed, err := parsePath(defPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expected := &schema{}
|
||||
if err := json.Unmarshal(b, expected); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(parsed, expected) {
|
||||
t.Fatalf("%s != %s", toString(parsed), toString(expected))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
263
sqlgen/schema.go
Normal file
263
sqlgen/schema.go
Normal file
@ -0,0 +1,263 @@
|
||||
package sqlgen
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type schema struct {
|
||||
Tables []*table
|
||||
}
|
||||
|
||||
type table struct {
|
||||
Name string // Name in SQL
|
||||
Type string // Go type
|
||||
NoInsert bool
|
||||
NoUpdate bool
|
||||
NoDelete bool
|
||||
Columns []*column
|
||||
}
|
||||
|
||||
type column struct {
|
||||
Name string
|
||||
Type string
|
||||
SqlName string // Defaults to Name
|
||||
PK bool // PK won't be updated
|
||||
NoInsert bool
|
||||
NoUpdate bool // Don't update column in update function
|
||||
}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
func (t *table) colSQLNames() []string {
|
||||
names := make([]string, len(t.Columns))
|
||||
for i := range names {
|
||||
names[i] = t.Columns[i].SqlName
|
||||
}
|
||||
return names
|
||||
}
|
||||
|
||||
func (t *table) SelectQuery() string {
|
||||
return fmt.Sprintf(`SELECT %s FROM %s`,
|
||||
strings.Join(t.colSQLNames(), ","),
|
||||
t.Name)
|
||||
}
|
||||
|
||||
func (t *table) insertCols() (cols []*column) {
|
||||
for _, c := range t.Columns {
|
||||
if !c.NoInsert {
|
||||
cols = append(cols, c)
|
||||
}
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (t *table) InsertQuery() string {
|
||||
cols := t.insertCols()
|
||||
b := &strings.Builder{}
|
||||
b.WriteString(`INSERT INTO `)
|
||||
b.WriteString(t.Name)
|
||||
b.WriteString(`(`)
|
||||
for i, c := range cols {
|
||||
if i != 0 {
|
||||
b.WriteString(`,`)
|
||||
}
|
||||
b.WriteString(c.SqlName)
|
||||
}
|
||||
b.WriteString(`) VALUES(`)
|
||||
for i := range cols {
|
||||
if i != 0 {
|
||||
b.WriteString(`,`)
|
||||
}
|
||||
b.WriteString(`?`)
|
||||
}
|
||||
b.WriteString(`)`)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) InsertArgs() string {
|
||||
args := []string{}
|
||||
for i, col := range t.Columns {
|
||||
if !col.NoInsert {
|
||||
args = append(args, "row."+t.Columns[i].Name)
|
||||
}
|
||||
}
|
||||
return strings.Join(args, ", ")
|
||||
}
|
||||
|
||||
func (t *table) UpdateCols() (cols []*column) {
|
||||
for _, c := range t.Columns {
|
||||
if !(c.PK || c.NoUpdate) {
|
||||
cols = append(cols, c)
|
||||
}
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (t *table) UpdateQuery() string {
|
||||
cols := t.UpdateCols()
|
||||
|
||||
b := &strings.Builder{}
|
||||
b.WriteString(`UPDATE `)
|
||||
b.WriteString(t.Name + ` SET `)
|
||||
for i, col := range cols {
|
||||
if i != 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(col.SqlName + `=?`)
|
||||
}
|
||||
|
||||
b.WriteString(` WHERE`)
|
||||
|
||||
for i, c := range t.pkCols() {
|
||||
if i != 0 {
|
||||
b.WriteString(` AND`)
|
||||
}
|
||||
b.WriteString(` ` + c.SqlName + `=?`)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) UpdateArgs() string {
|
||||
cols := t.UpdateCols()
|
||||
|
||||
b := &strings.Builder{}
|
||||
for i, col := range cols {
|
||||
if i != 0 {
|
||||
b.WriteString(`, `)
|
||||
}
|
||||
b.WriteString("row." + col.Name)
|
||||
}
|
||||
for _, col := range t.pkCols() {
|
||||
b.WriteString(", row." + col.Name)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) UpdateFullCols() (cols []*column) {
|
||||
for _, c := range t.Columns {
|
||||
if !c.PK {
|
||||
cols = append(cols, c)
|
||||
}
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (t *table) UpdateFullQuery() string {
|
||||
cols := t.UpdateFullCols()
|
||||
|
||||
b := &strings.Builder{}
|
||||
b.WriteString(`UPDATE `)
|
||||
b.WriteString(t.Name + ` SET `)
|
||||
for i, col := range cols {
|
||||
if i != 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(col.SqlName + `=?`)
|
||||
}
|
||||
|
||||
b.WriteString(` WHERE`)
|
||||
|
||||
for i, c := range t.pkCols() {
|
||||
if i != 0 {
|
||||
b.WriteString(` AND`)
|
||||
}
|
||||
b.WriteString(` ` + c.SqlName + `=?`)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) UpdateFullArgs() string {
|
||||
cols := t.UpdateFullCols()
|
||||
|
||||
b := &strings.Builder{}
|
||||
for i, col := range cols {
|
||||
if i != 0 {
|
||||
b.WriteString(`, `)
|
||||
}
|
||||
b.WriteString("row." + col.Name)
|
||||
}
|
||||
|
||||
for _, col := range t.pkCols() {
|
||||
b.WriteString(", row." + col.Name)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) pkCols() (cols []*column) {
|
||||
for _, c := range t.Columns {
|
||||
if c.PK {
|
||||
cols = append(cols, c)
|
||||
}
|
||||
}
|
||||
return cols
|
||||
}
|
||||
|
||||
func (t *table) PKFunctionArgs() string {
|
||||
b := &strings.Builder{}
|
||||
for _, col := range t.pkCols() {
|
||||
b.WriteString(col.Name)
|
||||
b.WriteString(` `)
|
||||
b.WriteString(col.Type)
|
||||
b.WriteString(",\n")
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) DeleteQuery() string {
|
||||
cols := t.pkCols()
|
||||
|
||||
b := &strings.Builder{}
|
||||
b.WriteString(`DELETE FROM `)
|
||||
b.WriteString(t.Name)
|
||||
b.WriteString(` WHERE `)
|
||||
|
||||
for i, col := range cols {
|
||||
if i != 0 {
|
||||
b.WriteString(` AND `)
|
||||
}
|
||||
b.WriteString(col.SqlName)
|
||||
b.WriteString(`=?`)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) DeleteArgs() string {
|
||||
cols := t.pkCols()
|
||||
b := &strings.Builder{}
|
||||
|
||||
for i, col := range cols {
|
||||
if i != 0 {
|
||||
b.WriteString(`,`)
|
||||
}
|
||||
b.WriteString(col.Name)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) GetQuery() string {
|
||||
b := &strings.Builder{}
|
||||
b.WriteString(t.SelectQuery())
|
||||
b.WriteString(` WHERE `)
|
||||
for i, col := range t.pkCols() {
|
||||
if i != 0 {
|
||||
b.WriteString(` AND `)
|
||||
}
|
||||
b.WriteString(col.SqlName + `=?`)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func (t *table) ScanArgs() string {
|
||||
b := &strings.Builder{}
|
||||
for i, col := range t.Columns {
|
||||
if i != 0 {
|
||||
b.WriteString(`, `)
|
||||
}
|
||||
b.WriteString(`&row.` + col.Name)
|
||||
}
|
||||
return b.String()
|
||||
}
|
206
sqlgen/sqlite.go.tmpl
Normal file
206
sqlgen/sqlite.go.tmpl
Normal file
@ -0,0 +1,206 @@
|
||||
package {{.PackageName}}
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"iter"
|
||||
)
|
||||
|
||||
type TX interface {
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
{{range .Schema.Tables}}
|
||||
|
||||
// ----------------------------------------------------------------------------
|
||||
// Table: {{.Name}}
|
||||
// ----------------------------------------------------------------------------
|
||||
|
||||
type {{.Type}} struct {
|
||||
{{- range .Columns}}
|
||||
{{.Name}} {{.Type}}{{end}}
|
||||
}
|
||||
|
||||
const {{.Type}}_SelectQuery = "{{.SelectQuery}}"
|
||||
|
||||
{{if not .NoInsert -}}
|
||||
|
||||
func {{.Type}}_Insert(
|
||||
tx TX,
|
||||
row *{{.Type}},
|
||||
) (err error) {
|
||||
{{.Type}}_Sanitize(row)
|
||||
if err = {{.Type}}_Validate(row); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}})
|
||||
return err
|
||||
}
|
||||
|
||||
{{- end}} {{/* if not .NoInsert */}}
|
||||
|
||||
{{if not .NoUpdate -}}
|
||||
|
||||
{{if .UpdateCols -}}
|
||||
|
||||
func {{.Type}}_Update(
|
||||
tx TX,
|
||||
row *{{.Type}},
|
||||
) (found bool, err error) {
|
||||
{{.Type}}_Sanitize(row)
|
||||
if err = {{.Type}}_Validate(row); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if n > 1 {
|
||||
panic("multiple rows updated")
|
||||
}
|
||||
|
||||
return n != 0, nil
|
||||
}
|
||||
{{- end}}
|
||||
|
||||
{{if .UpdateFullCols -}}
|
||||
|
||||
func {{.Type}}_UpdateFull(
|
||||
tx TX,
|
||||
row *{{.Type}},
|
||||
) (found bool, err error) {
|
||||
{{.Type}}_Sanitize(row)
|
||||
if err = {{.Type}}_Validate(row); err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if n > 1 {
|
||||
panic("multiple rows updated")
|
||||
}
|
||||
|
||||
return n != 0, nil
|
||||
}
|
||||
{{- end}}
|
||||
|
||||
|
||||
{{- end}} {{/* if not .NoUpdate */}}
|
||||
|
||||
{{if not .NoDelete -}}
|
||||
|
||||
func {{.Type}}_Delete(
|
||||
tx TX,
|
||||
{{.PKFunctionArgs -}}
|
||||
) (found bool, err error) {
|
||||
result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}})
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
n, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if n > 1 {
|
||||
panic("multiple rows deleted")
|
||||
}
|
||||
|
||||
return n != 0, nil
|
||||
}
|
||||
|
||||
{{- end}}
|
||||
|
||||
func {{.Type}}_Get(
|
||||
tx TX,
|
||||
{{.PKFunctionArgs -}}
|
||||
) (
|
||||
row *{{.Type}},
|
||||
err error,
|
||||
) {
|
||||
row = &{{.Type}}{}
|
||||
r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}})
|
||||
err = r.Scan({{.ScanArgs}})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
func {{.Type}}_GetWhere(
|
||||
tx TX,
|
||||
query string,
|
||||
args ...any,
|
||||
) (
|
||||
row *{{.Type}},
|
||||
err error,
|
||||
) {
|
||||
row = &{{.Type}}{}
|
||||
r := tx.QueryRow(query, args...)
|
||||
err = r.Scan({{.ScanArgs}})
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
||||
func {{.Type}}_Iterate(
|
||||
tx TX,
|
||||
query string,
|
||||
args ...any,
|
||||
) (
|
||||
iter.Seq2[*{{.Type}}, error],
|
||||
) {
|
||||
rows, err := tx.Query(query, args...)
|
||||
if err != nil {
|
||||
return func(yield func(*{{.Type}}, error) bool) {
|
||||
yield(nil, err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(yield func(*{{.Type}}, error) bool) {
|
||||
defer rows.Close()
|
||||
for rows.Next() {
|
||||
row := &{{.Type}}{}
|
||||
err := rows.Scan({{.ScanArgs}})
|
||||
if !yield(row, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func {{.Type}}_List(
|
||||
tx TX,
|
||||
query string,
|
||||
args ...any,
|
||||
) (
|
||||
l []*{{.Type}},
|
||||
err error,
|
||||
) {
|
||||
for row, err := range {{.Type}}_Iterate(tx, query, args...) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l = append(l, row)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
|
||||
{{end}} {{/* range .Schema.Tables */}}
|
36
sqlgen/template.go
Normal file
36
sqlgen/template.go
Normal file
@ -0,0 +1,36 @@
|
||||
package sqlgen
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
//go:embed sqlite.go.tmpl
|
||||
var sqliteTemplate string
|
||||
|
||||
func render(templateStr, schemaPath, outputPath string) error {
|
||||
sch, err := parsePath(schemaPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tmpl := template.Must(template.New("").Parse(templateStr))
|
||||
fOut, err := os.Create(outputPath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fOut.Close()
|
||||
|
||||
err = tmpl.Execute(fOut, struct {
|
||||
PackageName string
|
||||
Schema *schema
|
||||
}{filepath.Base(filepath.Dir(outputPath)), sch})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return exec.Command("gofmt", "-w", outputPath).Run()
|
||||
}
|
3
sqlgen/test-files/TestParse/000.def
Normal file
3
sqlgen/test-files/TestParse/000.def
Normal file
@ -0,0 +1,3 @@
|
||||
TABLE users OF User NoDelete (
|
||||
user_id string AS UserID PK
|
||||
);
|
17
sqlgen/test-files/TestParse/000.json
Normal file
17
sqlgen/test-files/TestParse/000.json
Normal file
@ -0,0 +1,17 @@
|
||||
{
|
||||
"Tables": [
|
||||
{
|
||||
"Name": "users",
|
||||
"Type": "User",
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
4
sqlgen/test-files/TestParse/001.def
Normal file
4
sqlgen/test-files/TestParse/001.def
Normal file
@ -0,0 +1,4 @@
|
||||
TABLE users OF User NoDelete (
|
||||
user_id string AS UserID PK,
|
||||
email string AS Email NoUpdate
|
||||
);
|
22
sqlgen/test-files/TestParse/001.json
Normal file
22
sqlgen/test-files/TestParse/001.json
Normal file
@ -0,0 +1,22 @@
|
||||
{
|
||||
"Tables": [
|
||||
{
|
||||
"Name": "users",
|
||||
"Type": "User",
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
}, {
|
||||
"Name": "Email",
|
||||
"Type": "string",
|
||||
"SqlName": "email",
|
||||
"NoUpdate": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
6
sqlgen/test-files/TestParse/002.def
Normal file
6
sqlgen/test-files/TestParse/002.def
Normal file
@ -0,0 +1,6 @@
|
||||
TABLE users OF User NoDelete (
|
||||
user_id string AS UserID PK,
|
||||
email string AS Email NoUpdate,
|
||||
name string AS Name NoInsert,
|
||||
admin bool AS Admin NoInsert NoUpdate
|
||||
);
|
33
sqlgen/test-files/TestParse/002.json
Normal file
33
sqlgen/test-files/TestParse/002.json
Normal file
@ -0,0 +1,33 @@
|
||||
{
|
||||
"Tables": [
|
||||
{
|
||||
"Name": "users",
|
||||
"Type": "User",
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
}, {
|
||||
"Name": "Email",
|
||||
"Type": "string",
|
||||
"SqlName": "email",
|
||||
"NoUpdate": true
|
||||
}, {
|
||||
"Name": "Name",
|
||||
"Type": "string",
|
||||
"SqlName": "name",
|
||||
"NoInsert": true
|
||||
}, {
|
||||
"Name": "Admin",
|
||||
"Type": "bool",
|
||||
"SqlName": "admin",
|
||||
"NoInsert": true,
|
||||
"NoUpdate": true
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
12
sqlgen/test-files/TestParse/003.def
Normal file
12
sqlgen/test-files/TestParse/003.def
Normal file
@ -0,0 +1,12 @@
|
||||
TABLE users OF User NoDelete (
|
||||
user_id string AS UserID PK,
|
||||
email string AS Email NoUpdate,
|
||||
name string AS Name NoInsert,
|
||||
admin bool AS Admin NoInsert NoUpdate
|
||||
);
|
||||
|
||||
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
|
||||
user_id string AS UserID PK,
|
||||
email string AS Email,
|
||||
name string AS Name
|
||||
);
|
61
sqlgen/test-files/TestParse/003.json
Normal file
61
sqlgen/test-files/TestParse/003.json
Normal file
@ -0,0 +1,61 @@
|
||||
{
|
||||
"Tables": [
|
||||
{
|
||||
"Name": "users",
|
||||
"Type": "User",
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
},
|
||||
{
|
||||
"Name": "Email",
|
||||
"Type": "string",
|
||||
"SqlName": "email",
|
||||
"NoUpdate": true
|
||||
},
|
||||
{
|
||||
"Name": "Name",
|
||||
"Type": "string",
|
||||
"SqlName": "name",
|
||||
"NoInsert": true
|
||||
},
|
||||
{
|
||||
"Name": "Admin",
|
||||
"Type": "bool",
|
||||
"SqlName": "admin",
|
||||
"NoInsert": true,
|
||||
"NoUpdate": true
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"Name": "users_view",
|
||||
"Type": "UserView",
|
||||
"NoInsert": true,
|
||||
"NoUpdate": true,
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
},
|
||||
{
|
||||
"Name": "Email",
|
||||
"Type": "string",
|
||||
"SqlName": "email"
|
||||
},
|
||||
{
|
||||
"Name": "Name",
|
||||
"Type": "string",
|
||||
"SqlName": "name"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
13
sqlgen/test-files/TestParse/004.def
Normal file
13
sqlgen/test-files/TestParse/004.def
Normal file
@ -0,0 +1,13 @@
|
||||
TABLE users OF User NoDelete (
|
||||
user_id string AS UserID PK,
|
||||
email string AS Email NoUpdate,
|
||||
name string AS Name NoInsert,
|
||||
admin bool AS Admin NoInsert NoUpdate,
|
||||
SSN string NoUpdate
|
||||
);
|
||||
|
||||
TABLE users_view OF UserView NoInsert NoUpdate NoDelete (
|
||||
user_id string AS UserID PK,
|
||||
email string AS Email,
|
||||
name string AS Name
|
||||
);
|
66
sqlgen/test-files/TestParse/004.json
Normal file
66
sqlgen/test-files/TestParse/004.json
Normal file
@ -0,0 +1,66 @@
|
||||
{
|
||||
"Tables": [
|
||||
{
|
||||
"Name": "users",
|
||||
"Type": "User",
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
},
|
||||
{
|
||||
"Name": "Email",
|
||||
"Type": "string",
|
||||
"SqlName": "email",
|
||||
"NoUpdate": true
|
||||
},
|
||||
{
|
||||
"Name": "Name",
|
||||
"Type": "string",
|
||||
"SqlName": "name",
|
||||
"NoInsert": true
|
||||
},
|
||||
{
|
||||
"Name": "Admin",
|
||||
"Type": "bool",
|
||||
"SqlName": "admin",
|
||||
"NoInsert": true,
|
||||
"NoUpdate": true
|
||||
}, {
|
||||
"Name": "SSN",
|
||||
"Type": "string",
|
||||
"SqlName": "SSN",
|
||||
"NoUpdate": true
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"Name": "users_view",
|
||||
"Type": "UserView",
|
||||
"NoInsert": true,
|
||||
"NoUpdate": true,
|
||||
"NoDelete": true,
|
||||
"Columns": [
|
||||
{
|
||||
"Name": "UserID",
|
||||
"Type": "string",
|
||||
"SqlName": "user_id",
|
||||
"PK": true
|
||||
},
|
||||
{
|
||||
"Name": "Email",
|
||||
"Type": "string",
|
||||
"SqlName": "email"
|
||||
},
|
||||
{
|
||||
"Name": "Name",
|
||||
"Type": "string",
|
||||
"SqlName": "name"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
45
sqliteutil/README.md
Normal file
45
sqliteutil/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# sqliteutil
|
||||
|
||||
## Transactions
|
||||
|
||||
Simplify postgres transactions using `WithTx` for serializable transactions,
|
||||
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
|
||||
type to get automatic retries of serialization errors.
|
||||
|
||||
## Migrations
|
||||
|
||||
Put your migrations into a directory, for example `migrations`, ordered by name
|
||||
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
|
||||
`Migrate` function:
|
||||
|
||||
```Go
|
||||
//go:embed migrations
|
||||
var migrations embed.FS
|
||||
|
||||
func init() {
|
||||
Migrate(db, migrations) // Check the error, of course.
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
In order to test this packge, we need to create a test user and database:
|
||||
|
||||
```
|
||||
sudo su postgres
|
||||
psql
|
||||
|
||||
CREATE DATABASE test;
|
||||
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
|
||||
GRANT ALL PRIVILEGES ON DATABASE test TO test;
|
||||
|
||||
use test
|
||||
|
||||
GRANT ALL ON SCHEMA public TO test;
|
||||
```
|
||||
|
||||
Check that you can connect via the command line:
|
||||
|
||||
```
|
||||
psql -h 127.0.0.1 -U test --password test
|
||||
```
|
5
sqliteutil/go.mod
Normal file
5
sqliteutil/go.mod
Normal file
@ -0,0 +1,5 @@
|
||||
module git.crumpington.com/lib/sqliteutil
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require github.com/mattn/go-sqlite3 v1.14.24 // indirect
|
2
sqliteutil/go.sum
Normal file
2
sqliteutil/go.sum
Normal file
@ -0,0 +1,2 @@
|
||||
github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM=
|
||||
github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
|
82
sqliteutil/migrate.go
Normal file
82
sqliteutil/migrate.go
Normal file
@ -0,0 +1,82 @@
|
||||
package sqliteutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const initMigrationTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
|
||||
|
||||
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
|
||||
|
||||
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
|
||||
|
||||
func Migrate(db *sql.DB, migrationFS embed.FS) error {
|
||||
return WithTx(db, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dirs, err := migrationFS.ReadDir(".")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(dirs) != 1 {
|
||||
return errors.New("expected a single migrations directory")
|
||||
}
|
||||
|
||||
if !dirs[0].IsDir() {
|
||||
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
|
||||
}
|
||||
|
||||
dirName := dirs[0].Name()
|
||||
files, err := migrationFS.ReadDir(dirName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sort sql files by name.
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Name() < files[j].Name()
|
||||
})
|
||||
|
||||
for _, dirEnt := range files {
|
||||
if !dirEnt.Type().IsRegular() {
|
||||
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
|
||||
}
|
||||
|
||||
var (
|
||||
name = dirEnt.Name()
|
||||
exists bool
|
||||
)
|
||||
|
||||
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(string(migration)); err != nil {
|
||||
return fmt.Errorf("migration %s failed: %v", name, err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
44
sqliteutil/migrate_test.go
Normal file
44
sqliteutil/migrate_test.go
Normal file
@ -0,0 +1,44 @@
|
||||
package sqliteutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//go:embed test-migrations
|
||||
var testMigrationFS embed.FS
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
db, err := sql.Open("sqlite3", ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Shouldn't have any effect.
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
|
||||
var exists bool
|
||||
|
||||
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if exists {
|
||||
t.Fatal("1 shouldn't exist")
|
||||
}
|
||||
|
||||
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("2 should exist")
|
||||
}
|
||||
|
||||
}
|
9
sqliteutil/test-migrations/000.sql
Normal file
9
sqliteutil/test-migrations/000.sql
Normal file
@ -0,0 +1,9 @@
|
||||
CREATE TABLE users(
|
||||
UserID BIGINT NOT NULL PRIMARY KEY,
|
||||
Email TEXT NOT NULL UNIQUE);
|
||||
|
||||
CREATE TABLE user_notes(
|
||||
UserID BIGINT NOT NULL REFERENCES users(UserID),
|
||||
NoteID BIGINT NOT NULL,
|
||||
Note Text NOT NULL,
|
||||
PRIMARY KEY(UserID,NoteID));
|
1
sqliteutil/test-migrations/001.sql
Normal file
1
sqliteutil/test-migrations/001.sql
Normal file
@ -0,0 +1 @@
|
||||
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');
|
1
sqliteutil/test-migrations/002.sql
Normal file
1
sqliteutil/test-migrations/002.sql
Normal file
@ -0,0 +1 @@
|
||||
DELETE FROM users WHERE UserID=1;
|
28
sqliteutil/tx.go
Normal file
28
sqliteutil/tx.go
Normal file
@ -0,0 +1,28 @@
|
||||
package sqliteutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
// This is a convenience function to run a function within a transaction.
|
||||
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
// Start a transaction.
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fn(tx)
|
||||
|
||||
if err == nil {
|
||||
err = tx.Commit()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
2
tagengine/README.md
Normal file
2
tagengine/README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# tagengine
|
||||
|
3
tagengine/go.mod
Normal file
3
tagengine/go.mod
Normal file
@ -0,0 +1,3 @@
|
||||
module git.crumpington.com/lib/tagengine
|
||||
|
||||
go 1.23.2
|
0
tagengine/go.sum
Normal file
0
tagengine/go.sum
Normal file
30
tagengine/ngram.go
Normal file
30
tagengine/ngram.go
Normal file
@ -0,0 +1,30 @@
|
||||
package tagengine
|
||||
|
||||
import "unicode"
|
||||
|
||||
func ngramLength(s string) int {
|
||||
N := len(s)
|
||||
i := 0
|
||||
count := 0
|
||||
|
||||
for {
|
||||
// Eat spaces.
|
||||
for i < N && unicode.IsSpace(rune(s[i])) {
|
||||
i++
|
||||
}
|
||||
|
||||
// Done?
|
||||
if i == N {
|
||||
break
|
||||
}
|
||||
|
||||
// Non-space!
|
||||
count++
|
||||
|
||||
// Eat non-spaces.
|
||||
for i < N && !unicode.IsSpace(rune(s[i])) {
|
||||
i++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
31
tagengine/ngram_test.go
Normal file
31
tagengine/ngram_test.go
Normal file
@ -0,0 +1,31 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"log"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNGramLength(t *testing.T) {
|
||||
type Case struct {
|
||||
Input string
|
||||
Length int
|
||||
}
|
||||
|
||||
cases := []Case{
|
||||
{"a b c", 3},
|
||||
{" xyz\nlkj dflaj a", 4},
|
||||
{"a", 1},
|
||||
{" a", 1},
|
||||
{"a", 1},
|
||||
{" a\n", 1},
|
||||
{" a ", 1},
|
||||
{"\tx\ny\nz q ", 4},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
length := ngramLength(tc.Input)
|
||||
if length != tc.Length {
|
||||
log.Fatalf("%s: %d != %d", tc.Input, length, tc.Length)
|
||||
}
|
||||
}
|
||||
}
|
79
tagengine/node.go
Normal file
79
tagengine/node.go
Normal file
@ -0,0 +1,79 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type node struct {
|
||||
Token string
|
||||
Matches []*Rule // If a list of tokens reaches this node, it matches these.
|
||||
Children map[string]*node
|
||||
}
|
||||
|
||||
func (n *node) AddRule(r *Rule) {
|
||||
n.addRule(r, 0)
|
||||
}
|
||||
|
||||
func (n *node) addRule(r *Rule, idx int) {
|
||||
if len(r.Includes) == idx {
|
||||
n.Matches = append(n.Matches, r)
|
||||
return
|
||||
}
|
||||
|
||||
token := r.Includes[idx]
|
||||
|
||||
child, ok := n.Children[token]
|
||||
if !ok {
|
||||
child = &node{
|
||||
Token: token,
|
||||
Children: map[string]*node{},
|
||||
}
|
||||
n.Children[token] = child
|
||||
}
|
||||
|
||||
child.addRule(r, idx+1)
|
||||
}
|
||||
|
||||
// Note that tokens must be sorted. This is the case for tokens created from
|
||||
// the tokenize function.
|
||||
func (n *node) Match(tokens []string) (rules []*Rule) {
|
||||
return n.match(tokens, rules)
|
||||
}
|
||||
|
||||
func (n *node) match(tokens []string, rules []*Rule) []*Rule {
|
||||
// Check for a match.
|
||||
if n.Matches != nil {
|
||||
rules = append(rules, n.Matches...)
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return rules
|
||||
}
|
||||
|
||||
// Attempt to match children.
|
||||
for i := 0; i < len(tokens); i++ {
|
||||
token := tokens[i]
|
||||
if child, ok := n.Children[token]; ok {
|
||||
rules = child.match(tokens[i+1:], rules)
|
||||
}
|
||||
}
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
func (n *node) Dump() {
|
||||
n.dump(0)
|
||||
}
|
||||
|
||||
func (n *node) dump(depth int) {
|
||||
indent := strings.Repeat(" ", 2*depth)
|
||||
tag := ""
|
||||
for _, m := range n.Matches {
|
||||
tag += " " + m.Tag
|
||||
}
|
||||
fmt.Printf("%s%s%s\n", indent, n.Token, tag)
|
||||
for _, child := range n.Children {
|
||||
child.dump(depth + 1)
|
||||
}
|
||||
}
|
159
tagengine/rule.go
Normal file
159
tagengine/rule.go
Normal file
@ -0,0 +1,159 @@
|
||||
package tagengine
|
||||
|
||||
type Rule struct {
|
||||
// The purpose of a Rule is to attach it's Tag to matching text.
|
||||
Tag string
|
||||
|
||||
// Includes is a list of strings that must be found in the input in order to
|
||||
// match.
|
||||
Includes []string
|
||||
|
||||
// Excludes is a list of strings that can exclude a match for this rule.
|
||||
Excludes []string
|
||||
|
||||
// Blocks: If this rule is matched, then it will block matches of any tags
|
||||
// listed here.
|
||||
Blocks []string
|
||||
|
||||
// The Score encodes the complexity of the Rule. A higher score indicates a
|
||||
// more specific match. A Rule more includes, or includes with multiple words
|
||||
// should havee a higher Score than a Rule with fewer includes or less
|
||||
// complex includes.
|
||||
Score int
|
||||
|
||||
excludes map[string]struct{}
|
||||
}
|
||||
|
||||
func NewRule(tag string) Rule {
|
||||
return Rule{Tag: tag}
|
||||
}
|
||||
|
||||
func (r Rule) Inc(l ...string) Rule {
|
||||
return Rule{
|
||||
Tag: r.Tag,
|
||||
Includes: append(r.Includes, l...),
|
||||
Excludes: r.Excludes,
|
||||
Blocks: r.Blocks,
|
||||
}
|
||||
}
|
||||
|
||||
func (r Rule) Exc(l ...string) Rule {
|
||||
return Rule{
|
||||
Tag: r.Tag,
|
||||
Includes: r.Includes,
|
||||
Excludes: append(r.Excludes, l...),
|
||||
Blocks: r.Blocks,
|
||||
}
|
||||
}
|
||||
|
||||
func (r Rule) Block(l ...string) Rule {
|
||||
return Rule{
|
||||
Tag: r.Tag,
|
||||
Includes: r.Includes,
|
||||
Excludes: r.Excludes,
|
||||
Blocks: append(r.Blocks, l...),
|
||||
}
|
||||
}
|
||||
|
||||
func (rule *Rule) normalize(sanitize func(string) string) {
|
||||
for i, token := range rule.Includes {
|
||||
rule.Includes[i] = sanitize(token)
|
||||
}
|
||||
for i, token := range rule.Excludes {
|
||||
rule.Excludes[i] = sanitize(token)
|
||||
}
|
||||
|
||||
sortTokens(rule.Includes)
|
||||
sortTokens(rule.Excludes)
|
||||
|
||||
rule.excludes = map[string]struct{}{}
|
||||
for _, s := range rule.Excludes {
|
||||
rule.excludes[s] = struct{}{}
|
||||
}
|
||||
|
||||
rule.Score = rule.computeScore()
|
||||
}
|
||||
|
||||
func (r Rule) maxNGram() int {
|
||||
max := 0
|
||||
for _, s := range r.Includes {
|
||||
n := ngramLength(s)
|
||||
if n > max {
|
||||
max = n
|
||||
}
|
||||
}
|
||||
for _, s := range r.Excludes {
|
||||
n := ngramLength(s)
|
||||
if n > max {
|
||||
max = n
|
||||
}
|
||||
}
|
||||
|
||||
return max
|
||||
}
|
||||
|
||||
func (r Rule) isExcluded(tokens []string) bool {
|
||||
// This is most often the case.
|
||||
if len(r.excludes) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
for _, s := range tokens {
|
||||
if _, ok := r.excludes[s]; ok {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (r Rule) computeScore() (score int) {
|
||||
for _, token := range r.Includes {
|
||||
n := ngramLength(token)
|
||||
score += n * (n + 1) / 2
|
||||
}
|
||||
return score
|
||||
}
|
||||
|
||||
func ruleLess(lhs, rhs *Rule) bool {
|
||||
// If scores differ, sort by score.
|
||||
if lhs.Score != rhs.Score {
|
||||
return lhs.Score < rhs.Score
|
||||
}
|
||||
|
||||
// If include depth differs, sort by depth.
|
||||
lDepth := len(lhs.Includes)
|
||||
rDepth := len(rhs.Includes)
|
||||
|
||||
if lDepth != rDepth {
|
||||
return lDepth < rDepth
|
||||
}
|
||||
|
||||
// If exclude depth differs, sort by depth.
|
||||
lDepth = len(lhs.Excludes)
|
||||
rDepth = len(rhs.Excludes)
|
||||
|
||||
if lDepth != rDepth {
|
||||
return lDepth < rDepth
|
||||
}
|
||||
|
||||
// Sort alphabetically by includes.
|
||||
for i := range lhs.Includes {
|
||||
if lhs.Includes[i] != rhs.Includes[i] {
|
||||
return lhs.Includes[i] < rhs.Includes[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by alphabetically by excludes.
|
||||
for i := range lhs.Excludes {
|
||||
if lhs.Excludes[i] != rhs.Excludes[i] {
|
||||
return lhs.Excludes[i] < rhs.Excludes[i]
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by tag.
|
||||
if lhs.Tag != rhs.Tag {
|
||||
return lhs.Tag < rhs.Tag
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
58
tagengine/rulegroup.go
Normal file
58
tagengine/rulegroup.go
Normal file
@ -0,0 +1,58 @@
|
||||
package tagengine
|
||||
|
||||
// A RuleGroup can be converted into a list of rules. Each rule will point to
|
||||
// the same tag, and have the same exclude set and blocks.
|
||||
type RuleGroup struct {
|
||||
Tag string
|
||||
Includes [][]string
|
||||
Excludes []string
|
||||
Blocks []string
|
||||
}
|
||||
|
||||
func NewRuleGroup(tag string) RuleGroup {
|
||||
return RuleGroup{
|
||||
Tag: tag,
|
||||
Includes: [][]string{},
|
||||
Excludes: []string{},
|
||||
Blocks: []string{},
|
||||
}
|
||||
}
|
||||
|
||||
func (g RuleGroup) Inc(l ...string) RuleGroup {
|
||||
return RuleGroup{
|
||||
Tag: g.Tag,
|
||||
Includes: append(g.Includes, l),
|
||||
Excludes: g.Excludes,
|
||||
Blocks: g.Blocks,
|
||||
}
|
||||
}
|
||||
|
||||
func (g RuleGroup) Exc(l ...string) RuleGroup {
|
||||
return RuleGroup{
|
||||
Tag: g.Tag,
|
||||
Includes: g.Includes,
|
||||
Excludes: append(g.Excludes, l...),
|
||||
Blocks: g.Blocks,
|
||||
}
|
||||
}
|
||||
|
||||
func (g RuleGroup) Block(l ...string) RuleGroup {
|
||||
return RuleGroup{
|
||||
Tag: g.Tag,
|
||||
Includes: g.Includes,
|
||||
Excludes: g.Excludes,
|
||||
Blocks: append(g.Blocks, l...),
|
||||
}
|
||||
}
|
||||
|
||||
func (g RuleGroup) ToList() (l []Rule) {
|
||||
for _, includes := range g.Includes {
|
||||
l = append(l, Rule{
|
||||
Tag: g.Tag,
|
||||
Excludes: g.Excludes,
|
||||
Includes: includes,
|
||||
Blocks: g.Blocks,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
162
tagengine/ruleset.go
Normal file
162
tagengine/ruleset.go
Normal file
@ -0,0 +1,162 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"sort"
|
||||
)
|
||||
|
||||
type RuleSet struct {
|
||||
root *node
|
||||
maxNgram int
|
||||
sanitize func(string) string
|
||||
rules []*Rule
|
||||
}
|
||||
|
||||
func NewRuleSet() *RuleSet {
|
||||
return &RuleSet{
|
||||
root: &node{
|
||||
Token: "/",
|
||||
Children: map[string]*node{},
|
||||
},
|
||||
sanitize: BasicSanitizer,
|
||||
rules: []*Rule{},
|
||||
}
|
||||
}
|
||||
|
||||
func NewRuleSetFromList(rules []Rule) *RuleSet {
|
||||
rs := NewRuleSet()
|
||||
rs.AddRule(rules...)
|
||||
return rs
|
||||
}
|
||||
|
||||
func (t *RuleSet) Add(ruleOrGroup ...interface{}) {
|
||||
for _, ix := range ruleOrGroup {
|
||||
switch x := ix.(type) {
|
||||
case Rule:
|
||||
t.AddRule(x)
|
||||
case RuleGroup:
|
||||
t.AddRuleGroup(x)
|
||||
default:
|
||||
panic("Add expects either Rule or RuleGroup objects.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RuleSet) AddRule(rules ...Rule) {
|
||||
for _, rule := range rules {
|
||||
rule := rule
|
||||
|
||||
// Make sure rule is well-formed.
|
||||
rule.normalize(t.sanitize)
|
||||
|
||||
// Update maxNgram.
|
||||
N := rule.maxNGram()
|
||||
if N > t.maxNgram {
|
||||
t.maxNgram = N
|
||||
}
|
||||
|
||||
t.rules = append(t.rules, &rule)
|
||||
t.root.AddRule(&rule)
|
||||
}
|
||||
}
|
||||
|
||||
func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) {
|
||||
for _, rg := range ruleGroups {
|
||||
t.AddRule(rg.ToList()...)
|
||||
}
|
||||
}
|
||||
|
||||
// MatchRules will return a list of all matching rules. The rules are sorted by
|
||||
// the match's Score. The best match will be first.
|
||||
func (t *RuleSet) MatchRules(input string) (rules []*Rule) {
|
||||
input = t.sanitize(input)
|
||||
tokens := Tokenize(input, t.maxNgram)
|
||||
|
||||
rules = t.root.Match(tokens)
|
||||
if len(rules) == 0 {
|
||||
return rules
|
||||
}
|
||||
|
||||
// Check excludes.
|
||||
l := rules[:0]
|
||||
for _, r := range rules {
|
||||
if !r.isExcluded(tokens) {
|
||||
l = append(l, r)
|
||||
}
|
||||
}
|
||||
|
||||
rules = l
|
||||
|
||||
// Sort rules descending.
|
||||
sort.Slice(rules, func(i, j int) bool {
|
||||
return ruleLess(rules[j], rules[i])
|
||||
})
|
||||
|
||||
return rules
|
||||
}
|
||||
|
||||
type Match struct {
|
||||
Tag string
|
||||
|
||||
// Confidence is used to sort all matches, and is normalized so the sum of
|
||||
// Confidence values for all matches is 1. Confidence is relative to the
|
||||
// number of matches and the size of matches in terms of number of tokens.
|
||||
Confidence float64 // In the range (0,1].
|
||||
}
|
||||
|
||||
// Return a list of matches with confidence. This is useful if you'd like to
|
||||
// find the best matching rule out of all the matched rules.
|
||||
//
|
||||
// If you just want to find all matching rules, then use MatchRules.
|
||||
func (t *RuleSet) Match(input string) []Match {
|
||||
rules := t.MatchRules(input)
|
||||
if len(rules) == 0 {
|
||||
return []Match{}
|
||||
}
|
||||
if len(rules) == 1 {
|
||||
return []Match{{
|
||||
Tag: rules[0].Tag,
|
||||
Confidence: 1,
|
||||
}}
|
||||
}
|
||||
|
||||
// Create list of blocked tags.
|
||||
blocks := map[string]struct{}{}
|
||||
for _, rule := range rules {
|
||||
for _, tag := range rule.Blocks {
|
||||
blocks[tag] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove rules for blocked tags.
|
||||
iOut := 0
|
||||
for _, rule := range rules {
|
||||
if _, ok := blocks[rule.Tag]; ok {
|
||||
continue
|
||||
}
|
||||
rules[iOut] = rule
|
||||
iOut++
|
||||
}
|
||||
rules = rules[:iOut]
|
||||
|
||||
// Matches by index.
|
||||
matches := map[string]int{}
|
||||
out := []Match{}
|
||||
sum := float64(0)
|
||||
|
||||
for _, rule := range rules {
|
||||
idx, ok := matches[rule.Tag]
|
||||
if !ok {
|
||||
idx = len(matches)
|
||||
matches[rule.Tag] = idx
|
||||
out = append(out, Match{Tag: rule.Tag})
|
||||
}
|
||||
out[idx].Confidence += float64(rule.Score)
|
||||
sum += float64(rule.Score)
|
||||
}
|
||||
|
||||
for i := range out {
|
||||
out[i].Confidence /= sum
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
84
tagengine/ruleset_test.go
Normal file
84
tagengine/ruleset_test.go
Normal file
@ -0,0 +1,84 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRulesSet(t *testing.T) {
|
||||
rs := NewRuleSet()
|
||||
rs.AddRule(Rule{
|
||||
Tag: "cc/2",
|
||||
Includes: []string{"cola", "coca"},
|
||||
})
|
||||
rs.AddRule(Rule{
|
||||
Tag: "cc/0",
|
||||
Includes: []string{"coca cola"},
|
||||
})
|
||||
rs.AddRule(Rule{
|
||||
Tag: "cz/2",
|
||||
Includes: []string{"coca", "zero"},
|
||||
})
|
||||
rs.AddRule(Rule{
|
||||
Tag: "cc0/3",
|
||||
Includes: []string{"zero", "coca", "cola"},
|
||||
})
|
||||
rs.AddRule(Rule{
|
||||
Tag: "cc0/3.1",
|
||||
Includes: []string{"coca", "cola", "zero"},
|
||||
Excludes: []string{"pepsi"},
|
||||
})
|
||||
rs.AddRule(Rule{
|
||||
Tag: "spa",
|
||||
Includes: []string{"spa"},
|
||||
Blocks: []string{"cc/0", "cc0/3", "cc0/3.1"},
|
||||
})
|
||||
|
||||
type TestCase struct {
|
||||
Input string
|
||||
Matches []Match
|
||||
}
|
||||
|
||||
cases := []TestCase{
|
||||
{
|
||||
Input: "coca-cola zero",
|
||||
Matches: []Match{
|
||||
{"cc0/3.1", 0.3},
|
||||
{"cc0/3", 0.3},
|
||||
{"cz/2", 0.2},
|
||||
{"cc/2", 0.2},
|
||||
},
|
||||
}, {
|
||||
Input: "coca cola",
|
||||
Matches: []Match{
|
||||
{"cc/0", 0.6},
|
||||
{"cc/2", 0.4},
|
||||
},
|
||||
}, {
|
||||
Input: "coca cola zero pepsi",
|
||||
Matches: []Match{
|
||||
{"cc0/3", 0.3},
|
||||
{"cc/0", 0.3},
|
||||
{"cz/2", 0.2},
|
||||
{"cc/2", 0.2},
|
||||
},
|
||||
}, {
|
||||
Input: "fanta orange",
|
||||
Matches: []Match{},
|
||||
}, {
|
||||
Input: "coca-cola zero / fanta / spa",
|
||||
Matches: []Match{
|
||||
{"cz/2", 0.4},
|
||||
{"cc/2", 0.4},
|
||||
{"spa", 0.2},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
matches := rs.Match(tc.Input)
|
||||
if !reflect.DeepEqual(matches, tc.Matches) {
|
||||
t.Fatalf("%v != %v", matches, tc.Matches)
|
||||
}
|
||||
}
|
||||
}
|
20
tagengine/sanitize.go
Normal file
20
tagengine/sanitize.go
Normal file
@ -0,0 +1,20 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"git.crumpington.com/lib/tagengine/sanitize"
|
||||
)
|
||||
|
||||
// The basic sanitizer:
|
||||
// * lower-case
|
||||
// * put spaces around numbers
|
||||
// * put slaces around punctuation
|
||||
// * collapse multiple spaces
|
||||
func BasicSanitizer(s string) string {
|
||||
s = strings.ToLower(s)
|
||||
s = sanitize.SpaceNumbers(s)
|
||||
s = sanitize.SpacePunctuation(s)
|
||||
s = sanitize.CollapseSpaces(s)
|
||||
return s
|
||||
}
|
91
tagengine/sanitize/sanitize.go
Normal file
91
tagengine/sanitize/sanitize.go
Normal file
@ -0,0 +1,91 @@
|
||||
package sanitize
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
func SpaceNumbers(s string) string {
|
||||
if len(s) == 0 {
|
||||
return s
|
||||
}
|
||||
|
||||
isDigit := func(b rune) bool {
|
||||
switch b {
|
||||
case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
b := strings.Builder{}
|
||||
|
||||
var first rune
|
||||
for _, c := range s {
|
||||
first = c
|
||||
break
|
||||
}
|
||||
|
||||
digit := isDigit(first)
|
||||
|
||||
// Range over runes.
|
||||
for _, c := range s {
|
||||
thisDigit := isDigit(c)
|
||||
if thisDigit != digit {
|
||||
b.WriteByte(' ')
|
||||
digit = thisDigit
|
||||
}
|
||||
b.WriteRune(c)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func SpacePunctuation(s string) string {
|
||||
needsSpace := func(r rune) bool {
|
||||
switch r {
|
||||
case '`', '~', '!', '@', '#', '%', '^', '&', '*', '(', ')',
|
||||
'-', '_', '+', '=', '[', '{', ']', '}', '\\', '|',
|
||||
':', ';', '"', '\'', ',', '<', '.', '>', '?', '/':
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
b := strings.Builder{}
|
||||
|
||||
// Range over runes.
|
||||
for _, r := range s {
|
||||
if needsSpace(r) {
|
||||
b.WriteRune(' ')
|
||||
b.WriteRune(r)
|
||||
b.WriteRune(' ')
|
||||
} else {
|
||||
b.WriteRune(r)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func CollapseSpaces(s string) string {
|
||||
// Trim leading and trailing spaces.
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
b := strings.Builder{}
|
||||
wasSpace := false
|
||||
|
||||
// Range over runes.
|
||||
for _, c := range s {
|
||||
if unicode.IsSpace(c) {
|
||||
wasSpace = true
|
||||
continue
|
||||
} else if wasSpace {
|
||||
wasSpace = false
|
||||
b.WriteRune(' ')
|
||||
}
|
||||
b.WriteRune(c)
|
||||
}
|
||||
|
||||
return b.String()
|
||||
}
|
30
tagengine/sanitize_test.go
Normal file
30
tagengine/sanitize_test.go
Normal file
@ -0,0 +1,30 @@
|
||||
package tagengine
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestSanitize(t *testing.T) {
|
||||
sanitize := BasicSanitizer
|
||||
|
||||
type Case struct {
|
||||
In string
|
||||
Out string
|
||||
}
|
||||
|
||||
cases := []Case{
|
||||
{"", ""},
|
||||
{"123abc", "123 abc"},
|
||||
{"abc123", "abc 123"},
|
||||
{"abc123xyz", "abc 123 xyz"},
|
||||
{"1f2", "1 f 2"},
|
||||
{" abc", "abc"},
|
||||
{" ; KitKat/m&m's (bottle) @ ", "; kitkat / m & m ' s ( bottle ) @"},
|
||||
{"€", "€"},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
out := sanitize(tc.In)
|
||||
if out != tc.Out {
|
||||
t.Fatalf("%v != %v", out, tc.Out)
|
||||
}
|
||||
}
|
||||
}
|
63
tagengine/tokenize.go
Normal file
63
tagengine/tokenize.go
Normal file
@ -0,0 +1,63 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var ignoreTokens = map[string]struct{}{}
|
||||
|
||||
func init() {
|
||||
// These on their own are ignored.
|
||||
tokens := []string{
|
||||
"`", `~`, `!`, `@`, `#`, `%`, `^`, `&`, `*`, `(`, `)`,
|
||||
`-`, `_`, `+`, `=`, `[`, `{`, `]`, `}`, `\`, `|`,
|
||||
`:`, `;`, `"`, `'`, `,`, `<`, `.`, `>`, `?`, `/`,
|
||||
}
|
||||
for _, s := range tokens {
|
||||
ignoreTokens[s] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
func Tokenize(
|
||||
input string,
|
||||
maxNgram int,
|
||||
) (
|
||||
tokens []string,
|
||||
) {
|
||||
// Avoid duplicate ngrams.
|
||||
ignored := map[string]bool{}
|
||||
|
||||
fields := strings.Fields(input)
|
||||
|
||||
if len(fields) < maxNgram {
|
||||
maxNgram = len(fields)
|
||||
}
|
||||
|
||||
for i := 1; i < maxNgram+1; i++ {
|
||||
jMax := len(fields) - i + 1
|
||||
|
||||
for j := 0; j < jMax; j++ {
|
||||
ngram := strings.Join(fields[j:i+j], " ")
|
||||
if _, ok := ignoreTokens[ngram]; !ok {
|
||||
if _, ok := ignored[ngram]; !ok {
|
||||
tokens = append(tokens, ngram)
|
||||
ignored[ngram] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sortTokens(tokens)
|
||||
|
||||
return tokens
|
||||
}
|
||||
|
||||
func sortTokens(tokens []string) {
|
||||
sort.Slice(tokens, func(i, j int) bool {
|
||||
if len(tokens[i]) != len(tokens[j]) {
|
||||
return len(tokens[i]) < len(tokens[j])
|
||||
}
|
||||
return tokens[i] < tokens[j]
|
||||
})
|
||||
}
|
55
tagengine/tokenize_test.go
Normal file
55
tagengine/tokenize_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package tagengine
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestTokenize(t *testing.T) {
|
||||
type Case struct {
|
||||
Input string
|
||||
MaxNgram int
|
||||
Output []string
|
||||
}
|
||||
|
||||
cases := []Case{
|
||||
{
|
||||
Input: "a bb c d",
|
||||
MaxNgram: 3,
|
||||
Output: []string{
|
||||
"a", "c", "d", "bb",
|
||||
"c d", "a bb", "bb c",
|
||||
"a bb c", "bb c d",
|
||||
},
|
||||
}, {
|
||||
Input: "a b",
|
||||
MaxNgram: 3,
|
||||
Output: []string{
|
||||
"a", "b", "a b",
|
||||
},
|
||||
}, {
|
||||
Input: "- b c d",
|
||||
MaxNgram: 3,
|
||||
Output: []string{
|
||||
"b", "c", "d",
|
||||
"- b", "b c", "c d",
|
||||
"- b c", "b c d",
|
||||
},
|
||||
}, {
|
||||
Input: "a a b c d c d",
|
||||
MaxNgram: 3,
|
||||
Output: []string{
|
||||
"a", "b", "c", "d",
|
||||
"a a", "a b", "b c", "c d", "d c",
|
||||
"a a b", "a b c", "b c d", "c d c", "d c d",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
output := Tokenize(tc.Input, tc.MaxNgram)
|
||||
if !reflect.DeepEqual(output, tc.Output) {
|
||||
t.Fatalf("%s: %#v", tc.Input, output)
|
||||
}
|
||||
}
|
||||
}
|
5
webutil/README.md
Normal file
5
webutil/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# webutil
|
||||
|
||||
## Roadmap
|
||||
|
||||
* logging middleware
|
10
webutil/go.mod
Normal file
10
webutil/go.mod
Normal file
@ -0,0 +1,10 @@
|
||||
module git.crumpington.com/lib/webutil
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require golang.org/x/crypto v0.28.0
|
||||
|
||||
require (
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/text v0.19.0 // indirect
|
||||
)
|
6
webutil/go.sum
Normal file
6
webutil/go.sum
Normal file
@ -0,0 +1,6 @@
|
||||
golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw=
|
||||
golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM=
|
||||
golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY=
|
24
webutil/listenandserve.go
Normal file
24
webutil/listenandserve.go
Normal file
@ -0,0 +1,24 @@
|
||||
package webutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/acme/autocert"
|
||||
)
|
||||
|
||||
// Serve requests using the given http.Server. If srv.Addr has the format
|
||||
// `hostname:https`, then use autocert to manage certificates for the domain.
|
||||
//
|
||||
// For http on port 80, you can use :http.
|
||||
func ListenAndServe(srv *http.Server) error {
|
||||
if strings.HasSuffix(srv.Addr, ":https") {
|
||||
hostname := strings.TrimSuffix(srv.Addr, ":https")
|
||||
if len(hostname) == 0 {
|
||||
return errors.New("https requires a hostname")
|
||||
}
|
||||
return srv.Serve(autocert.NewListener(hostname))
|
||||
}
|
||||
return srv.ListenAndServe()
|
||||
}
|
47
webutil/middleware-logging.go
Normal file
47
webutil/middleware-logging.go
Normal file
@ -0,0 +1,47 @@
|
||||
package webutil
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
var _log = log.New(os.Stderr, "", 0)
|
||||
|
||||
type responseWriterWrapper struct {
|
||||
http.ResponseWriter
|
||||
httpStatus int
|
||||
responseSize int
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) WriteHeader(status int) {
|
||||
w.httpStatus = status
|
||||
w.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
func (w *responseWriterWrapper) Write(b []byte) (int, error) {
|
||||
if w.httpStatus == 0 {
|
||||
w.httpStatus = 200
|
||||
}
|
||||
w.responseSize += len(b)
|
||||
return w.ResponseWriter.Write(b)
|
||||
}
|
||||
|
||||
func WithLogging(inner http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
t := time.Now()
|
||||
wrapper := responseWriterWrapper{w, 0, 0}
|
||||
|
||||
inner(&wrapper, r)
|
||||
_log.Printf("%s \"%s %s %s\" %d %d %v\n",
|
||||
r.RemoteAddr,
|
||||
r.Method,
|
||||
r.URL.Path,
|
||||
r.Proto,
|
||||
wrapper.httpStatus,
|
||||
wrapper.responseSize,
|
||||
time.Since(t),
|
||||
)
|
||||
}
|
||||
}
|
100
webutil/template.go
Normal file
100
webutil/template.go
Normal file
@ -0,0 +1,100 @@
|
||||
package webutil
|
||||
|
||||
import (
|
||||
"embed"
|
||||
"html/template"
|
||||
"io/fs"
|
||||
"log"
|
||||
"path"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ParseTemplateSet parses sets of templates from an embed.FS.
|
||||
//
|
||||
// Each directory constitutes a set of templates that are parsed together.
|
||||
//
|
||||
// Structure (within a directory):
|
||||
// - share/* are always parsed.
|
||||
// - base.html will be parsed with each other file in same dir
|
||||
//
|
||||
// Call a template with m[path].Execute(w, data) (root dir name is excluded).
|
||||
//
|
||||
// For example, if you have
|
||||
// - /user/share/*
|
||||
// - /user/base.html
|
||||
// - /user/home.html
|
||||
//
|
||||
// Then you call m["/user/home.html"].Execute(w, data).
|
||||
func ParseTemplateSet(funcs template.FuncMap, fs embed.FS) map[string]*template.Template {
|
||||
m := map[string]*template.Template{}
|
||||
rootDir := readDir(fs, ".")[0].Name()
|
||||
loadTemplateDir(fs, funcs, m, rootDir, rootDir)
|
||||
return m
|
||||
}
|
||||
|
||||
func loadTemplateDir(
|
||||
fs embed.FS,
|
||||
funcs template.FuncMap,
|
||||
m map[string]*template.Template,
|
||||
dirPath string,
|
||||
rootDir string,
|
||||
) map[string]*template.Template {
|
||||
t := template.New("")
|
||||
if funcs != nil {
|
||||
t = t.Funcs(funcs)
|
||||
}
|
||||
|
||||
shareDir := path.Join(dirPath, "share")
|
||||
if _, err := fs.ReadDir(shareDir); err == nil {
|
||||
log.Printf("Parsing %s...", path.Join(shareDir, "*"))
|
||||
t = template.Must(t.ParseFS(fs, path.Join(shareDir, "*")))
|
||||
}
|
||||
|
||||
if data, _ := fs.ReadFile(path.Join(dirPath, "base.html")); data != nil {
|
||||
log.Printf("Parsing %s...", path.Join(dirPath, "base.html"))
|
||||
t = template.Must(t.Parse(string(data)))
|
||||
}
|
||||
|
||||
for _, ent := range readDir(fs, dirPath) {
|
||||
if ent.Type().IsDir() {
|
||||
if ent.Name() != "share" {
|
||||
m = loadTemplateDir(fs, funcs, m, path.Join(dirPath, ent.Name()), rootDir)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if !ent.Type().IsRegular() {
|
||||
continue
|
||||
}
|
||||
|
||||
if ent.Name() == "base.html" {
|
||||
continue
|
||||
}
|
||||
|
||||
filePath := path.Join(dirPath, ent.Name())
|
||||
log.Printf("Parsing %s...", filePath)
|
||||
|
||||
key := strings.TrimPrefix(path.Join(dirPath, ent.Name()), rootDir)
|
||||
tt := template.Must(t.Clone())
|
||||
tt = template.Must(tt.Parse(readFile(fs, filePath)))
|
||||
m[key] = tt
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func readDir(fs embed.FS, dirPath string) []fs.DirEntry {
|
||||
ents, err := fs.ReadDir(dirPath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return ents
|
||||
}
|
||||
|
||||
func readFile(fs embed.FS, path string) string {
|
||||
data, err := fs.ReadFile(path)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(data)
|
||||
}
|
49
webutil/template_test.go
Normal file
49
webutil/template_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package webutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"html/template"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//go:embed all:test-templates
|
||||
var testFS embed.FS
|
||||
|
||||
func TestParseTemplateSet(t *testing.T) {
|
||||
funcs := template.FuncMap{"join": strings.Join}
|
||||
m := ParseTemplateSet(funcs, testFS)
|
||||
|
||||
type TestCase struct {
|
||||
Key string
|
||||
Data any
|
||||
Out string
|
||||
}
|
||||
|
||||
cases := []TestCase{
|
||||
{
|
||||
Key: "/home.html",
|
||||
Data: "DATA",
|
||||
Out: "<p>HOME!</p>",
|
||||
}, {
|
||||
Key: "/about.html",
|
||||
Data: "DATA",
|
||||
Out: "<p><b>DATA</b></p>",
|
||||
}, {
|
||||
Key: "/contact.html",
|
||||
Data: []string{"a", "b", "c"},
|
||||
Out: "<p>a,b,c</p>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
b := &bytes.Buffer{}
|
||||
m[tc.Key].Execute(b, tc.Data)
|
||||
out := strings.TrimSpace(b.String())
|
||||
if out != tc.Out {
|
||||
t.Fatalf("%s != %s", out, tc.Out)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
1
webutil/test-templates/about.html
Normal file
1
webutil/test-templates/about.html
Normal file
@ -0,0 +1 @@
|
||||
{{define "body"}}{{template "bold" .}}{{end}}
|
1
webutil/test-templates/base.html
Normal file
1
webutil/test-templates/base.html
Normal file
@ -0,0 +1 @@
|
||||
<p>{{block "body" .}}default{{end}}</p>
|
1
webutil/test-templates/contact.html
Normal file
1
webutil/test-templates/contact.html
Normal file
@ -0,0 +1 @@
|
||||
{{define "body"}}{{join . ","}}{{end}}
|
1
webutil/test-templates/home.html
Normal file
1
webutil/test-templates/home.html
Normal file
@ -0,0 +1 @@
|
||||
{{define "body"}}HOME!{{end}}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user