wip
This commit is contained in:
		
							
								
								
									
										2
									
								
								flock/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								flock/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # flock | ||||||
|  |  | ||||||
							
								
								
									
										69
									
								
								flock/flock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								flock/flock.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | |||||||
|  | // The flock package provides a file-system mediated locking mechanism on linux | ||||||
|  | // using the `flock` system call. | ||||||
|  |  | ||||||
|  | package flock | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"os" | ||||||
|  | 	"syscall" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Lock gets an exclusive lock on the file at the given path. If the file | ||||||
|  | // doesn't exist, it's created. | ||||||
|  | func Lock(path string) (*os.File, error) { | ||||||
|  | 	return lock(path, syscall.LOCK_EX) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // TryLock will return a nil file if the file is already locked. | ||||||
|  | func TryLock(path string) (*os.File, error) { | ||||||
|  | 	return lock(path, syscall.LOCK_EX|syscall.LOCK_NB) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func LockFile(f *os.File) error { | ||||||
|  | 	_, err := lockFile(f, syscall.LOCK_EX) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Returns true if the lock was successfully acquired. | ||||||
|  | func TryLockFile(f *os.File) (bool, error) { | ||||||
|  | 	return lockFile(f, syscall.LOCK_EX|syscall.LOCK_NB) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func lockFile(f *os.File, flags int) (bool, error) { | ||||||
|  |  | ||||||
|  | 	if err := flock(int(f.Fd()), flags); err != nil { | ||||||
|  | 		if flags&syscall.LOCK_NB != 0 && errors.Is(err, syscall.EAGAIN) { | ||||||
|  | 			return false, nil | ||||||
|  | 		} | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  | 	return true, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func flock(fd int, how int) error { | ||||||
|  | 	_, _, e1 := syscall.Syscall(syscall.SYS_FLOCK, uintptr(fd), uintptr(how), 0) | ||||||
|  | 	if e1 != 0 { | ||||||
|  | 		return syscall.Errno(e1) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func lock(path string, flags int) (*os.File, error) { | ||||||
|  | 	perm := os.O_CREATE | os.O_RDWR | ||||||
|  | 	f, err := os.OpenFile(path, perm, 0600) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	ok, err := lockFile(f, flags) | ||||||
|  | 	if err != nil || !ok { | ||||||
|  | 		f.Close() | ||||||
|  | 		f = nil | ||||||
|  | 	} | ||||||
|  | 	return f, err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Unlock releases the lock acquired via the Lock function. | ||||||
|  | func Unlock(f *os.File) error { | ||||||
|  | 	return f.Close() | ||||||
|  | } | ||||||
							
								
								
									
										66
									
								
								flock/flock_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								flock/flock_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | |||||||
|  | package flock | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Test_Lock_basic(t *testing.T) { | ||||||
|  | 	ch := make(chan int, 1) | ||||||
|  | 	f, err := Lock("/tmp/fsutil-test-lock") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	go func() { | ||||||
|  | 		time.Sleep(time.Second) | ||||||
|  | 		ch <- 10 | ||||||
|  | 		Unlock(f) | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	select { | ||||||
|  | 	case x := <-ch: | ||||||
|  | 		t.Fatal(x) | ||||||
|  | 	default: | ||||||
|  |  | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	f2, _ := Lock("/tmp/fsutil-test-lock") | ||||||
|  | 	defer Unlock(f2) | ||||||
|  | 	select { | ||||||
|  | 	case i := <-ch: | ||||||
|  | 		if i != 10 { | ||||||
|  | 			t.Fatal(i) | ||||||
|  | 		} | ||||||
|  | 	default: | ||||||
|  | 		t.Fatal("No value available.") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Test_Lock_badPath(t *testing.T) { | ||||||
|  | 	_, err := Lock("./dne/file.lock") | ||||||
|  | 	if err == nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestTryLock(t *testing.T) { | ||||||
|  | 	lockPath := "/tmp/fsutil-test-lock" | ||||||
|  | 	f, err := TryLock(lockPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatalf("%#v", err) | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	f2, err := TryLock(lockPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if f2 != nil { | ||||||
|  | 		t.Fatal(f2) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := Unlock(f); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								flock/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								flock/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/flock | ||||||
|  |  | ||||||
|  | go 1.23.0 | ||||||
							
								
								
									
										0
									
								
								flock/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								flock/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										2
									
								
								httpconn/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								httpconn/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # httpconn | ||||||
|  |  | ||||||
							
								
								
									
										104
									
								
								httpconn/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								httpconn/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | |||||||
|  | package httpconn | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bufio" | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"errors" | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/url" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	ErrUnknownScheme = errors.New("uknown scheme") | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Dialer struct { | ||||||
|  | 	timeout time.Duration | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewDialer() *Dialer { | ||||||
|  | 	return &Dialer{timeout: 10 * time.Second} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Dialer) SetTimeout(timeout time.Duration) { | ||||||
|  | 	d.timeout = timeout | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Dialer) Dial(rawURL string) (net.Conn, error) { | ||||||
|  | 	u, err := url.Parse(rawURL) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	switch u.Scheme { | ||||||
|  | 	case "https": | ||||||
|  | 		return d.DialHTTPS(u.Host+":443", u.Path) | ||||||
|  | 	case "http": | ||||||
|  | 		return d.DialHTTP(u.Host, u.Path) | ||||||
|  | 	default: | ||||||
|  | 		return nil, ErrUnknownScheme | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) { | ||||||
|  | 	ctx, cancel := context.WithTimeout(context.Background(), d.timeout) | ||||||
|  | 	dd := tls.Dialer{} | ||||||
|  | 	conn, err := dd.DialContext(ctx, "tcp", host) | ||||||
|  | 	cancel() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return d.finishDialing(conn, host, path) | ||||||
|  |  | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) { | ||||||
|  | 	conn, err := net.DialTimeout("tcp", host, d.timeout) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return d.finishDialing(conn, host, path) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) { | ||||||
|  | 	conn.SetDeadline(time.Now().Add(d.timeout)) | ||||||
|  |  | ||||||
|  | 	if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Require successful HTTP response before using the conn. | ||||||
|  | 	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		conn.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if resp.Status != "200 OK" { | ||||||
|  | 		conn.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	conn.SetDeadline(time.Time{}) | ||||||
|  |  | ||||||
|  | 	return conn, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Dial(rawURL string) (net.Conn, error) { | ||||||
|  | 	return NewDialer().Dial(rawURL) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DialHTTPS(host, path string) (net.Conn, error) { | ||||||
|  | 	return NewDialer().DialHTTPS(host, path) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func DialHTTP(host, path string) (net.Conn, error) { | ||||||
|  | 	return NewDialer().DialHTTP(host, path) | ||||||
|  | } | ||||||
							
								
								
									
										42
									
								
								httpconn/conn_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								httpconn/conn_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | |||||||
|  | package httpconn | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/net/nettest" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNetTest_TestConn(t *testing.T) { | ||||||
|  | 	nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { | ||||||
|  |  | ||||||
|  | 		connCh := make(chan net.Conn, 1) | ||||||
|  | 		doneCh := make(chan bool) | ||||||
|  |  | ||||||
|  | 		mux := http.NewServeMux() | ||||||
|  | 		mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 			conn, err := Accept(w, r) | ||||||
|  | 			if err != nil { | ||||||
|  | 				panic(err) | ||||||
|  | 			} | ||||||
|  | 			connCh <- conn | ||||||
|  | 			<-doneCh | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		srv := httptest.NewServer(mux) | ||||||
|  |  | ||||||
|  | 		c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect") | ||||||
|  | 		if err != nil { | ||||||
|  | 			panic(err) | ||||||
|  | 		} | ||||||
|  | 		c2 = <-connCh | ||||||
|  |  | ||||||
|  | 		return c1, c2, func() { | ||||||
|  | 			doneCh <- true | ||||||
|  | 			srv.Close() | ||||||
|  | 		}, nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								httpconn/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								httpconn/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | module git.crumpington.com/lib/httpconn | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
|  |  | ||||||
|  | require golang.org/x/net v0.30.0 | ||||||
							
								
								
									
										2
									
								
								httpconn/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								httpconn/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= | ||||||
|  | golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= | ||||||
							
								
								
									
										32
									
								
								httpconn/server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								httpconn/server.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | |||||||
|  | package httpconn | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"io" | ||||||
|  | 	"net" | ||||||
|  | 	"net/http" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) { | ||||||
|  | 	if r.Method != "CONNECT" { | ||||||
|  | 		w.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||||
|  | 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||||
|  | 		io.WriteString(w, "405 must CONNECT\n") | ||||||
|  | 		return nil, http.ErrNotSupported | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	hj, ok := w.(http.Hijacker) | ||||||
|  | 	if !ok { | ||||||
|  | 		return nil, http.ErrNotSupported | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	conn, _, err := hj.Hijack() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n") | ||||||
|  | 	conn.SetDeadline(time.Time{}) | ||||||
|  |  | ||||||
|  | 	return conn, nil | ||||||
|  | } | ||||||
							
								
								
									
										2
									
								
								idgen/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								idgen/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # idgen | ||||||
|  |  | ||||||
							
								
								
									
										3
									
								
								idgen/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								idgen/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/idgen | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
							
								
								
									
										53
									
								
								idgen/idgen.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								idgen/idgen.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | |||||||
|  | package idgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/base32" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Creates a new, random token. | ||||||
|  | func NewToken() string { | ||||||
|  | 	buf := make([]byte, 20) | ||||||
|  | 	_, err := rand.Read(buf) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	return base32.StdEncoding.EncodeToString(buf) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ( | ||||||
|  | 	lock       sync.Mutex | ||||||
|  | 	ts         int64 = time.Now().Unix() | ||||||
|  | 	counter    int64 = 1 | ||||||
|  | 	counterMax int64 = 1 << 20 | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // NextID can generate ~1M ints per second for a given nodeID. | ||||||
|  | // | ||||||
|  | // nodeID must be less than 64. | ||||||
|  | func NextID(nodeID int64) int64 { | ||||||
|  | 	lock.Lock() | ||||||
|  | 	defer lock.Unlock() | ||||||
|  |  | ||||||
|  | 	tt := time.Now().Unix() | ||||||
|  | 	if tt > ts { | ||||||
|  | 		ts = tt | ||||||
|  | 		counter = 1 | ||||||
|  | 	} else { | ||||||
|  | 		counter++ | ||||||
|  | 		if counter == counterMax { | ||||||
|  | 			panic("Too many IDs.") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return (ts << 26) + (nodeID << 20) + counter | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SplitID(id int64) (unixTime, nodeID, counter int64) { | ||||||
|  | 	counter = id & (0x00000000000FFFFF) | ||||||
|  | 	nodeID = (id >> 20) & (0x000000000000003F) | ||||||
|  | 	unixTime = id >> 26 | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										19
									
								
								idgen/idgen_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								idgen/idgen_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | |||||||
|  | package idgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func BenchmarkNext(b *testing.B) { | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		NextID(0) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestNextID(t *testing.T) { | ||||||
|  | 	id := NextID(32) | ||||||
|  |  | ||||||
|  | 	a, b, c := SplitID(id) | ||||||
|  | 	log.Print(a, b, c) | ||||||
|  | } | ||||||
							
								
								
									
										2
									
								
								keyedmutex/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								keyedmutex/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # keyedmutex | ||||||
|  |  | ||||||
							
								
								
									
										3
									
								
								keyedmutex/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								keyedmutex/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/keyedmutex | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
							
								
								
									
										67
									
								
								keyedmutex/keyedmutex.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								keyedmutex/keyedmutex.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | |||||||
|  | package keyedmutex | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"container/list" | ||||||
|  | 	"sync" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type KeyedMutex[K comparable] struct { | ||||||
|  | 	mu       *sync.Mutex | ||||||
|  | 	waitList map[K]*list.List | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func New[K comparable]() KeyedMutex[K] { | ||||||
|  | 	return KeyedMutex[K]{ | ||||||
|  | 		mu:       new(sync.Mutex), | ||||||
|  | 		waitList: map[K]*list.List{}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m KeyedMutex[K]) Lock(key K) { | ||||||
|  | 	if ch := m.lock(key); ch != nil { | ||||||
|  | 		<-ch | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m KeyedMutex[K]) lock(key K) chan struct{} { | ||||||
|  | 	m.mu.Lock() | ||||||
|  | 	defer m.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	if waitList, ok := m.waitList[key]; ok { | ||||||
|  | 		ch := make(chan struct{}) | ||||||
|  | 		waitList.PushBack(ch) | ||||||
|  | 		return ch | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	m.waitList[key] = list.New() | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m KeyedMutex[K]) TryLock(key K) bool { | ||||||
|  | 	m.mu.Lock() | ||||||
|  | 	defer m.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	if _, ok := m.waitList[key]; ok { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	m.waitList[key] = list.New() | ||||||
|  | 	return true | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (m KeyedMutex[K]) Unlock(key K) { | ||||||
|  | 	m.mu.Lock() | ||||||
|  | 	defer m.mu.Unlock() | ||||||
|  |  | ||||||
|  | 	waitList, ok := m.waitList[key] | ||||||
|  | 	if !ok { | ||||||
|  | 		panic("unlock of unlocked mutex") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if waitList.Len() == 0 { | ||||||
|  | 		delete(m.waitList, key) | ||||||
|  | 	} else { | ||||||
|  | 		ch := waitList.Remove(waitList.Front()).(chan struct{}) | ||||||
|  | 		ch <- struct{}{} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										123
									
								
								keyedmutex/keyedmutex_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								keyedmutex/keyedmutex_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | |||||||
|  | package keyedmutex | ||||||
|  |  | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestKeyedMutex(t *testing.T) { | ||||||
|  | 	checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) { | ||||||
|  | 		if len(m.waitList) != len(keys) { | ||||||
|  | 			t.Fatal(m.waitList, keys) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		for _, key := range keys { | ||||||
|  | 			if _, ok := m.waitList[key]; !ok { | ||||||
|  | 				t.Fatal(key) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	m := New[string]() | ||||||
|  | 	checkState(t, m) | ||||||
|  |  | ||||||
|  | 	m.Lock("a") | ||||||
|  | 	checkState(t, m, "a") | ||||||
|  | 	m.Lock("b") | ||||||
|  | 	checkState(t, m, "a", "b") | ||||||
|  | 	m.Lock("c") | ||||||
|  | 	checkState(t, m, "a", "b", "c") | ||||||
|  |  | ||||||
|  | 	if m.TryLock("a") { | ||||||
|  | 		t.Fatal("a") | ||||||
|  | 	} | ||||||
|  | 	if m.TryLock("b") { | ||||||
|  | 		t.Fatal("b") | ||||||
|  | 	} | ||||||
|  | 	if m.TryLock("c") { | ||||||
|  | 		t.Fatal("c") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if !m.TryLock("d") { | ||||||
|  | 		t.Fatal("d") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	checkState(t, m, "a", "b", "c", "d") | ||||||
|  |  | ||||||
|  | 	if !m.TryLock("e") { | ||||||
|  | 		t.Fatal("e") | ||||||
|  | 	} | ||||||
|  | 	checkState(t, m, "a", "b", "c", "d", "e") | ||||||
|  |  | ||||||
|  | 	m.Unlock("c") | ||||||
|  | 	checkState(t, m, "a", "b", "d", "e") | ||||||
|  | 	m.Unlock("a") | ||||||
|  | 	checkState(t, m, "b", "d", "e") | ||||||
|  | 	m.Unlock("e") | ||||||
|  | 	checkState(t, m, "b", "d") | ||||||
|  |  | ||||||
|  | 	wg := sync.WaitGroup{} | ||||||
|  | 	for i := 0; i < 8; i++ { | ||||||
|  | 		wg.Add(1) | ||||||
|  | 		go func() { | ||||||
|  | 			defer wg.Done() | ||||||
|  | 			m.Lock("b") | ||||||
|  | 			m.Unlock("b") | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	time.Sleep(100 * time.Millisecond) | ||||||
|  | 	m.Unlock("b") | ||||||
|  | 	wg.Wait() | ||||||
|  |  | ||||||
|  | 	checkState(t, m, "d") | ||||||
|  |  | ||||||
|  | 	m.Unlock("d") | ||||||
|  | 	checkState(t, m) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestKeyedMutex_unlockUnlocked(t *testing.T) { | ||||||
|  | 	defer func() { | ||||||
|  | 		if r := recover(); r == nil { | ||||||
|  | 			t.Fatal(r) | ||||||
|  | 		} | ||||||
|  | 	}() | ||||||
|  |  | ||||||
|  | 	m := New[string]() | ||||||
|  | 	m.Unlock("aldkfj") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkUncontendedMutex(b *testing.B) { | ||||||
|  | 	m := New[string]() | ||||||
|  | 	key := "xyz" | ||||||
|  |  | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		m.Lock(key) | ||||||
|  | 		m.Unlock(key) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func BenchmarkContendedMutex(b *testing.B) { | ||||||
|  | 	m := New[string]() | ||||||
|  | 	key := "xyz" | ||||||
|  |  | ||||||
|  | 	m.Lock(key) | ||||||
|  |  | ||||||
|  | 	wg := sync.WaitGroup{} | ||||||
|  | 	for i := 0; i < b.N; i++ { | ||||||
|  | 		wg.Add(1) | ||||||
|  | 		go func() { | ||||||
|  | 			defer wg.Done() | ||||||
|  | 			m.Lock(key) | ||||||
|  | 			m.Unlock(key) | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	time.Sleep(time.Second) | ||||||
|  |  | ||||||
|  | 	b.ResetTimer() | ||||||
|  | 	m.Unlock(key) | ||||||
|  | 	wg.Wait() | ||||||
|  | } | ||||||
							
								
								
									
										2
									
								
								kvmemcache/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								kvmemcache/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # kvmemcache | ||||||
|  |  | ||||||
							
								
								
									
										134
									
								
								kvmemcache/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								kvmemcache/cache.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | |||||||
|  | package kvmemcache | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"container/list" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/keyedmutex" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type Cache[K comparable, V any] struct { | ||||||
|  | 	updateLock keyedmutex.KeyedMutex[K] | ||||||
|  | 	src        func(K) (V, error) | ||||||
|  | 	ttl        time.Duration | ||||||
|  | 	maxSize    int | ||||||
|  |  | ||||||
|  | 	// Lock protects variables below. | ||||||
|  | 	lock  sync.Mutex | ||||||
|  | 	cache map[K]*list.Element | ||||||
|  | 	ll    *list.List | ||||||
|  | 	stats Stats | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type lruItem[K comparable, V any] struct { | ||||||
|  | 	key       K | ||||||
|  | 	createdAt time.Time | ||||||
|  | 	value     V | ||||||
|  | 	err       error | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Config[K comparable, V any] struct { | ||||||
|  | 	MaxSize int | ||||||
|  | 	TTL     time.Duration // Zero to ignore. | ||||||
|  | 	Src     func(K) (V, error) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] { | ||||||
|  | 	return &Cache[K, V]{ | ||||||
|  | 		updateLock: keyedmutex.New[K](), | ||||||
|  | 		src:        conf.Src, | ||||||
|  | 		ttl:        conf.TTL, | ||||||
|  | 		maxSize:    conf.MaxSize, | ||||||
|  | 		lock:       sync.Mutex{}, | ||||||
|  | 		cache:      make(map[K]*list.Element, conf.MaxSize+1), | ||||||
|  | 		ll:         list.New(), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) Get(key K) (V, error) { | ||||||
|  | 	ok, val, err := c.get(key) | ||||||
|  | 	if ok { | ||||||
|  | 		return val, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return c.load(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) Evict(key K) { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  | 	c.evict(key) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) Stats() Stats { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  | 	return c.stats | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) put(key K, value V, err error) { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	c.stats.Misses++ | ||||||
|  |  | ||||||
|  | 	c.cache[key] = c.ll.PushFront(lruItem[K, V]{ | ||||||
|  | 		key:       key, | ||||||
|  | 		createdAt: time.Now(), | ||||||
|  | 		value:     value, | ||||||
|  | 		err:       err, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	if c.maxSize != 0 && len(c.cache) > c.maxSize { | ||||||
|  | 		li := c.ll.Back() | ||||||
|  | 		c.ll.Remove(li) | ||||||
|  | 		delete(c.cache, li.Value.(lruItem[K, V]).key) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) evict(key K) { | ||||||
|  | 	elem := c.cache[key] | ||||||
|  | 	if elem != nil { | ||||||
|  | 		delete(c.cache, key) | ||||||
|  | 		c.ll.Remove(elem) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	li := c.cache[key] | ||||||
|  | 	if li == nil { | ||||||
|  | 		return false, val, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	item := li.Value.(lruItem[K, V]) | ||||||
|  | 	// Maybe evict. | ||||||
|  | 	if c.ttl != 0 && time.Since(item.createdAt) > c.ttl { | ||||||
|  | 		c.evict(key) | ||||||
|  | 		return false, val, nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	c.stats.Hits++ | ||||||
|  |  | ||||||
|  | 	c.ll.MoveToFront(li) | ||||||
|  | 	return true, item.value, item.err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K, V]) load(key K) (V, error) { | ||||||
|  | 	c.updateLock.Lock(key) | ||||||
|  | 	defer c.updateLock.Unlock(key) | ||||||
|  |  | ||||||
|  | 	// Check again in case we lost the update race. | ||||||
|  | 	ok, val, err := c.get(key) | ||||||
|  | 	if ok { | ||||||
|  | 		return val, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Won the update race. | ||||||
|  | 	val, err = c.src(key) | ||||||
|  | 	c.put(key, val, err) | ||||||
|  | 	return val, err | ||||||
|  | } | ||||||
							
								
								
									
										249
									
								
								kvmemcache/cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										249
									
								
								kvmemcache/cache_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,249 @@ | |||||||
|  | package kvmemcache | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type State[K comparable] struct { | ||||||
|  | 	Keys  []K | ||||||
|  | 	Stats Stats | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (c *Cache[K,V]) assert(state State[K]) error { | ||||||
|  | 	c.lock.Lock() | ||||||
|  | 	defer c.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	if len(c.cache) != len(state.Keys) { | ||||||
|  | 		return fmt.Errorf( | ||||||
|  | 			"Expected %d keys but found %d.", | ||||||
|  | 			len(state.Keys), | ||||||
|  | 			len(c.cache)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, k := range state.Keys { | ||||||
|  | 		if _, ok := c.cache[k]; !ok { | ||||||
|  | 			return fmt.Errorf( | ||||||
|  | 				"Expected key %v not found.", | ||||||
|  | 				k) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if c.stats.Hits != state.Stats.Hits { | ||||||
|  | 		return fmt.Errorf( | ||||||
|  | 			"Expected %d hits, but found %d.", | ||||||
|  | 			state.Stats.Hits, | ||||||
|  | 			c.stats.Hits) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if c.stats.Misses != state.Stats.Misses { | ||||||
|  | 		return fmt.Errorf( | ||||||
|  | 			"Expected %d misses, but found %d.", | ||||||
|  | 			state.Stats.Misses, | ||||||
|  | 			c.stats.Misses) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | var ErrTest = errors.New("Hello") | ||||||
|  |  | ||||||
|  | func TestCache_basic(t *testing.T) { | ||||||
|  | 	c := New(Config[string, string]{ | ||||||
|  | 		MaxSize: 4, | ||||||
|  | 		TTL:     50 * time.Millisecond, | ||||||
|  | 		Src: func(key string) (string, error) { | ||||||
|  | 			if key == "err" { | ||||||
|  | 				return "", ErrTest | ||||||
|  | 			} | ||||||
|  | 			return key, nil | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	type testCase struct { | ||||||
|  | 		name  string | ||||||
|  | 		sleep time.Duration | ||||||
|  | 		key   string | ||||||
|  | 		evict bool | ||||||
|  | 		state State[string] | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []testCase{ | ||||||
|  | 		{ | ||||||
|  | 			name: "get a", | ||||||
|  | 			key:  "a", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a"}, | ||||||
|  | 				Stats: Stats{Hits: 0, Misses: 1}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get a again", | ||||||
|  | 			key:  "a", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a"}, | ||||||
|  | 				Stats: Stats{Hits: 1, Misses: 1}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name:  "sleep, then get a again", | ||||||
|  | 			sleep: 55 * time.Millisecond, | ||||||
|  | 			key:   "a", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a"}, | ||||||
|  | 				Stats: Stats{Hits: 1, Misses: 2}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get b", | ||||||
|  | 			key:  "b", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "b"}, | ||||||
|  | 				Stats: Stats{Hits: 1, Misses: 3}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get c", | ||||||
|  | 			key:  "c", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "b", "c"}, | ||||||
|  | 				Stats: Stats{Hits: 1, Misses: 4}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get d", | ||||||
|  | 			key:  "d", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "b", "c", "d"}, | ||||||
|  | 				Stats: Stats{Hits: 1, Misses: 5}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get e", | ||||||
|  | 			key:  "e", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"b", "c", "d", "e"}, | ||||||
|  | 				Stats: Stats{Hits: 1, Misses: 6}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get c again", | ||||||
|  | 			key:  "c", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"b", "c", "d", "e"}, | ||||||
|  | 				Stats: Stats{Hits: 2, Misses: 6}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get err", | ||||||
|  | 			key:  "err", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"c", "d", "e", "err"}, | ||||||
|  | 				Stats: Stats{Hits: 2, Misses: 7}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "get err again", | ||||||
|  | 			key:  "err", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"c", "d", "e", "err"}, | ||||||
|  | 				Stats: Stats{Hits: 3, Misses: 7}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name:  "evict c", | ||||||
|  | 			key:   "c", | ||||||
|  | 			evict: true, | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"d", "e", "err"}, | ||||||
|  | 				Stats: Stats{Hits: 3, Misses: 7}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "reload-all a", | ||||||
|  | 			key:  "a", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "d", "e", "err"}, | ||||||
|  | 				Stats: Stats{Hits: 3, Misses: 8}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "reload-all b", | ||||||
|  | 			key:  "b", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "b", "e", "err"}, | ||||||
|  | 				Stats: Stats{Hits: 3, Misses: 9}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "reload-all c", | ||||||
|  | 			key:  "c", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "b", "c", "err"}, | ||||||
|  | 				Stats: Stats{Hits: 3, Misses: 10}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "reload-all d", | ||||||
|  | 			key:  "d", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"a", "b", "c", "d"}, | ||||||
|  | 				Stats: Stats{Hits: 3, Misses: 11}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "read a again", | ||||||
|  | 			key:  "a", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"b", "c", "d", "a"}, | ||||||
|  | 				Stats: Stats{Hits: 4, Misses: 11}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			name: "read e, evicting b", | ||||||
|  | 			key:  "e", | ||||||
|  | 			state: State[string]{ | ||||||
|  | 				Keys:  []string{"c", "d", "a", "e"}, | ||||||
|  | 				Stats: Stats{Hits: 4, Misses: 12}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		time.Sleep(tc.sleep) | ||||||
|  | 		if !tc.evict { | ||||||
|  | 			val, err := c.Get(tc.key) | ||||||
|  | 			if tc.key == "err" && err != ErrTest { | ||||||
|  | 				t.Fatal(tc.name, val) | ||||||
|  | 			} | ||||||
|  | 			if tc.key != "err" && val != tc.key { | ||||||
|  | 				t.Fatal(tc.name, tc.key, val) | ||||||
|  | 			} | ||||||
|  | 		} else { | ||||||
|  | 			c.Evict(tc.key) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if err := c.assert(tc.state); err != nil { | ||||||
|  | 			t.Fatal(err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestCache_thunderingHerd(t *testing.T) { | ||||||
|  | 	c := New(Config[string,string]{ | ||||||
|  | 		MaxSize: 4, | ||||||
|  | 		Src: func(key string) (string, error) { | ||||||
|  | 			time.Sleep(time.Second) | ||||||
|  | 			return key, nil | ||||||
|  | 		}, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	wg := sync.WaitGroup{} | ||||||
|  | 	for i := 0; i < 16384; i++ { | ||||||
|  | 		wg.Add(1) | ||||||
|  | 		go func() { | ||||||
|  | 			defer wg.Done() | ||||||
|  | 			val, err := c.Get("a") | ||||||
|  | 			if err != nil { | ||||||
|  | 				panic(err) | ||||||
|  | 			} | ||||||
|  | 			if val != "a" { | ||||||
|  | 				panic(err) | ||||||
|  | 			} | ||||||
|  | 		}() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	wg.Wait() | ||||||
|  |  | ||||||
|  | 	stats := c.Stats() | ||||||
|  | 	if stats.Hits != 16383 || stats.Misses != 1 { | ||||||
|  | 		t.Fatal(stats) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								kvmemcache/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								kvmemcache/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | module git.crumpington.com/lib/kvmemcache | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
|  |  | ||||||
|  | require git.crumpington.com/lib/keyedmutex v1.0.1 | ||||||
							
								
								
									
										2
									
								
								kvmemcache/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								kvmemcache/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | git.crumpington.com/lib/keyedmutex v1.0.1 h1:5ylwGXQzL9ojZIhlqkut6dpa4yt6Wz6bOWbf/tQBAMQ= | ||||||
|  | git.crumpington.com/lib/keyedmutex v1.0.1/go.mod h1:VxxJRU/XvvF61IuJZG7kUIv954Q8+Rh8bnVpEzGYrQ4= | ||||||
							
								
								
									
										6
									
								
								kvmemcache/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								kvmemcache/stats.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | package kvmemcache | ||||||
|  |  | ||||||
|  | type Stats struct { | ||||||
|  | 	Hits   uint64 | ||||||
|  | 	Misses uint64 | ||||||
|  | } | ||||||
							
								
								
									
										2
									
								
								mmap/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								mmap/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # mmap | ||||||
|  |  | ||||||
							
								
								
									
										100
									
								
								mmap/file.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								mmap/file.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | |||||||
|  | package mmap | ||||||
|  |  | ||||||
|  | import "os" | ||||||
|  |  | ||||||
|  | type File struct { | ||||||
|  | 	f   *os.File | ||||||
|  | 	Map []byte | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Create(path string, size int64) (*File, error) { | ||||||
|  | 	f, err := os.Create(path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	if err := f.Truncate(size); err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	m, err := Map(f, PROT_READ|PROT_WRITE) | ||||||
|  | 	if err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &File{f, m}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Opens a mapped file in read-only mode. | ||||||
|  | func Open(path string) (*File, error) { | ||||||
|  | 	f, err := os.Open(path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	m, err := Map(f, PROT_READ) | ||||||
|  | 	if err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return &File{f, m}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func OpenFile( | ||||||
|  | 	path string, | ||||||
|  | 	fileFlags int, | ||||||
|  | 	perm os.FileMode, | ||||||
|  | 	size int64, // -1 for file size. | ||||||
|  | ) (*File, error) { | ||||||
|  | 	f, err := os.OpenFile(path, fileFlags, perm) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	writable := fileFlags|os.O_RDWR != 0 || fileFlags|os.O_WRONLY != 0 | ||||||
|  |  | ||||||
|  | 	fi, err := f.Stat() | ||||||
|  | 	if err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if writable && size > 0 && size != fi.Size() { | ||||||
|  | 		if err := f.Truncate(size); err != nil { | ||||||
|  | 			f.Close() | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	mapFlags := PROT_READ | ||||||
|  | 	if writable { | ||||||
|  | 		mapFlags |= PROT_WRITE | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	m, err := Map(f, mapFlags) | ||||||
|  | 	if err != nil { | ||||||
|  | 		f.Close() | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return &File{f, m}, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (f *File) Sync() error { | ||||||
|  | 	return Sync(f.Map) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (f *File) Close() error { | ||||||
|  | 	if f.Map != nil { | ||||||
|  | 		if err := Unmap(f.Map); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		f.Map = nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if f.f != nil { | ||||||
|  | 		if err := f.f.Close(); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		f.f = nil | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								mmap/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								mmap/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/mmap | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
							
								
								
									
										0
									
								
								mmap/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								mmap/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										59
									
								
								mmap/mmap.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								mmap/mmap.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | |||||||
|  | package mmap | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"os" | ||||||
|  | 	"syscall" | ||||||
|  | 	"unsafe" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	PROT_READ  = syscall.PROT_READ | ||||||
|  | 	PROT_WRITE = syscall.PROT_WRITE | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Mmap creates a memory map of the given file. The flags argument should be a | ||||||
|  | // combination of PROT_READ and PROT_WRITE. The size of the map will be the | ||||||
|  | // file's size. | ||||||
|  | func Map(f *os.File, flags int) ([]byte, error) { | ||||||
|  | 	fi, err := f.Stat() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	size := fi.Size() | ||||||
|  |  | ||||||
|  | 	addr, _, errno := syscall.Syscall6( | ||||||
|  | 		syscall.SYS_MMAP, | ||||||
|  | 		0, // addr: 0 => allow kernel to choose | ||||||
|  | 		uintptr(size), | ||||||
|  | 		uintptr(flags), | ||||||
|  | 		uintptr(syscall.MAP_SHARED), | ||||||
|  | 		f.Fd(), | ||||||
|  | 		0) // offset: 0 => start of file | ||||||
|  | 	if errno != 0 { | ||||||
|  | 		return nil, syscall.Errno(errno) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Munmap unmaps the data obtained by Map. | ||||||
|  | func Unmap(data []byte) error { | ||||||
|  | 	_, _, errno := syscall.Syscall( | ||||||
|  | 		syscall.SYS_MUNMAP, | ||||||
|  | 		uintptr(unsafe.Pointer(&data[:1][0])), | ||||||
|  | 		uintptr(cap(data)), | ||||||
|  | 		0) | ||||||
|  | 	if errno != 0 { | ||||||
|  | 		return syscall.Errno(errno) | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Sync(b []byte) (err error) { | ||||||
|  | 	_p0 := unsafe.Pointer(&b[0]) | ||||||
|  | 	_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(_p0), uintptr(len(b)), uintptr(syscall.MS_SYNC)) | ||||||
|  | 	if errno != 0 { | ||||||
|  | 		err = syscall.Errno(errno) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										45
									
								
								pgutil/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								pgutil/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | |||||||
|  | # pgutil | ||||||
|  |  | ||||||
|  | ## Transactions | ||||||
|  |  | ||||||
|  | Simplify postgres transactions using `WithTx` for serializable transactions, | ||||||
|  | or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner` | ||||||
|  | type to get automatic retries of serialization errors. | ||||||
|  |  | ||||||
|  | ## Migrations | ||||||
|  |  | ||||||
|  | Put your migrations into a directory, for example `migrations`, ordered by name | ||||||
|  | (YYYY-MM-DD prefix, for example). Embed the directory and pass it to the | ||||||
|  | `Migrate` function: | ||||||
|  |  | ||||||
|  | ```Go | ||||||
|  | //go:embed migrations | ||||||
|  | var migrations embed.FS | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  |     Migrate(db, migrations) // Check the error, of course. | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Testing | ||||||
|  |  | ||||||
|  | In order to test this packge, we need to create a test user and database: | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | sudo su postgres | ||||||
|  | psql | ||||||
|  |  | ||||||
|  | CREATE DATABASE test; | ||||||
|  | CREATE USER test WITH ENCRYPTED PASSWORD 'test'; | ||||||
|  | GRANT ALL PRIVILEGES ON DATABASE test TO test; | ||||||
|  |  | ||||||
|  | use test | ||||||
|  |  | ||||||
|  | GRANT ALL ON SCHEMA public TO test; | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Check that you can connect via the command line: | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | psql -h 127.0.0.1 -U test --password test | ||||||
|  | ``` | ||||||
							
								
								
									
										42
									
								
								pgutil/dropall.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								pgutil/dropall.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"log" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const dropTablesQueryQuery = ` | ||||||
|  | SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;' | ||||||
|  | FROM | ||||||
|  |    pg_tables | ||||||
|  | WHERE | ||||||
|  |   schemaname='public'` | ||||||
|  |  | ||||||
|  | // Deletes all tables in the database. Useful for testing. | ||||||
|  | func DropAllTables(db *sql.DB) error { | ||||||
|  | 	rows, err := db.Query(dropTablesQueryQuery) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	queries := []string{} | ||||||
|  | 	for rows.Next() { | ||||||
|  | 		var s string | ||||||
|  | 		if err := rows.Scan(&s); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		queries = append(queries, s) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(queries) > 0 { | ||||||
|  | 		log.Printf("DROPPING ALL (%d) TABLES", len(queries)) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, query := range queries { | ||||||
|  | 		if _, err := db.Exec(query); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
							
								
								
									
										31
									
								
								pgutil/errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								pgutil/errors.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  |  | ||||||
|  | 	"github.com/lib/pq" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func ErrIsDuplicateKey(err error) bool { | ||||||
|  | 	return ErrHasCode(err, "23505") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ErrIsForeignKey(err error) bool { | ||||||
|  | 	return ErrHasCode(err, "23503") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ErrIsSerializationFaiilure(err error) bool { | ||||||
|  | 	return ErrHasCode(err, "40001") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ErrHasCode(err error, code string) bool { | ||||||
|  | 	if err == nil { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var pErr *pq.Error | ||||||
|  | 	if errors.As(err, &pErr) { | ||||||
|  | 		return pErr.Code == pq.ErrorCode(code) | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
							
								
								
									
										36
									
								
								pgutil/errors_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								pgutil/errors_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestErrors(t *testing.T) { | ||||||
|  | 	db, err := sql.Open( | ||||||
|  | 		"postgres", | ||||||
|  | 		"host=127.0.0.1 dbname=test sslmode=disable user=test password=test") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := DropAllTables(db); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := Migrate(db, testMigrationFS); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`) | ||||||
|  | 	if !ErrIsDuplicateKey(err) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`) | ||||||
|  | 	if !ErrIsDuplicateKey(err) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`) | ||||||
|  | 	if !ErrIsForeignKey(err) { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								pgutil/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								pgutil/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | module git.crumpington.com/git/pgutil | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
|  |  | ||||||
|  | require github.com/lib/pq v1.10.9 | ||||||
							
								
								
									
										2
									
								
								pgutil/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								pgutil/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= | ||||||
|  | github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= | ||||||
							
								
								
									
										82
									
								
								pgutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								pgutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"embed" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"sort" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const initMigrationTableQuery = ` | ||||||
|  | CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` | ||||||
|  |  | ||||||
|  | const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` | ||||||
|  |  | ||||||
|  | const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` | ||||||
|  |  | ||||||
|  | func Migrate(db *sql.DB, migrationFS embed.FS) error { | ||||||
|  | 	return WithTx(db, func(tx *sql.Tx) error { | ||||||
|  | 		if _, err := tx.Exec(initMigrationTableQuery); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		dirs, err := migrationFS.ReadDir(".") | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if len(dirs) != 1 { | ||||||
|  | 			return errors.New("expected a single migrations directory") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !dirs[0].IsDir() { | ||||||
|  | 			return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		dirName := dirs[0].Name() | ||||||
|  | 		files, err := migrationFS.ReadDir(dirName) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Sort sql files by name. | ||||||
|  | 		sort.Slice(files, func(i, j int) bool { | ||||||
|  | 			return files[i].Name() < files[j].Name() | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		for _, dirEnt := range files { | ||||||
|  | 			if !dirEnt.Type().IsRegular() { | ||||||
|  | 				return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			var ( | ||||||
|  | 				name   = dirEnt.Name() | ||||||
|  | 				exists bool | ||||||
|  | 			) | ||||||
|  |  | ||||||
|  | 			err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if exists { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if _, err := tx.Exec(string(migration)); err != nil { | ||||||
|  | 				return fmt.Errorf("migration %s failed: %v", name, err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if _, err := tx.Exec(insertMigrationQuery, name); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										50
									
								
								pgutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								pgutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"embed" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | //go:embed test-migrations | ||||||
|  | var testMigrationFS embed.FS | ||||||
|  |  | ||||||
|  | func TestMigrate(t *testing.T) { | ||||||
|  | 	db, err := sql.Open( | ||||||
|  | 		"postgres", | ||||||
|  | 		"host=127.0.0.1 dbname=test sslmode=disable user=test password=test") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := DropAllTables(db); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := Migrate(db, testMigrationFS); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Shouldn't have any effect. | ||||||
|  | 	if err := Migrate(db, testMigrationFS); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` | ||||||
|  | 	var exists bool | ||||||
|  |  | ||||||
|  | 	if err = db.QueryRow(query, 1).Scan(&exists); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if exists { | ||||||
|  | 		t.Fatal("1 shouldn't exist") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err = db.QueryRow(query, 2).Scan(&exists); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if !exists { | ||||||
|  | 		t.Fatal("2 should exist") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
							
								
								
									
										9
									
								
								pgutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								pgutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | CREATE TABLE users( | ||||||
|  |        UserID BIGINT NOT NULL PRIMARY KEY, | ||||||
|  |        Email  TEXT   NOT NULL UNIQUE); | ||||||
|  |  | ||||||
|  | CREATE TABLE user_notes( | ||||||
|  |        UserID BIGINT NOT NULL REFERENCES users(UserID), | ||||||
|  |        NoteID BIGINT NOT NULL, | ||||||
|  |        Note   Text   NOT NULL, | ||||||
|  |        PRIMARY KEY(UserID,NoteID)); | ||||||
							
								
								
									
										1
									
								
								pgutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								pgutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); | ||||||
							
								
								
									
										1
									
								
								pgutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								pgutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | DELETE FROM users WHERE UserID=1; | ||||||
							
								
								
									
										70
									
								
								pgutil/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								pgutil/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"math/rand" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Postgres doesn't use serializable transactions by default. This wrapper will | ||||||
|  | // run the enclosed function within a serializable. Note: this may return an | ||||||
|  | // retriable serialization error (see ErrIsSerializationFaiilure). | ||||||
|  | func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { | ||||||
|  | 	return WithTxDefault(db, func(tx *sql.Tx) error { | ||||||
|  | 		if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		return fn(tx) | ||||||
|  | 	}) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // This is a convenience function to provide a transaction wrapper with the | ||||||
|  | // default isolation level. | ||||||
|  | func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error { | ||||||
|  | 	// Start a transaction. | ||||||
|  | 	tx, err := db.Begin() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = fn(tx) | ||||||
|  |  | ||||||
|  | 	if err == nil { | ||||||
|  | 		err = tx.Commit() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		_ = tx.Rollback() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // SerialTxRunner attempts serializable transactions in a loop. If a | ||||||
|  | // transaction fails due to a serialization error, then the runner will retry | ||||||
|  | // with exponential backoff, until the sleep time reaches MaxTimeout. | ||||||
|  | // | ||||||
|  | // For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep | ||||||
|  | // for ~100, 200, 400, and 800 ms between retries. | ||||||
|  | // | ||||||
|  | // 10% jitter is added to the sleep time. | ||||||
|  | type SerialTxRunner struct { | ||||||
|  | 	MinTimeout time.Duration | ||||||
|  | 	MaxTimeout time.Duration | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error { | ||||||
|  | 	timeout := r.MinTimeout | ||||||
|  | 	for { | ||||||
|  | 		err := WithTx(db, fn) | ||||||
|  | 		if err == nil { | ||||||
|  | 			return nil | ||||||
|  | 		} | ||||||
|  | 		if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10))) | ||||||
|  | 		time.Sleep(sleepTimeout) | ||||||
|  | 		timeout *= 2 | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										120
									
								
								pgutil/tx_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								pgutil/tx_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | |||||||
|  | package pgutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"fmt" | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // TestExecuteTx verifies transaction retry using the classic | ||||||
|  | // example of write skew in bank account balance transfers. | ||||||
|  | func TestWithTx(t *testing.T) { | ||||||
|  | 	db, err := sql.Open( | ||||||
|  | 		"postgres", | ||||||
|  | 		"host=127.0.0.1 dbname=test sslmode=disable user=test password=test") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := DropAllTables(db); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	defer db.Close() | ||||||
|  |  | ||||||
|  | 	initStmt := ` | ||||||
|  | CREATE TABLE t (acct INT PRIMARY KEY, balance INT); | ||||||
|  | INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100); | ||||||
|  | ` | ||||||
|  | 	if _, err := db.Exec(initStmt); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	type queryI interface { | ||||||
|  | 		Query(string, ...interface{}) (*sql.Rows, error) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	getBalances := func(q queryI) (bal1, bal2 int, err error) { | ||||||
|  | 		var rows *sql.Rows | ||||||
|  | 		rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		defer rows.Close() | ||||||
|  | 		balances := []*int{&bal1, &bal2} | ||||||
|  | 		i := 0 | ||||||
|  | 		for ; rows.Next(); i++ { | ||||||
|  | 			if err = rows.Scan(balances[i]); err != nil { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		if i != 2 { | ||||||
|  | 			err = fmt.Errorf("expected two balances; got %d", i) | ||||||
|  | 			return | ||||||
|  | 		} | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond} | ||||||
|  |  | ||||||
|  | 	runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error { | ||||||
|  | 		errCh := make(chan error, 1) | ||||||
|  | 		go func() { | ||||||
|  | 			*iter = 0 | ||||||
|  | 			errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error { | ||||||
|  | 				*iter++ | ||||||
|  | 				bal1, bal2, err := getBalances(tx) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return err | ||||||
|  | 				} | ||||||
|  | 				// If this is the first iteration, wait for the other tx to | ||||||
|  | 				// also read. | ||||||
|  | 				if *iter == 1 { | ||||||
|  | 					wg.Done() | ||||||
|  | 					wg.Wait() | ||||||
|  | 				} | ||||||
|  | 				// Now, subtract from one account and give to the other. | ||||||
|  | 				if bal1 > bal2 { | ||||||
|  | 					if _, err := tx.Exec(` | ||||||
|  | UPDATE t SET balance=balance-100 WHERE acct=1; | ||||||
|  | UPDATE t SET balance=balance+100 WHERE acct=2; | ||||||
|  | `); err != nil { | ||||||
|  | 						return err | ||||||
|  | 					} | ||||||
|  | 				} else { | ||||||
|  | 					if _, err := tx.Exec(` | ||||||
|  | UPDATE t SET balance=balance+100 WHERE acct=1; | ||||||
|  | UPDATE t SET balance=balance-100 WHERE acct=2; | ||||||
|  | `); err != nil { | ||||||
|  | 						return err | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  | 				return nil | ||||||
|  | 			}) | ||||||
|  | 		}() | ||||||
|  | 		return errCh | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var wg sync.WaitGroup | ||||||
|  | 	wg.Add(2) | ||||||
|  | 	var iters1, iters2 int | ||||||
|  | 	txn1Err := runTxn(&wg, &iters1) | ||||||
|  | 	txn2Err := runTxn(&wg, &iters2) | ||||||
|  | 	if err := <-txn1Err; err != nil { | ||||||
|  | 		t.Errorf("expected success in txn1; got %s", err) | ||||||
|  | 	} | ||||||
|  | 	if err := <-txn2Err; err != nil { | ||||||
|  | 		t.Errorf("expected success in txn2; got %s", err) | ||||||
|  | 	} | ||||||
|  | 	if iters1+iters2 <= 2 { | ||||||
|  | 		t.Errorf("expected retries between the competing transactions; "+ | ||||||
|  | 			"got txn1=%d, txn2=%d", iters1, iters2) | ||||||
|  | 	} | ||||||
|  | 	bal1, bal2, err := getBalances(db) | ||||||
|  | 	if err != nil || bal1 != 100 || bal2 != 100 { | ||||||
|  | 		t.Errorf("expected balances to be restored without error; "+ | ||||||
|  | 			"got acct1=%d, acct2=%d: %s", bal1, bal2, err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								ratelimiter/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								ratelimiter/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/ratelimiter | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
							
								
								
									
										86
									
								
								ratelimiter/ratelimiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								ratelimiter/ratelimiter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | |||||||
|  | package ratelimiter | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"sync" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ErrBackoff = errors.New("Backoff") | ||||||
|  |  | ||||||
|  | type Config struct { | ||||||
|  | 	BurstLimit   int64         // Number of requests to allow to burst. | ||||||
|  | 	FillPeriod   time.Duration // Add one per period. | ||||||
|  | 	MaxWaitCount int64         // Max number of waiting requests. 0 disables. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Limiter struct { | ||||||
|  | 	lock sync.Mutex | ||||||
|  |  | ||||||
|  | 	fillPeriod  time.Duration | ||||||
|  | 	minWaitTime time.Duration | ||||||
|  | 	maxWaitTime time.Duration | ||||||
|  |  | ||||||
|  | 	waitTime    time.Duration // If waitTime < 0, no waiting occurs. | ||||||
|  | 	lastRequest time.Time | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func New(conf Config) *Limiter { | ||||||
|  | 	if conf.BurstLimit < 0 { | ||||||
|  | 		panic(conf.BurstLimit) | ||||||
|  | 	} | ||||||
|  | 	if conf.FillPeriod <= 0 { | ||||||
|  | 		panic(conf.FillPeriod) | ||||||
|  | 	} | ||||||
|  | 	if conf.MaxWaitCount < 0 { | ||||||
|  | 		panic(conf.MaxWaitCount) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lim := &Limiter{ | ||||||
|  | 		lastRequest: time.Now(), | ||||||
|  | 		fillPeriod:  conf.FillPeriod, | ||||||
|  | 		waitTime:    -conf.FillPeriod * time.Duration(conf.BurstLimit), | ||||||
|  | 		minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), | ||||||
|  | 		maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount), | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lim.waitTime = lim.minWaitTime | ||||||
|  |  | ||||||
|  | 	return lim | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (lim *Limiter) limit(count int64) (time.Duration, error) { | ||||||
|  | 	lim.lock.Lock() | ||||||
|  | 	defer lim.lock.Unlock() | ||||||
|  |  | ||||||
|  | 	dt := time.Since(lim.lastRequest) | ||||||
|  |  | ||||||
|  | 	waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod | ||||||
|  |  | ||||||
|  | 	if waitTime < lim.minWaitTime { | ||||||
|  | 		waitTime = lim.minWaitTime | ||||||
|  | 	} else if waitTime > lim.maxWaitTime { | ||||||
|  | 		return 0, ErrBackoff | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	lim.waitTime = waitTime | ||||||
|  | 	lim.lastRequest = lim.lastRequest.Add(dt) | ||||||
|  |  | ||||||
|  | 	return lim.waitTime, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Apply the limiter to the calling thread. The function may sleep for up to | ||||||
|  | // maxWaitTime before returning. If the timeout would need to be more than | ||||||
|  | // maxWaitTime to enforce the rate limit, ErrBackoff is returned. | ||||||
|  | func (lim *Limiter) Limit() error { | ||||||
|  | 	dt, err := lim.limit(1) | ||||||
|  | 	time.Sleep(dt) // Will return immediately for dt <= 0. | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Apply the limiter for multiple items at once. | ||||||
|  | func (lim *Limiter) LimitMultiple(count int64) error { | ||||||
|  | 	dt, err := lim.limit(count) | ||||||
|  | 	time.Sleep(dt) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
							
								
								
									
										101
									
								
								ratelimiter/ratelimiter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								ratelimiter/ratelimiter_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | |||||||
|  | package ratelimiter | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"sync" | ||||||
|  | 	"testing" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestRateLimiter_Limit_Errors(t *testing.T) { | ||||||
|  | 	type TestCase struct { | ||||||
|  | 		Name     string | ||||||
|  | 		Conf     Config | ||||||
|  | 		N        int | ||||||
|  | 		ErrCount int | ||||||
|  | 		DT       time.Duration | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []TestCase{ | ||||||
|  | 		{ | ||||||
|  | 			Name: "no burst, no wait", | ||||||
|  | 			Conf: Config{ | ||||||
|  | 				BurstLimit:   0, | ||||||
|  | 				FillPeriod:   100 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 0, | ||||||
|  | 			}, | ||||||
|  | 			N:        32, | ||||||
|  | 			ErrCount: 32, | ||||||
|  | 			DT:       0, | ||||||
|  | 		}, { | ||||||
|  | 			Name: "no wait", | ||||||
|  | 			Conf: Config{ | ||||||
|  | 				BurstLimit:   10, | ||||||
|  | 				FillPeriod:   100 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 0, | ||||||
|  | 			}, | ||||||
|  | 			N:        32, | ||||||
|  | 			ErrCount: 22, | ||||||
|  | 			DT:       0, | ||||||
|  | 		}, { | ||||||
|  | 			Name: "no burst", | ||||||
|  | 			Conf: Config{ | ||||||
|  | 				BurstLimit:   0, | ||||||
|  | 				FillPeriod:   10 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 10, | ||||||
|  | 			}, | ||||||
|  | 			N:        32, | ||||||
|  | 			ErrCount: 22, | ||||||
|  | 			DT:       100 * time.Millisecond, | ||||||
|  | 		}, { | ||||||
|  | 			Name: "burst and wait", | ||||||
|  | 			Conf: Config{ | ||||||
|  | 				BurstLimit:   10, | ||||||
|  | 				FillPeriod:   10 * time.Millisecond, | ||||||
|  | 				MaxWaitCount: 10, | ||||||
|  | 			}, | ||||||
|  | 			N:        32, | ||||||
|  | 			ErrCount: 12, | ||||||
|  | 			DT:       100 * time.Millisecond, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		wg := sync.WaitGroup{} | ||||||
|  | 		l := New(tc.Conf) | ||||||
|  | 		errs := make([]error, tc.N) | ||||||
|  |  | ||||||
|  | 		t0 := time.Now() | ||||||
|  |  | ||||||
|  | 		for i := 0; i < tc.N; i++ { | ||||||
|  | 			wg.Add(1) | ||||||
|  | 			go func(i int) { | ||||||
|  | 				errs[i] = l.Limit() | ||||||
|  | 				wg.Done() | ||||||
|  | 			}(i) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		wg.Wait() | ||||||
|  |  | ||||||
|  | 		dt := time.Since(t0) | ||||||
|  |  | ||||||
|  | 		errCount := 0 | ||||||
|  | 		for _, err := range errs { | ||||||
|  | 			if err != nil { | ||||||
|  | 				errCount++ | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if errCount != tc.ErrCount { | ||||||
|  | 			t.Fatalf("%s: Expected %d errors but got %d.", | ||||||
|  | 				tc.Name, tc.ErrCount, errCount) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if dt < tc.DT { | ||||||
|  | 			t.Fatal(tc.Name, dt, tc.DT) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if dt > tc.DT+10*time.Millisecond { | ||||||
|  | 			t.Fatal(tc.Name, dt, tc.DT) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										22
									
								
								sqlgen/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								sqlgen/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | # sqlgen | ||||||
|  |  | ||||||
|  | ## Installing | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Usage | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | sqlgen [driver] [defs-path] [output-path] | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## File Format | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | TABLE [sql-name] OF [go-type] <NoInsert> <NoUpdate> <NoDelete> ( | ||||||
|  |     [sql-column] [go-type] <AS go-name> <PK> <NoInsert> <NoUpdate>, | ||||||
|  |     ... | ||||||
|  | ); | ||||||
|  | ``` | ||||||
							
								
								
									
										7
									
								
								sqlgen/cmd/sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								sqlgen/cmd/sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | |||||||
|  | package main | ||||||
|  |  | ||||||
|  | import "git.crumpington.com/lib/sqlgen" | ||||||
|  |  | ||||||
|  | func main() { | ||||||
|  | 	sqlgen.Main() | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								sqlgen/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								sqlgen/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/sqlgen | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
							
								
								
									
										43
									
								
								sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | |||||||
|  | package sqlgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"os" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func Main() { | ||||||
|  | 	usage := func() { | ||||||
|  | 		fmt.Fprintf(os.Stderr, ` | ||||||
|  | %s DRIVER DEFS_PATH OUTPUT_PATH | ||||||
|  |  | ||||||
|  | Drivers are one of: sqlite, postgres | ||||||
|  | `, | ||||||
|  | 			os.Args[0]) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(os.Args) != 4 { | ||||||
|  | 		usage() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var ( | ||||||
|  | 		template   string | ||||||
|  | 		driver     = os.Args[1] | ||||||
|  | 		defsPath   = os.Args[2] | ||||||
|  | 		outputPath = os.Args[3] | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	switch driver { | ||||||
|  | 	case "sqlite": | ||||||
|  | 		template = sqliteTemplate | ||||||
|  | 	default: | ||||||
|  | 		fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver) | ||||||
|  | 		usage() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err := render(template, defsPath, outputPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		fmt.Fprintf(os.Stderr, "Error: %v", err) | ||||||
|  | 		os.Exit(1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										143
									
								
								sqlgen/parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								sqlgen/parse.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | |||||||
|  | package sqlgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"os" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func parsePath(filePath string) (*schema, error) { | ||||||
|  | 	fileBytes, err := os.ReadFile(filePath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil, err | ||||||
|  | 	} | ||||||
|  | 	return parseBytes(fileBytes) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseBytes(fileBytes []byte) (*schema, error) { | ||||||
|  | 	s := string(fileBytes) | ||||||
|  | 	for _, c := range []string{",", "(", ")", ";"} { | ||||||
|  | 		s = strings.ReplaceAll(s, c, " "+c+" ") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var ( | ||||||
|  | 		tokens = strings.Fields(s) | ||||||
|  | 		schema = &schema{} | ||||||
|  | 		err    error | ||||||
|  | 	) | ||||||
|  |  | ||||||
|  | 	for len(tokens) > 0 { | ||||||
|  | 		switch tokens[0] { | ||||||
|  | 		case "TABLE": | ||||||
|  | 			tokens, err = parseTable(schema, tokens) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return nil, err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 		default: | ||||||
|  | 			return nil, errors.New("invalid token: " + tokens[0]) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return schema, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseTable(schema *schema, tokens []string) ([]string, error) { | ||||||
|  | 	tokens = tokens[1:] | ||||||
|  | 	if len(tokens) < 3 { | ||||||
|  | 		return tokens, errors.New("incomplete table definition") | ||||||
|  | 	} | ||||||
|  | 	if tokens[1] != "OF" { | ||||||
|  | 		return tokens, errors.New("expected OF in table definition") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	table := &table{ | ||||||
|  | 		Name: tokens[0], | ||||||
|  | 		Type: tokens[2], | ||||||
|  | 	} | ||||||
|  | 	schema.Tables = append(schema.Tables, table) | ||||||
|  |  | ||||||
|  | 	tokens = tokens[3:] | ||||||
|  |  | ||||||
|  | 	if len(tokens) == 0 { | ||||||
|  | 		return tokens, errors.New("missing table definition body") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for len(tokens) > 0 { | ||||||
|  | 		switch tokens[0] { | ||||||
|  | 		case "NoInsert": | ||||||
|  | 			table.NoInsert = true | ||||||
|  | 			tokens = tokens[1:] | ||||||
|  | 		case "NoUpdate": | ||||||
|  | 			table.NoUpdate = true | ||||||
|  | 			tokens = tokens[1:] | ||||||
|  | 		case "NoDelete": | ||||||
|  | 			table.NoDelete = true | ||||||
|  | 			tokens = tokens[1:] | ||||||
|  | 		case "(": | ||||||
|  | 			return parseTableBody(table, tokens[1:]) | ||||||
|  | 		default: | ||||||
|  | 			return tokens, errors.New("unexpected token in table definition: " + tokens[0]) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return tokens, errors.New("incomplete table definition") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseTableBody(table *table, tokens []string) ([]string, error) { | ||||||
|  | 	var err error | ||||||
|  |  | ||||||
|  | 	for len(tokens) > 0 && tokens[0] != ";" { | ||||||
|  | 		tokens, err = parseTableColumn(table, tokens) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return tokens, err | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(tokens) < 1 || tokens[0] != ";" { | ||||||
|  | 		return tokens, errors.New("incomplete table column definitions") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return tokens[1:], nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func parseTableColumn(table *table, tokens []string) ([]string, error) { | ||||||
|  | 	if len(tokens) < 2 { | ||||||
|  | 		return tokens, errors.New("incomplete column definition") | ||||||
|  | 	} | ||||||
|  | 	column := &column{ | ||||||
|  | 		Name:    tokens[0], | ||||||
|  | 		Type:    tokens[1], | ||||||
|  | 		SqlName: tokens[0], | ||||||
|  | 	} | ||||||
|  | 	table.Columns = append(table.Columns, column) | ||||||
|  |  | ||||||
|  | 	tokens = tokens[2:] | ||||||
|  | 	for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" { | ||||||
|  | 		switch tokens[0] { | ||||||
|  | 		case "AS": | ||||||
|  | 			if len(tokens) < 2 { | ||||||
|  | 				return tokens, errors.New("incomplete AS clause in column definition") | ||||||
|  | 			} | ||||||
|  | 			column.Name = tokens[1] | ||||||
|  | 			tokens = tokens[2:] | ||||||
|  | 		case "PK": | ||||||
|  | 			column.PK = true | ||||||
|  | 			tokens = tokens[1:] | ||||||
|  | 		case "NoInsert": | ||||||
|  | 			column.NoInsert = true | ||||||
|  | 			tokens = tokens[1:] | ||||||
|  | 		case "NoUpdate": | ||||||
|  | 			column.NoUpdate = true | ||||||
|  | 			tokens = tokens[1:] | ||||||
|  | 		default: | ||||||
|  | 			return tokens, errors.New("unexpected token in column definition: " + tokens[0]) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(tokens) == 0 { | ||||||
|  | 		return tokens, errors.New("incomplete column definition") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return tokens[1:], nil | ||||||
|  | } | ||||||
							
								
								
									
										45
									
								
								sqlgen/parse_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								sqlgen/parse_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | |||||||
|  | package sqlgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"os" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"reflect" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestParse(t *testing.T) { | ||||||
|  | 	toString := func(v any) string { | ||||||
|  | 		txt, _ := json.MarshalIndent(v, "", "  ") | ||||||
|  | 		return string(txt) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	paths, err := filepath.Glob("test-files/TestParse/*.def") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, defPath := range paths { | ||||||
|  | 		t.Run(filepath.Base(defPath), func(t *testing.T) { | ||||||
|  | 			parsed, err := parsePath(defPath) | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json") | ||||||
|  | 			if err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			expected := &schema{} | ||||||
|  | 			if err := json.Unmarshal(b, expected); err != nil { | ||||||
|  | 				t.Fatal(err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if !reflect.DeepEqual(parsed, expected) { | ||||||
|  | 				t.Fatalf("%s != %s", toString(parsed), toString(expected)) | ||||||
|  | 			} | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										263
									
								
								sqlgen/schema.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								sqlgen/schema.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | |||||||
|  | package sqlgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type schema struct { | ||||||
|  | 	Tables []*table | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type table struct { | ||||||
|  | 	Name     string // Name in SQL | ||||||
|  | 	Type     string // Go type | ||||||
|  | 	NoInsert bool | ||||||
|  | 	NoUpdate bool | ||||||
|  | 	NoDelete bool | ||||||
|  | 	Columns  []*column | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type column struct { | ||||||
|  | 	Name     string | ||||||
|  | 	Type     string | ||||||
|  | 	SqlName  string // Defaults to Name | ||||||
|  | 	PK       bool   // PK won't be updated | ||||||
|  | 	NoInsert bool | ||||||
|  | 	NoUpdate bool // Don't update column in update function | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | func (t *table) colSQLNames() []string { | ||||||
|  | 	names := make([]string, len(t.Columns)) | ||||||
|  | 	for i := range names { | ||||||
|  | 		names[i] = t.Columns[i].SqlName | ||||||
|  | 	} | ||||||
|  | 	return names | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) SelectQuery() string { | ||||||
|  | 	return fmt.Sprintf(`SELECT %s FROM %s`, | ||||||
|  | 		strings.Join(t.colSQLNames(), ","), | ||||||
|  | 		t.Name) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) insertCols() (cols []*column) { | ||||||
|  | 	for _, c := range t.Columns { | ||||||
|  | 		if !c.NoInsert { | ||||||
|  | 			cols = append(cols, c) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return cols | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) InsertQuery() string { | ||||||
|  | 	cols := t.insertCols() | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	b.WriteString(`INSERT INTO `) | ||||||
|  | 	b.WriteString(t.Name) | ||||||
|  | 	b.WriteString(`(`) | ||||||
|  | 	for i, c := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(`,`) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(c.SqlName) | ||||||
|  | 	} | ||||||
|  | 	b.WriteString(`) VALUES(`) | ||||||
|  | 	for i := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(`,`) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(`?`) | ||||||
|  | 	} | ||||||
|  | 	b.WriteString(`)`) | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) InsertArgs() string { | ||||||
|  | 	args := []string{} | ||||||
|  | 	for i, col := range t.Columns { | ||||||
|  | 		if !col.NoInsert { | ||||||
|  | 			args = append(args, "row."+t.Columns[i].Name) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return strings.Join(args, ", ") | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) UpdateCols() (cols []*column) { | ||||||
|  | 	for _, c := range t.Columns { | ||||||
|  | 		if !(c.PK || c.NoUpdate) { | ||||||
|  | 			cols = append(cols, c) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return cols | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) UpdateQuery() string { | ||||||
|  | 	cols := t.UpdateCols() | ||||||
|  |  | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	b.WriteString(`UPDATE `) | ||||||
|  | 	b.WriteString(t.Name + ` SET `) | ||||||
|  | 	for i, col := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteByte(',') | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(col.SqlName + `=?`) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	b.WriteString(` WHERE`) | ||||||
|  |  | ||||||
|  | 	for i, c := range t.pkCols() { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(` AND`) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(` ` + c.SqlName + `=?`) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) UpdateArgs() string { | ||||||
|  | 	cols := t.UpdateCols() | ||||||
|  |  | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	for i, col := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(`, `) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString("row." + col.Name) | ||||||
|  | 	} | ||||||
|  | 	for _, col := range t.pkCols() { | ||||||
|  | 		b.WriteString(", row." + col.Name) | ||||||
|  | 	} | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) UpdateFullCols() (cols []*column) { | ||||||
|  | 	for _, c := range t.Columns { | ||||||
|  | 		if !c.PK { | ||||||
|  | 			cols = append(cols, c) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return cols | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) UpdateFullQuery() string { | ||||||
|  | 	cols := t.UpdateFullCols() | ||||||
|  |  | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	b.WriteString(`UPDATE `) | ||||||
|  | 	b.WriteString(t.Name + ` SET `) | ||||||
|  | 	for i, col := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteByte(',') | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(col.SqlName + `=?`) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	b.WriteString(` WHERE`) | ||||||
|  |  | ||||||
|  | 	for i, c := range t.pkCols() { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(` AND`) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(` ` + c.SqlName + `=?`) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) UpdateFullArgs() string { | ||||||
|  | 	cols := t.UpdateFullCols() | ||||||
|  |  | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	for i, col := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(`, `) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString("row." + col.Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, col := range t.pkCols() { | ||||||
|  | 		b.WriteString(", row." + col.Name) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) pkCols() (cols []*column) { | ||||||
|  | 	for _, c := range t.Columns { | ||||||
|  | 		if c.PK { | ||||||
|  | 			cols = append(cols, c) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return cols | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) PKFunctionArgs() string { | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	for _, col := range t.pkCols() { | ||||||
|  | 		b.WriteString(col.Name) | ||||||
|  | 		b.WriteString(` `) | ||||||
|  | 		b.WriteString(col.Type) | ||||||
|  | 		b.WriteString(",\n") | ||||||
|  | 	} | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) DeleteQuery() string { | ||||||
|  | 	cols := t.pkCols() | ||||||
|  |  | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	b.WriteString(`DELETE FROM `) | ||||||
|  | 	b.WriteString(t.Name) | ||||||
|  | 	b.WriteString(` WHERE `) | ||||||
|  |  | ||||||
|  | 	for i, col := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(` AND `) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(col.SqlName) | ||||||
|  | 		b.WriteString(`=?`) | ||||||
|  | 	} | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) DeleteArgs() string { | ||||||
|  | 	cols := t.pkCols() | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  |  | ||||||
|  | 	for i, col := range cols { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(`,`) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(col.Name) | ||||||
|  | 	} | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) GetQuery() string { | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	b.WriteString(t.SelectQuery()) | ||||||
|  | 	b.WriteString(` WHERE `) | ||||||
|  | 	for i, col := range t.pkCols() { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(` AND `) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(col.SqlName + `=?`) | ||||||
|  | 	} | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *table) ScanArgs() string { | ||||||
|  | 	b := &strings.Builder{} | ||||||
|  | 	for i, col := range t.Columns { | ||||||
|  | 		if i != 0 { | ||||||
|  | 			b.WriteString(`, `) | ||||||
|  | 		} | ||||||
|  | 		b.WriteString(`&row.` + col.Name) | ||||||
|  | 	} | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
							
								
								
									
										206
									
								
								sqlgen/sqlite.go.tmpl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								sqlgen/sqlite.go.tmpl
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | |||||||
|  | package {{.PackageName}} | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"iter" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type TX interface { | ||||||
|  | 	Exec(query string, args ...any) (sql.Result, error) | ||||||
|  | 	Query(query string, args ...any) (*sql.Rows, error) | ||||||
|  | 	QueryRow(query string, args ...any) *sql.Row | ||||||
|  | } | ||||||
|  |  | ||||||
|  | {{range .Schema.Tables}} | ||||||
|  |  | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  | // Table: {{.Name}} | ||||||
|  | // ---------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | type {{.Type}} struct { | ||||||
|  | 	{{- range .Columns}} | ||||||
|  | 	{{.Name}} {{.Type}}{{end}} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | const {{.Type}}_SelectQuery = "{{.SelectQuery}}" | ||||||
|  |  | ||||||
|  | {{if not .NoInsert -}} | ||||||
|  |  | ||||||
|  | func {{.Type}}_Insert( | ||||||
|  | 	tx TX, | ||||||
|  | 	row *{{.Type}}, | ||||||
|  | ) (err error) { | ||||||
|  |   {{.Type}}_Sanitize(row) | ||||||
|  |   if err = {{.Type}}_Validate(row); err != nil { | ||||||
|  |     return err | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | 	_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) | ||||||
|  | 	return err | ||||||
|  | } | ||||||
|  |  | ||||||
|  | {{- end}} {{/* if not .NoInsert */}} | ||||||
|  |  | ||||||
|  | {{if not .NoUpdate -}} | ||||||
|  |  | ||||||
|  | {{if .UpdateCols -}} | ||||||
|  |  | ||||||
|  | func {{.Type}}_Update( | ||||||
|  | 	tx TX, | ||||||
|  | 	row *{{.Type}}, | ||||||
|  | ) (found bool, err error) { | ||||||
|  |   {{.Type}}_Sanitize(row) | ||||||
|  |   if err = {{.Type}}_Validate(row); err != nil { | ||||||
|  |     return false, err | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | 	result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	n, err := result.RowsAffected() | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if n > 1 { | ||||||
|  | 		panic("multiple rows updated") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return n != 0, nil | ||||||
|  | } | ||||||
|  | {{- end}} | ||||||
|  |  | ||||||
|  | {{if .UpdateFullCols -}} | ||||||
|  |  | ||||||
|  | func {{.Type}}_UpdateFull( | ||||||
|  | 	tx TX, | ||||||
|  | 	row *{{.Type}}, | ||||||
|  | ) (found bool, err error) { | ||||||
|  |   {{.Type}}_Sanitize(row) | ||||||
|  |   if err = {{.Type}}_Validate(row); err != nil { | ||||||
|  |     return false, err | ||||||
|  |   } | ||||||
|  |  | ||||||
|  | 	result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	n, err := result.RowsAffected() | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if n > 1 { | ||||||
|  | 		panic("multiple rows updated") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return n != 0, nil | ||||||
|  | } | ||||||
|  | {{- end}} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | {{- end}} {{/* if not .NoUpdate */}} | ||||||
|  |  | ||||||
|  | {{if not .NoDelete -}} | ||||||
|  |  | ||||||
|  | func {{.Type}}_Delete( | ||||||
|  | 	tx TX, | ||||||
|  | 	{{.PKFunctionArgs -}} | ||||||
|  | ) (found bool, err error) { | ||||||
|  | 	result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) | ||||||
|  | 		if err != nil { | ||||||
|  | 		return false, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	n, err := result.RowsAffected() | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if n > 1 { | ||||||
|  | 		panic("multiple rows deleted") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return n != 0, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | {{- end}} | ||||||
|  |  | ||||||
|  | func {{.Type}}_Get( | ||||||
|  | 	tx TX, | ||||||
|  | 	{{.PKFunctionArgs -}} | ||||||
|  | ) ( | ||||||
|  | 	row *{{.Type}}, | ||||||
|  | 	err error, | ||||||
|  | ) { | ||||||
|  |   row = &{{.Type}}{} | ||||||
|  | 	r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) | ||||||
|  | 	err = r.Scan({{.ScanArgs}}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | func {{.Type}}_GetWhere( | ||||||
|  | 	tx TX, | ||||||
|  | 	query string, | ||||||
|  | 	args ...any, | ||||||
|  | ) ( | ||||||
|  | 	row *{{.Type}}, | ||||||
|  | 	err error, | ||||||
|  | ) { | ||||||
|  |   row = &{{.Type}}{} | ||||||
|  | 	r := tx.QueryRow(query, args...) | ||||||
|  | 	err = r.Scan({{.ScanArgs}}) | ||||||
|  | 	return | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | func {{.Type}}_Iterate( | ||||||
|  | 	tx TX, | ||||||
|  | 	query string, | ||||||
|  | 	args ...any, | ||||||
|  | ) ( | ||||||
|  | 	iter.Seq2[*{{.Type}}, error], | ||||||
|  | ) { | ||||||
|  | 	rows, err := tx.Query(query, args...) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return func(yield func(*{{.Type}}, error) bool) { | ||||||
|  | 			yield(nil, err) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return func(yield func(*{{.Type}}, error) bool) { | ||||||
|  | 		defer rows.Close() | ||||||
|  | 		for rows.Next() { | ||||||
|  | 			row := &{{.Type}}{} | ||||||
|  | 			err := rows.Scan({{.ScanArgs}}) | ||||||
|  | 			if !yield(row, err) { | ||||||
|  | 				return | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func {{.Type}}_List( | ||||||
|  | 	tx TX, | ||||||
|  | 	query string, | ||||||
|  | 	args ...any, | ||||||
|  | ) ( | ||||||
|  | 	l []*{{.Type}}, | ||||||
|  | 	err error, | ||||||
|  | ) { | ||||||
|  | 	for row, err := range {{.Type}}_Iterate(tx, query, args...) { | ||||||
|  | 		if err != nil { | ||||||
|  | 			return nil, err | ||||||
|  | 		} | ||||||
|  | 		l = append(l, row) | ||||||
|  | 	} | ||||||
|  | 	return l, nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | {{end}} {{/* range .Schema.Tables */}} | ||||||
							
								
								
									
										36
									
								
								sqlgen/template.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								sqlgen/template.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | |||||||
|  | package sqlgen | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	_ "embed" | ||||||
|  | 	"os" | ||||||
|  | 	"os/exec" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"text/template" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | //go:embed sqlite.go.tmpl | ||||||
|  | var sqliteTemplate string | ||||||
|  |  | ||||||
|  | func render(templateStr, schemaPath, outputPath string) error { | ||||||
|  | 	sch, err := parsePath(schemaPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	tmpl := template.Must(template.New("").Parse(templateStr)) | ||||||
|  | 	fOut, err := os.Create(outputPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  | 	defer fOut.Close() | ||||||
|  |  | ||||||
|  | 	err = tmpl.Execute(fOut, struct { | ||||||
|  | 		PackageName string | ||||||
|  | 		Schema      *schema | ||||||
|  | 	}{filepath.Base(filepath.Dir(outputPath)), sch}) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return exec.Command("gofmt", "-w", outputPath).Run() | ||||||
|  | } | ||||||
							
								
								
									
										3
									
								
								sqlgen/test-files/TestParse/000.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								sqlgen/test-files/TestParse/000.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | TABLE users OF User NoDelete ( | ||||||
|  |   user_id string AS UserID PK | ||||||
|  | ); | ||||||
							
								
								
									
										17
									
								
								sqlgen/test-files/TestParse/000.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								sqlgen/test-files/TestParse/000.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | |||||||
|  | { | ||||||
|  |   "Tables": [ | ||||||
|  |     { | ||||||
|  |       "Name": "users", | ||||||
|  |       "Type": "User", | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     } | ||||||
|  |   ] | ||||||
|  | } | ||||||
							
								
								
									
										4
									
								
								sqlgen/test-files/TestParse/001.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								sqlgen/test-files/TestParse/001.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | |||||||
|  | TABLE users OF User NoDelete ( | ||||||
|  |   user_id string AS UserID PK, | ||||||
|  |   email   string AS Email NoUpdate | ||||||
|  | ); | ||||||
							
								
								
									
										22
									
								
								sqlgen/test-files/TestParse/001.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								sqlgen/test-files/TestParse/001.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | |||||||
|  | { | ||||||
|  |   "Tables": [ | ||||||
|  |     { | ||||||
|  |       "Name": "users", | ||||||
|  |       "Type": "User", | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         }, { | ||||||
|  |           "Name": "Email", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "email", | ||||||
|  |           "NoUpdate": true | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     } | ||||||
|  |   ] | ||||||
|  | } | ||||||
							
								
								
									
										6
									
								
								sqlgen/test-files/TestParse/002.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								sqlgen/test-files/TestParse/002.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | TABLE users OF User NoDelete ( | ||||||
|  |   user_id string AS UserID PK, | ||||||
|  |   email   string AS Email NoUpdate, | ||||||
|  |   name    string AS Name NoInsert, | ||||||
|  |   admin   bool   AS Admin NoInsert NoUpdate | ||||||
|  | ); | ||||||
							
								
								
									
										33
									
								
								sqlgen/test-files/TestParse/002.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								sqlgen/test-files/TestParse/002.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | |||||||
|  | { | ||||||
|  |   "Tables": [ | ||||||
|  |     { | ||||||
|  |       "Name": "users", | ||||||
|  |       "Type": "User", | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         }, { | ||||||
|  |           "Name": "Email", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "email", | ||||||
|  |           "NoUpdate": true | ||||||
|  |         }, { | ||||||
|  |           "Name": "Name", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "name", | ||||||
|  |           "NoInsert": true | ||||||
|  |         }, { | ||||||
|  |           "Name": "Admin", | ||||||
|  |           "Type": "bool", | ||||||
|  |           "SqlName": "admin", | ||||||
|  |           "NoInsert": true, | ||||||
|  |           "NoUpdate": true | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     } | ||||||
|  |   ] | ||||||
|  | } | ||||||
							
								
								
									
										12
									
								
								sqlgen/test-files/TestParse/003.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								sqlgen/test-files/TestParse/003.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | |||||||
|  | TABLE users OF User NoDelete ( | ||||||
|  |   user_id string AS UserID PK, | ||||||
|  |   email   string AS Email NoUpdate, | ||||||
|  |   name    string AS Name NoInsert, | ||||||
|  |   admin   bool   AS Admin NoInsert NoUpdate | ||||||
|  | ); | ||||||
|  |  | ||||||
|  | TABLE users_view OF UserView NoInsert NoUpdate NoDelete ( | ||||||
|  |   user_id string AS UserID PK, | ||||||
|  |   email   string AS Email, | ||||||
|  |   name    string AS Name | ||||||
|  | ); | ||||||
							
								
								
									
										61
									
								
								sqlgen/test-files/TestParse/003.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								sqlgen/test-files/TestParse/003.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | |||||||
|  | { | ||||||
|  |   "Tables": [ | ||||||
|  |     { | ||||||
|  |       "Name": "users", | ||||||
|  |       "Type": "User", | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Email", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "email", | ||||||
|  |           "NoUpdate": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Name", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "name", | ||||||
|  |           "NoInsert": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Admin", | ||||||
|  |           "Type": "bool", | ||||||
|  |           "SqlName": "admin", | ||||||
|  |           "NoInsert": true, | ||||||
|  |           "NoUpdate": true | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "Name": "users_view", | ||||||
|  |       "Type": "UserView", | ||||||
|  |       "NoInsert": true, | ||||||
|  |       "NoUpdate": true, | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Email", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "email" | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Name", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "name" | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     } | ||||||
|  |   ] | ||||||
|  | } | ||||||
							
								
								
									
										13
									
								
								sqlgen/test-files/TestParse/004.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								sqlgen/test-files/TestParse/004.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | |||||||
|  | TABLE users OF User NoDelete ( | ||||||
|  |   user_id string AS UserID PK, | ||||||
|  |   email   string AS Email NoUpdate, | ||||||
|  |   name    string AS Name NoInsert, | ||||||
|  |   admin   bool   AS Admin NoInsert NoUpdate, | ||||||
|  |   SSN     string NoUpdate | ||||||
|  | ); | ||||||
|  |  | ||||||
|  | TABLE users_view OF UserView NoInsert NoUpdate NoDelete ( | ||||||
|  |   user_id string AS UserID PK, | ||||||
|  |   email   string AS Email, | ||||||
|  |   name    string AS Name | ||||||
|  | ); | ||||||
							
								
								
									
										66
									
								
								sqlgen/test-files/TestParse/004.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								sqlgen/test-files/TestParse/004.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | |||||||
|  | { | ||||||
|  |   "Tables": [ | ||||||
|  |     { | ||||||
|  |       "Name": "users", | ||||||
|  |       "Type": "User", | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Email", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "email", | ||||||
|  |           "NoUpdate": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Name", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "name", | ||||||
|  |           "NoInsert": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Admin", | ||||||
|  |           "Type": "bool", | ||||||
|  |           "SqlName": "admin", | ||||||
|  |           "NoInsert": true, | ||||||
|  |           "NoUpdate": true | ||||||
|  |         }, { | ||||||
|  |           "Name": "SSN", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "SSN", | ||||||
|  |           "NoUpdate": true | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     }, | ||||||
|  |     { | ||||||
|  |       "Name": "users_view", | ||||||
|  |       "Type": "UserView", | ||||||
|  |       "NoInsert": true, | ||||||
|  |       "NoUpdate": true, | ||||||
|  |       "NoDelete": true, | ||||||
|  |       "Columns": [ | ||||||
|  |         { | ||||||
|  |           "Name": "UserID", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "user_id", | ||||||
|  |           "PK": true | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Email", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "email" | ||||||
|  |         }, | ||||||
|  |         { | ||||||
|  |           "Name": "Name", | ||||||
|  |           "Type": "string", | ||||||
|  |           "SqlName": "name" | ||||||
|  |         } | ||||||
|  |       ] | ||||||
|  |     } | ||||||
|  |   ] | ||||||
|  | } | ||||||
							
								
								
									
										45
									
								
								sqliteutil/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								sqliteutil/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | |||||||
|  | # sqliteutil | ||||||
|  |  | ||||||
|  | ## Transactions | ||||||
|  |  | ||||||
|  | Simplify postgres transactions using `WithTx` for serializable transactions, | ||||||
|  | or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner` | ||||||
|  | type to get automatic retries of serialization errors. | ||||||
|  |  | ||||||
|  | ## Migrations | ||||||
|  |  | ||||||
|  | Put your migrations into a directory, for example `migrations`, ordered by name | ||||||
|  | (YYYY-MM-DD prefix, for example). Embed the directory and pass it to the | ||||||
|  | `Migrate` function: | ||||||
|  |  | ||||||
|  | ```Go | ||||||
|  | //go:embed migrations | ||||||
|  | var migrations embed.FS | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  |     Migrate(db, migrations) // Check the error, of course. | ||||||
|  | } | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | ## Testing | ||||||
|  |  | ||||||
|  | In order to test this packge, we need to create a test user and database: | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | sudo su postgres | ||||||
|  | psql | ||||||
|  |  | ||||||
|  | CREATE DATABASE test; | ||||||
|  | CREATE USER test WITH ENCRYPTED PASSWORD 'test'; | ||||||
|  | GRANT ALL PRIVILEGES ON DATABASE test TO test; | ||||||
|  |  | ||||||
|  | use test | ||||||
|  |  | ||||||
|  | GRANT ALL ON SCHEMA public TO test; | ||||||
|  | ``` | ||||||
|  |  | ||||||
|  | Check that you can connect via the command line: | ||||||
|  |  | ||||||
|  | ``` | ||||||
|  | psql -h 127.0.0.1 -U test --password test | ||||||
|  | ``` | ||||||
							
								
								
									
										5
									
								
								sqliteutil/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								sqliteutil/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | module git.crumpington.com/lib/sqliteutil | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
|  |  | ||||||
|  | require github.com/mattn/go-sqlite3 v1.14.24 // indirect | ||||||
							
								
								
									
										2
									
								
								sqliteutil/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								sqliteutil/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= | ||||||
|  | github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= | ||||||
							
								
								
									
										82
									
								
								sqliteutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								sqliteutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | |||||||
|  | package sqliteutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"embed" | ||||||
|  | 	"errors" | ||||||
|  | 	"fmt" | ||||||
|  | 	"path/filepath" | ||||||
|  | 	"sort" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const initMigrationTableQuery = ` | ||||||
|  | CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` | ||||||
|  |  | ||||||
|  | const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` | ||||||
|  |  | ||||||
|  | const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` | ||||||
|  |  | ||||||
|  | func Migrate(db *sql.DB, migrationFS embed.FS) error { | ||||||
|  | 	return WithTx(db, func(tx *sql.Tx) error { | ||||||
|  | 		if _, err := tx.Exec(initMigrationTableQuery); err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		dirs, err := migrationFS.ReadDir(".") | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  | 		if len(dirs) != 1 { | ||||||
|  | 			return errors.New("expected a single migrations directory") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !dirs[0].IsDir() { | ||||||
|  | 			return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		dirName := dirs[0].Name() | ||||||
|  | 		files, err := migrationFS.ReadDir(dirName) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return err | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Sort sql files by name. | ||||||
|  | 		sort.Slice(files, func(i, j int) bool { | ||||||
|  | 			return files[i].Name() < files[j].Name() | ||||||
|  | 		}) | ||||||
|  |  | ||||||
|  | 		for _, dirEnt := range files { | ||||||
|  | 			if !dirEnt.Type().IsRegular() { | ||||||
|  | 				return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			var ( | ||||||
|  | 				name   = dirEnt.Name() | ||||||
|  | 				exists bool | ||||||
|  | 			) | ||||||
|  |  | ||||||
|  | 			err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if exists { | ||||||
|  | 				continue | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) | ||||||
|  | 			if err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if _, err := tx.Exec(string(migration)); err != nil { | ||||||
|  | 				return fmt.Errorf("migration %s failed: %v", name, err) | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			if _, err := tx.Exec(insertMigrationQuery, name); err != nil { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 		return nil | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										44
									
								
								sqliteutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								sqliteutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | |||||||
|  | package sqliteutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  | 	"embed" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | //go:embed test-migrations | ||||||
|  | var testMigrationFS embed.FS | ||||||
|  |  | ||||||
|  | func TestMigrate(t *testing.T) { | ||||||
|  | 	db, err := sql.Open("sqlite3", ":memory:") | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err := Migrate(db, testMigrationFS); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Shouldn't have any effect. | ||||||
|  | 	if err := Migrate(db, testMigrationFS); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` | ||||||
|  | 	var exists bool | ||||||
|  |  | ||||||
|  | 	if err = db.QueryRow(query, 1).Scan(&exists); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if exists { | ||||||
|  | 		t.Fatal("1 shouldn't exist") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err = db.QueryRow(query, 2).Scan(&exists); err != nil { | ||||||
|  | 		t.Fatal(err) | ||||||
|  | 	} | ||||||
|  | 	if !exists { | ||||||
|  | 		t.Fatal("2 should exist") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
							
								
								
									
										9
									
								
								sqliteutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								sqliteutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | |||||||
|  | CREATE TABLE users( | ||||||
|  |        UserID BIGINT NOT NULL PRIMARY KEY, | ||||||
|  |        Email  TEXT   NOT NULL UNIQUE); | ||||||
|  |  | ||||||
|  | CREATE TABLE user_notes( | ||||||
|  |        UserID BIGINT NOT NULL REFERENCES users(UserID), | ||||||
|  |        NoteID BIGINT NOT NULL, | ||||||
|  |        Note   Text   NOT NULL, | ||||||
|  |        PRIMARY KEY(UserID,NoteID)); | ||||||
							
								
								
									
										1
									
								
								sqliteutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								sqliteutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); | ||||||
							
								
								
									
										1
									
								
								sqliteutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								sqliteutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | DELETE FROM users WHERE UserID=1; | ||||||
							
								
								
									
										28
									
								
								sqliteutil/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								sqliteutil/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | |||||||
|  | package sqliteutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"database/sql" | ||||||
|  |  | ||||||
|  | 	_ "github.com/mattn/go-sqlite3" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // This is a convenience function to run a function within a transaction. | ||||||
|  | func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { | ||||||
|  | 	// Start a transaction. | ||||||
|  | 	tx, err := db.Begin() | ||||||
|  | 	if err != nil { | ||||||
|  | 		return err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = fn(tx) | ||||||
|  |  | ||||||
|  | 	if err == nil { | ||||||
|  | 		err = tx.Commit() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		_ = tx.Rollback() | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return err | ||||||
|  | } | ||||||
							
								
								
									
										2
									
								
								tagengine/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tagengine/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | |||||||
|  | # tagengine | ||||||
|  |  | ||||||
							
								
								
									
										3
									
								
								tagengine/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tagengine/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | |||||||
|  | module git.crumpington.com/lib/tagengine | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
							
								
								
									
										0
									
								
								tagengine/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tagengine/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										30
									
								
								tagengine/ngram.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tagengine/ngram.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import "unicode" | ||||||
|  |  | ||||||
|  | func ngramLength(s string) int { | ||||||
|  | 	N := len(s) | ||||||
|  | 	i := 0 | ||||||
|  | 	count := 0 | ||||||
|  |  | ||||||
|  | 	for { | ||||||
|  | 		// Eat spaces. | ||||||
|  | 		for i < N && unicode.IsSpace(rune(s[i])) { | ||||||
|  | 			i++ | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Done? | ||||||
|  | 		if i == N { | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Non-space! | ||||||
|  | 		count++ | ||||||
|  |  | ||||||
|  | 		// Eat non-spaces. | ||||||
|  | 		for i < N && !unicode.IsSpace(rune(s[i])) { | ||||||
|  | 			i++ | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return count | ||||||
|  | } | ||||||
							
								
								
									
										31
									
								
								tagengine/ngram_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								tagengine/ngram_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestNGramLength(t *testing.T) { | ||||||
|  | 	type Case struct { | ||||||
|  | 		Input  string | ||||||
|  | 		Length int | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []Case{ | ||||||
|  | 		{"a b c", 3}, | ||||||
|  | 		{"  xyz\nlkj  dflaj a", 4}, | ||||||
|  | 		{"a", 1}, | ||||||
|  | 		{" a", 1}, | ||||||
|  | 		{"a", 1}, | ||||||
|  | 		{" a\n", 1}, | ||||||
|  | 		{" a ", 1}, | ||||||
|  | 		{"\tx\ny\nz q ", 4}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		length := ngramLength(tc.Input) | ||||||
|  | 		if length != tc.Length { | ||||||
|  | 			log.Fatalf("%s: %d != %d", tc.Input, length, tc.Length) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										79
									
								
								tagengine/node.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								tagengine/node.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type node struct { | ||||||
|  | 	Token    string | ||||||
|  | 	Matches  []*Rule // If a list of tokens reaches this node, it matches these. | ||||||
|  | 	Children map[string]*node | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *node) AddRule(r *Rule) { | ||||||
|  | 	n.addRule(r, 0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *node) addRule(r *Rule, idx int) { | ||||||
|  | 	if len(r.Includes) == idx { | ||||||
|  | 		n.Matches = append(n.Matches, r) | ||||||
|  | 		return | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	token := r.Includes[idx] | ||||||
|  |  | ||||||
|  | 	child, ok := n.Children[token] | ||||||
|  | 	if !ok { | ||||||
|  | 		child = &node{ | ||||||
|  | 			Token:    token, | ||||||
|  | 			Children: map[string]*node{}, | ||||||
|  | 		} | ||||||
|  | 		n.Children[token] = child | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	child.addRule(r, idx+1) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Note that tokens must be sorted. This is the case for tokens created from | ||||||
|  | // the tokenize function. | ||||||
|  | func (n *node) Match(tokens []string) (rules []*Rule) { | ||||||
|  | 	return n.match(tokens, rules) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *node) match(tokens []string, rules []*Rule) []*Rule { | ||||||
|  | 	// Check for a match. | ||||||
|  | 	if n.Matches != nil { | ||||||
|  | 		rules = append(rules, n.Matches...) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(tokens) == 0 { | ||||||
|  | 		return rules | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Attempt to match children. | ||||||
|  | 	for i := 0; i < len(tokens); i++ { | ||||||
|  | 		token := tokens[i] | ||||||
|  | 		if child, ok := n.Children[token]; ok { | ||||||
|  | 			rules = child.match(tokens[i+1:], rules) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return rules | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *node) Dump() { | ||||||
|  | 	n.dump(0) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (n *node) dump(depth int) { | ||||||
|  | 	indent := strings.Repeat(" ", 2*depth) | ||||||
|  | 	tag := "" | ||||||
|  | 	for _, m := range n.Matches { | ||||||
|  | 		tag += " " + m.Tag | ||||||
|  | 	} | ||||||
|  | 	fmt.Printf("%s%s%s\n", indent, n.Token, tag) | ||||||
|  | 	for _, child := range n.Children { | ||||||
|  | 		child.dump(depth + 1) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										159
									
								
								tagengine/rule.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								tagengine/rule.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | type Rule struct { | ||||||
|  | 	// The purpose of a Rule is to attach it's Tag to matching text. | ||||||
|  | 	Tag string | ||||||
|  |  | ||||||
|  | 	// Includes is a list of strings that must be found in the input in order to | ||||||
|  | 	// match. | ||||||
|  | 	Includes []string | ||||||
|  |  | ||||||
|  | 	// Excludes is a list of strings that can exclude a match for this rule. | ||||||
|  | 	Excludes []string | ||||||
|  |  | ||||||
|  | 	// Blocks: If this rule is matched, then it will block matches of any tags | ||||||
|  | 	// listed here. | ||||||
|  | 	Blocks []string | ||||||
|  |  | ||||||
|  | 	// The Score encodes the complexity of the Rule. A higher score indicates a | ||||||
|  | 	// more specific match. A Rule more includes, or includes with multiple words | ||||||
|  | 	// should havee a higher Score than a Rule with fewer includes or less | ||||||
|  | 	// complex includes. | ||||||
|  | 	Score int | ||||||
|  |  | ||||||
|  | 	excludes map[string]struct{} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewRule(tag string) Rule { | ||||||
|  | 	return Rule{Tag: tag} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r Rule) Inc(l ...string) Rule { | ||||||
|  | 	return Rule{ | ||||||
|  | 		Tag:      r.Tag, | ||||||
|  | 		Includes: append(r.Includes, l...), | ||||||
|  | 		Excludes: r.Excludes, | ||||||
|  | 		Blocks:   r.Blocks, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r Rule) Exc(l ...string) Rule { | ||||||
|  | 	return Rule{ | ||||||
|  | 		Tag:      r.Tag, | ||||||
|  | 		Includes: r.Includes, | ||||||
|  | 		Excludes: append(r.Excludes, l...), | ||||||
|  | 		Blocks:   r.Blocks, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r Rule) Block(l ...string) Rule { | ||||||
|  | 	return Rule{ | ||||||
|  | 		Tag:      r.Tag, | ||||||
|  | 		Includes: r.Includes, | ||||||
|  | 		Excludes: r.Excludes, | ||||||
|  | 		Blocks:   append(r.Blocks, l...), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (rule *Rule) normalize(sanitize func(string) string) { | ||||||
|  | 	for i, token := range rule.Includes { | ||||||
|  | 		rule.Includes[i] = sanitize(token) | ||||||
|  | 	} | ||||||
|  | 	for i, token := range rule.Excludes { | ||||||
|  | 		rule.Excludes[i] = sanitize(token) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sortTokens(rule.Includes) | ||||||
|  | 	sortTokens(rule.Excludes) | ||||||
|  |  | ||||||
|  | 	rule.excludes = map[string]struct{}{} | ||||||
|  | 	for _, s := range rule.Excludes { | ||||||
|  | 		rule.excludes[s] = struct{}{} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rule.Score = rule.computeScore() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r Rule) maxNGram() int { | ||||||
|  | 	max := 0 | ||||||
|  | 	for _, s := range r.Includes { | ||||||
|  | 		n := ngramLength(s) | ||||||
|  | 		if n > max { | ||||||
|  | 			max = n | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	for _, s := range r.Excludes { | ||||||
|  | 		n := ngramLength(s) | ||||||
|  | 		if n > max { | ||||||
|  | 			max = n | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return max | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r Rule) isExcluded(tokens []string) bool { | ||||||
|  | 	// This is most often the case. | ||||||
|  | 	if len(r.excludes) == 0 { | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, s := range tokens { | ||||||
|  | 		if _, ok := r.excludes[s]; ok { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (r Rule) computeScore() (score int) { | ||||||
|  | 	for _, token := range r.Includes { | ||||||
|  | 		n := ngramLength(token) | ||||||
|  | 		score += n * (n + 1) / 2 | ||||||
|  | 	} | ||||||
|  | 	return score | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func ruleLess(lhs, rhs *Rule) bool { | ||||||
|  | 	// If scores differ, sort by score. | ||||||
|  | 	if lhs.Score != rhs.Score { | ||||||
|  | 		return lhs.Score < rhs.Score | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If include depth differs, sort by depth. | ||||||
|  | 	lDepth := len(lhs.Includes) | ||||||
|  | 	rDepth := len(rhs.Includes) | ||||||
|  |  | ||||||
|  | 	if lDepth != rDepth { | ||||||
|  | 		return lDepth < rDepth | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// If exclude depth differs, sort by depth. | ||||||
|  | 	lDepth = len(lhs.Excludes) | ||||||
|  | 	rDepth = len(rhs.Excludes) | ||||||
|  |  | ||||||
|  | 	if lDepth != rDepth { | ||||||
|  | 		return lDepth < rDepth | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Sort alphabetically by includes. | ||||||
|  | 	for i := range lhs.Includes { | ||||||
|  | 		if lhs.Includes[i] != rhs.Includes[i] { | ||||||
|  | 			return lhs.Includes[i] < rhs.Includes[i] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Sort by alphabetically by excludes. | ||||||
|  | 	for i := range lhs.Excludes { | ||||||
|  | 		if lhs.Excludes[i] != rhs.Excludes[i] { | ||||||
|  | 			return lhs.Excludes[i] < rhs.Excludes[i] | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Sort by tag. | ||||||
|  | 	if lhs.Tag != rhs.Tag { | ||||||
|  | 		return lhs.Tag < rhs.Tag | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return false | ||||||
|  | } | ||||||
							
								
								
									
										58
									
								
								tagengine/rulegroup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								tagengine/rulegroup.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | // A RuleGroup can be converted into a list of rules. Each rule will point to | ||||||
|  | // the same tag, and have the same exclude set and blocks. | ||||||
|  | type RuleGroup struct { | ||||||
|  | 	Tag      string | ||||||
|  | 	Includes [][]string | ||||||
|  | 	Excludes []string | ||||||
|  | 	Blocks   []string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewRuleGroup(tag string) RuleGroup { | ||||||
|  | 	return RuleGroup{ | ||||||
|  | 		Tag:      tag, | ||||||
|  | 		Includes: [][]string{}, | ||||||
|  | 		Excludes: []string{}, | ||||||
|  | 		Blocks:   []string{}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g RuleGroup) Inc(l ...string) RuleGroup { | ||||||
|  | 	return RuleGroup{ | ||||||
|  | 		Tag:      g.Tag, | ||||||
|  | 		Includes: append(g.Includes, l), | ||||||
|  | 		Excludes: g.Excludes, | ||||||
|  | 		Blocks:   g.Blocks, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g RuleGroup) Exc(l ...string) RuleGroup { | ||||||
|  | 	return RuleGroup{ | ||||||
|  | 		Tag:      g.Tag, | ||||||
|  | 		Includes: g.Includes, | ||||||
|  | 		Excludes: append(g.Excludes, l...), | ||||||
|  | 		Blocks:   g.Blocks, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g RuleGroup) Block(l ...string) RuleGroup { | ||||||
|  | 	return RuleGroup{ | ||||||
|  | 		Tag:      g.Tag, | ||||||
|  | 		Includes: g.Includes, | ||||||
|  | 		Excludes: g.Excludes, | ||||||
|  | 		Blocks:   append(g.Blocks, l...), | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (g RuleGroup) ToList() (l []Rule) { | ||||||
|  | 	for _, includes := range g.Includes { | ||||||
|  | 		l = append(l, Rule{ | ||||||
|  | 			Tag:      g.Tag, | ||||||
|  | 			Excludes: g.Excludes, | ||||||
|  | 			Includes: includes, | ||||||
|  | 			Blocks:   g.Blocks, | ||||||
|  | 		}) | ||||||
|  | 	} | ||||||
|  | 	return | ||||||
|  | } | ||||||
							
								
								
									
										162
									
								
								tagengine/ruleset.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								tagengine/ruleset.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"sort" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type RuleSet struct { | ||||||
|  | 	root     *node | ||||||
|  | 	maxNgram int | ||||||
|  | 	sanitize func(string) string | ||||||
|  | 	rules    []*Rule | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewRuleSet() *RuleSet { | ||||||
|  | 	return &RuleSet{ | ||||||
|  | 		root: &node{ | ||||||
|  | 			Token:    "/", | ||||||
|  | 			Children: map[string]*node{}, | ||||||
|  | 		}, | ||||||
|  | 		sanitize: BasicSanitizer, | ||||||
|  | 		rules:    []*Rule{}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewRuleSetFromList(rules []Rule) *RuleSet { | ||||||
|  | 	rs := NewRuleSet() | ||||||
|  | 	rs.AddRule(rules...) | ||||||
|  | 	return rs | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *RuleSet) Add(ruleOrGroup ...interface{}) { | ||||||
|  | 	for _, ix := range ruleOrGroup { | ||||||
|  | 		switch x := ix.(type) { | ||||||
|  | 		case Rule: | ||||||
|  | 			t.AddRule(x) | ||||||
|  | 		case RuleGroup: | ||||||
|  | 			t.AddRuleGroup(x) | ||||||
|  | 		default: | ||||||
|  | 			panic("Add expects either Rule or RuleGroup objects.") | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *RuleSet) AddRule(rules ...Rule) { | ||||||
|  | 	for _, rule := range rules { | ||||||
|  | 		rule := rule | ||||||
|  |  | ||||||
|  | 		// Make sure rule is well-formed. | ||||||
|  | 		rule.normalize(t.sanitize) | ||||||
|  |  | ||||||
|  | 		// Update maxNgram. | ||||||
|  | 		N := rule.maxNGram() | ||||||
|  | 		if N > t.maxNgram { | ||||||
|  | 			t.maxNgram = N | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		t.rules = append(t.rules, &rule) | ||||||
|  | 		t.root.AddRule(&rule) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) { | ||||||
|  | 	for _, rg := range ruleGroups { | ||||||
|  | 		t.AddRule(rg.ToList()...) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // MatchRules will return a list of all matching rules. The rules are sorted by | ||||||
|  | // the match's Score. The best match will be first. | ||||||
|  | func (t *RuleSet) MatchRules(input string) (rules []*Rule) { | ||||||
|  | 	input = t.sanitize(input) | ||||||
|  | 	tokens := Tokenize(input, t.maxNgram) | ||||||
|  |  | ||||||
|  | 	rules = t.root.Match(tokens) | ||||||
|  | 	if len(rules) == 0 { | ||||||
|  | 		return rules | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Check excludes. | ||||||
|  | 	l := rules[:0] | ||||||
|  | 	for _, r := range rules { | ||||||
|  | 		if !r.isExcluded(tokens) { | ||||||
|  | 			l = append(l, r) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	rules = l | ||||||
|  |  | ||||||
|  | 	// Sort rules descending. | ||||||
|  | 	sort.Slice(rules, func(i, j int) bool { | ||||||
|  | 		return ruleLess(rules[j], rules[i]) | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	return rules | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type Match struct { | ||||||
|  | 	Tag string | ||||||
|  |  | ||||||
|  | 	// Confidence is used to sort all matches, and is normalized so the sum of | ||||||
|  | 	// Confidence values for all matches is 1. Confidence is relative to the | ||||||
|  | 	// number of matches and the size of matches in terms of number of tokens. | ||||||
|  | 	Confidence float64 // In the range (0,1]. | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Return a list of matches with confidence. This is useful if you'd like to | ||||||
|  | // find the best matching rule out of all the matched rules. | ||||||
|  | // | ||||||
|  | // If you just want to find all matching rules, then use MatchRules. | ||||||
|  | func (t *RuleSet) Match(input string) []Match { | ||||||
|  | 	rules := t.MatchRules(input) | ||||||
|  | 	if len(rules) == 0 { | ||||||
|  | 		return []Match{} | ||||||
|  | 	} | ||||||
|  | 	if len(rules) == 1 { | ||||||
|  | 		return []Match{{ | ||||||
|  | 			Tag:        rules[0].Tag, | ||||||
|  | 			Confidence: 1, | ||||||
|  | 		}} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Create list of blocked tags. | ||||||
|  | 	blocks := map[string]struct{}{} | ||||||
|  | 	for _, rule := range rules { | ||||||
|  | 		for _, tag := range rule.Blocks { | ||||||
|  | 			blocks[tag] = struct{}{} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Remove rules for blocked tags. | ||||||
|  | 	iOut := 0 | ||||||
|  | 	for _, rule := range rules { | ||||||
|  | 		if _, ok := blocks[rule.Tag]; ok { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  | 		rules[iOut] = rule | ||||||
|  | 		iOut++ | ||||||
|  | 	} | ||||||
|  | 	rules = rules[:iOut] | ||||||
|  |  | ||||||
|  | 	// Matches by index. | ||||||
|  | 	matches := map[string]int{} | ||||||
|  | 	out := []Match{} | ||||||
|  | 	sum := float64(0) | ||||||
|  |  | ||||||
|  | 	for _, rule := range rules { | ||||||
|  | 		idx, ok := matches[rule.Tag] | ||||||
|  | 		if !ok { | ||||||
|  | 			idx = len(matches) | ||||||
|  | 			matches[rule.Tag] = idx | ||||||
|  | 			out = append(out, Match{Tag: rule.Tag}) | ||||||
|  | 		} | ||||||
|  | 		out[idx].Confidence += float64(rule.Score) | ||||||
|  | 		sum += float64(rule.Score) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := range out { | ||||||
|  | 		out[i].Confidence /= sum | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return out | ||||||
|  | } | ||||||
							
								
								
									
										84
									
								
								tagengine/ruleset_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								tagengine/ruleset_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestRulesSet(t *testing.T) { | ||||||
|  | 	rs := NewRuleSet() | ||||||
|  | 	rs.AddRule(Rule{ | ||||||
|  | 		Tag:      "cc/2", | ||||||
|  | 		Includes: []string{"cola", "coca"}, | ||||||
|  | 	}) | ||||||
|  | 	rs.AddRule(Rule{ | ||||||
|  | 		Tag:      "cc/0", | ||||||
|  | 		Includes: []string{"coca cola"}, | ||||||
|  | 	}) | ||||||
|  | 	rs.AddRule(Rule{ | ||||||
|  | 		Tag:      "cz/2", | ||||||
|  | 		Includes: []string{"coca", "zero"}, | ||||||
|  | 	}) | ||||||
|  | 	rs.AddRule(Rule{ | ||||||
|  | 		Tag:      "cc0/3", | ||||||
|  | 		Includes: []string{"zero", "coca", "cola"}, | ||||||
|  | 	}) | ||||||
|  | 	rs.AddRule(Rule{ | ||||||
|  | 		Tag:      "cc0/3.1", | ||||||
|  | 		Includes: []string{"coca", "cola", "zero"}, | ||||||
|  | 		Excludes: []string{"pepsi"}, | ||||||
|  | 	}) | ||||||
|  | 	rs.AddRule(Rule{ | ||||||
|  | 		Tag:      "spa", | ||||||
|  | 		Includes: []string{"spa"}, | ||||||
|  | 		Blocks:   []string{"cc/0", "cc0/3", "cc0/3.1"}, | ||||||
|  | 	}) | ||||||
|  |  | ||||||
|  | 	type TestCase struct { | ||||||
|  | 		Input   string | ||||||
|  | 		Matches []Match | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []TestCase{ | ||||||
|  | 		{ | ||||||
|  | 			Input: "coca-cola zero", | ||||||
|  | 			Matches: []Match{ | ||||||
|  | 				{"cc0/3.1", 0.3}, | ||||||
|  | 				{"cc0/3", 0.3}, | ||||||
|  | 				{"cz/2", 0.2}, | ||||||
|  | 				{"cc/2", 0.2}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			Input: "coca cola", | ||||||
|  | 			Matches: []Match{ | ||||||
|  | 				{"cc/0", 0.6}, | ||||||
|  | 				{"cc/2", 0.4}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			Input: "coca cola zero pepsi", | ||||||
|  | 			Matches: []Match{ | ||||||
|  | 				{"cc0/3", 0.3}, | ||||||
|  | 				{"cc/0", 0.3}, | ||||||
|  | 				{"cz/2", 0.2}, | ||||||
|  | 				{"cc/2", 0.2}, | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			Input:   "fanta orange", | ||||||
|  | 			Matches: []Match{}, | ||||||
|  | 		}, { | ||||||
|  | 			Input: "coca-cola zero / fanta / spa", | ||||||
|  | 			Matches: []Match{ | ||||||
|  | 				{"cz/2", 0.4}, | ||||||
|  | 				{"cc/2", 0.4}, | ||||||
|  | 				{"spa", 0.2}, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		matches := rs.Match(tc.Input) | ||||||
|  | 		if !reflect.DeepEqual(matches, tc.Matches) { | ||||||
|  | 			t.Fatalf("%v != %v", matches, tc.Matches) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										20
									
								
								tagengine/sanitize.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								tagengine/sanitize.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"git.crumpington.com/lib/tagengine/sanitize" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // The basic sanitizer: | ||||||
|  | // * lower-case | ||||||
|  | // * put spaces around numbers | ||||||
|  | // * put slaces around punctuation | ||||||
|  | // * collapse multiple spaces | ||||||
|  | func BasicSanitizer(s string) string { | ||||||
|  | 	s = strings.ToLower(s) | ||||||
|  | 	s = sanitize.SpaceNumbers(s) | ||||||
|  | 	s = sanitize.SpacePunctuation(s) | ||||||
|  | 	s = sanitize.CollapseSpaces(s) | ||||||
|  | 	return s | ||||||
|  | } | ||||||
							
								
								
									
										91
									
								
								tagengine/sanitize/sanitize.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								tagengine/sanitize/sanitize.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,91 @@ | |||||||
|  | package sanitize | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"strings" | ||||||
|  | 	"unicode" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func SpaceNumbers(s string) string { | ||||||
|  | 	if len(s) == 0 { | ||||||
|  | 		return s | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	isDigit := func(b rune) bool { | ||||||
|  | 		switch b { | ||||||
|  | 		case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	b := strings.Builder{} | ||||||
|  |  | ||||||
|  | 	var first rune | ||||||
|  | 	for _, c := range s { | ||||||
|  | 		first = c | ||||||
|  | 		break | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	digit := isDigit(first) | ||||||
|  |  | ||||||
|  | 	// Range over runes. | ||||||
|  | 	for _, c := range s { | ||||||
|  | 		thisDigit := isDigit(c) | ||||||
|  | 		if thisDigit != digit { | ||||||
|  | 			b.WriteByte(' ') | ||||||
|  | 			digit = thisDigit | ||||||
|  | 		} | ||||||
|  | 		b.WriteRune(c) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func SpacePunctuation(s string) string { | ||||||
|  | 	needsSpace := func(r rune) bool { | ||||||
|  | 		switch r { | ||||||
|  | 		case '`', '~', '!', '@', '#', '%', '^', '&', '*', '(', ')', | ||||||
|  | 			'-', '_', '+', '=', '[', '{', ']', '}', '\\', '|', | ||||||
|  | 			':', ';', '"', '\'', ',', '<', '.', '>', '?', '/': | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 		return false | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	b := strings.Builder{} | ||||||
|  |  | ||||||
|  | 	// Range over runes. | ||||||
|  | 	for _, r := range s { | ||||||
|  | 		if needsSpace(r) { | ||||||
|  | 			b.WriteRune(' ') | ||||||
|  | 			b.WriteRune(r) | ||||||
|  | 			b.WriteRune(' ') | ||||||
|  | 		} else { | ||||||
|  | 			b.WriteRune(r) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func CollapseSpaces(s string) string { | ||||||
|  | 	// Trim leading and trailing spaces. | ||||||
|  | 	s = strings.TrimSpace(s) | ||||||
|  |  | ||||||
|  | 	b := strings.Builder{} | ||||||
|  | 	wasSpace := false | ||||||
|  |  | ||||||
|  | 	// Range over runes. | ||||||
|  | 	for _, c := range s { | ||||||
|  | 		if unicode.IsSpace(c) { | ||||||
|  | 			wasSpace = true | ||||||
|  | 			continue | ||||||
|  | 		} else if wasSpace { | ||||||
|  | 			wasSpace = false | ||||||
|  | 			b.WriteRune(' ') | ||||||
|  | 		} | ||||||
|  | 		b.WriteRune(c) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return b.String() | ||||||
|  | } | ||||||
							
								
								
									
										30
									
								
								tagengine/sanitize_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tagengine/sanitize_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import "testing" | ||||||
|  |  | ||||||
|  | func TestSanitize(t *testing.T) { | ||||||
|  | 	sanitize := BasicSanitizer | ||||||
|  |  | ||||||
|  | 	type Case struct { | ||||||
|  | 		In  string | ||||||
|  | 		Out string | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []Case{ | ||||||
|  | 		{"", ""}, | ||||||
|  | 		{"123abc", "123 abc"}, | ||||||
|  | 		{"abc123", "abc 123"}, | ||||||
|  | 		{"abc123xyz", "abc 123 xyz"}, | ||||||
|  | 		{"1f2", "1 f 2"}, | ||||||
|  | 		{" abc", "abc"}, | ||||||
|  | 		{" ; KitKat/m&m's (bottle) @ ", "; kitkat / m & m ' s ( bottle ) @"}, | ||||||
|  | 		{"€", "€"}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		out := sanitize(tc.In) | ||||||
|  | 		if out != tc.Out { | ||||||
|  | 			t.Fatalf("%v != %v", out, tc.Out) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										63
									
								
								tagengine/tokenize.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								tagengine/tokenize.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"sort" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var ignoreTokens = map[string]struct{}{} | ||||||
|  |  | ||||||
|  | func init() { | ||||||
|  | 	// These on their own are ignored. | ||||||
|  | 	tokens := []string{ | ||||||
|  | 		"`", `~`, `!`, `@`, `#`, `%`, `^`, `&`, `*`, `(`, `)`, | ||||||
|  | 		`-`, `_`, `+`, `=`, `[`, `{`, `]`, `}`, `\`, `|`, | ||||||
|  | 		`:`, `;`, `"`, `'`, `,`, `<`, `.`, `>`, `?`, `/`, | ||||||
|  | 	} | ||||||
|  | 	for _, s := range tokens { | ||||||
|  | 		ignoreTokens[s] = struct{}{} | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func Tokenize( | ||||||
|  | 	input string, | ||||||
|  | 	maxNgram int, | ||||||
|  | ) ( | ||||||
|  | 	tokens []string, | ||||||
|  | ) { | ||||||
|  | 	// Avoid duplicate ngrams. | ||||||
|  | 	ignored := map[string]bool{} | ||||||
|  |  | ||||||
|  | 	fields := strings.Fields(input) | ||||||
|  |  | ||||||
|  | 	if len(fields) < maxNgram { | ||||||
|  | 		maxNgram = len(fields) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for i := 1; i < maxNgram+1; i++ { | ||||||
|  | 		jMax := len(fields) - i + 1 | ||||||
|  |  | ||||||
|  | 		for j := 0; j < jMax; j++ { | ||||||
|  | 			ngram := strings.Join(fields[j:i+j], " ") | ||||||
|  | 			if _, ok := ignoreTokens[ngram]; !ok { | ||||||
|  | 				if _, ok := ignored[ngram]; !ok { | ||||||
|  | 					tokens = append(tokens, ngram) | ||||||
|  | 					ignored[ngram] = true | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	sortTokens(tokens) | ||||||
|  |  | ||||||
|  | 	return tokens | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func sortTokens(tokens []string) { | ||||||
|  | 	sort.Slice(tokens, func(i, j int) bool { | ||||||
|  | 		if len(tokens[i]) != len(tokens[j]) { | ||||||
|  | 			return len(tokens[i]) < len(tokens[j]) | ||||||
|  | 		} | ||||||
|  | 		return tokens[i] < tokens[j] | ||||||
|  | 	}) | ||||||
|  | } | ||||||
							
								
								
									
										55
									
								
								tagengine/tokenize_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								tagengine/tokenize_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | |||||||
|  | package tagengine | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"reflect" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func TestTokenize(t *testing.T) { | ||||||
|  | 	type Case struct { | ||||||
|  | 		Input    string | ||||||
|  | 		MaxNgram int | ||||||
|  | 		Output   []string | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []Case{ | ||||||
|  | 		{ | ||||||
|  | 			Input:    "a bb c d", | ||||||
|  | 			MaxNgram: 3, | ||||||
|  | 			Output: []string{ | ||||||
|  | 				"a", "c", "d", "bb", | ||||||
|  | 				"c d", "a bb", "bb c", | ||||||
|  | 				"a bb c", "bb c d", | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			Input:    "a b", | ||||||
|  | 			MaxNgram: 3, | ||||||
|  | 			Output: []string{ | ||||||
|  | 				"a", "b", "a b", | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			Input:    "- b c d", | ||||||
|  | 			MaxNgram: 3, | ||||||
|  | 			Output: []string{ | ||||||
|  | 				"b", "c", "d", | ||||||
|  | 				"- b", "b c", "c d", | ||||||
|  | 				"- b c", "b c d", | ||||||
|  | 			}, | ||||||
|  | 		}, { | ||||||
|  | 			Input:    "a a b c d c d", | ||||||
|  | 			MaxNgram: 3, | ||||||
|  | 			Output: []string{ | ||||||
|  | 				"a", "b", "c", "d", | ||||||
|  | 				"a a", "a b", "b c", "c d", "d c", | ||||||
|  | 				"a a b", "a b c", "b c d", "c d c", "d c d", | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		output := Tokenize(tc.Input, tc.MaxNgram) | ||||||
|  | 		if !reflect.DeepEqual(output, tc.Output) { | ||||||
|  | 			t.Fatalf("%s: %#v", tc.Input, output) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										5
									
								
								webutil/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								webutil/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | |||||||
|  | # webutil | ||||||
|  |  | ||||||
|  | ## Roadmap | ||||||
|  |  | ||||||
|  | * logging middleware | ||||||
							
								
								
									
										10
									
								
								webutil/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								webutil/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | |||||||
|  | module git.crumpington.com/lib/webutil | ||||||
|  |  | ||||||
|  | go 1.23.2 | ||||||
|  |  | ||||||
|  | require golang.org/x/crypto v0.28.0 | ||||||
|  |  | ||||||
|  | require ( | ||||||
|  | 	golang.org/x/net v0.21.0 // indirect | ||||||
|  | 	golang.org/x/text v0.19.0 // indirect | ||||||
|  | ) | ||||||
							
								
								
									
										6
									
								
								webutil/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								webutil/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | |||||||
|  | golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= | ||||||
|  | golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= | ||||||
|  | golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= | ||||||
|  | golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= | ||||||
|  | golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= | ||||||
|  | golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= | ||||||
							
								
								
									
										24
									
								
								webutil/listenandserve.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								webutil/listenandserve.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | |||||||
|  | package webutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"errors" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/crypto/acme/autocert" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // Serve requests using the given http.Server. If srv.Addr has the format | ||||||
|  | // `hostname:https`, then use autocert to manage certificates for the domain. | ||||||
|  | // | ||||||
|  | // For http on port 80, you can use :http. | ||||||
|  | func ListenAndServe(srv *http.Server) error { | ||||||
|  | 	if strings.HasSuffix(srv.Addr, ":https") { | ||||||
|  | 		hostname := strings.TrimSuffix(srv.Addr, ":https") | ||||||
|  | 		if len(hostname) == 0 { | ||||||
|  | 			return errors.New("https requires a hostname") | ||||||
|  | 		} | ||||||
|  | 		return srv.Serve(autocert.NewListener(hostname)) | ||||||
|  | 	} | ||||||
|  | 	return srv.ListenAndServe() | ||||||
|  | } | ||||||
							
								
								
									
										47
									
								
								webutil/middleware-logging.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								webutil/middleware-logging.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | |||||||
|  | package webutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"log" | ||||||
|  | 	"net/http" | ||||||
|  | 	"os" | ||||||
|  | 	"time" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var _log = log.New(os.Stderr, "", 0) | ||||||
|  |  | ||||||
|  | type responseWriterWrapper struct { | ||||||
|  | 	http.ResponseWriter | ||||||
|  | 	httpStatus   int | ||||||
|  | 	responseSize int | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *responseWriterWrapper) WriteHeader(status int) { | ||||||
|  | 	w.httpStatus = status | ||||||
|  | 	w.ResponseWriter.WriteHeader(status) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (w *responseWriterWrapper) Write(b []byte) (int, error) { | ||||||
|  | 	if w.httpStatus == 0 { | ||||||
|  | 		w.httpStatus = 200 | ||||||
|  | 	} | ||||||
|  | 	w.responseSize += len(b) | ||||||
|  | 	return w.ResponseWriter.Write(b) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func WithLogging(inner http.HandlerFunc) http.HandlerFunc { | ||||||
|  | 	return func(w http.ResponseWriter, r *http.Request) { | ||||||
|  | 		t := time.Now() | ||||||
|  | 		wrapper := responseWriterWrapper{w, 0, 0} | ||||||
|  |  | ||||||
|  | 		inner(&wrapper, r) | ||||||
|  | 		_log.Printf("%s \"%s %s %s\" %d %d %v\n", | ||||||
|  | 			r.RemoteAddr, | ||||||
|  | 			r.Method, | ||||||
|  | 			r.URL.Path, | ||||||
|  | 			r.Proto, | ||||||
|  | 			wrapper.httpStatus, | ||||||
|  | 			wrapper.responseSize, | ||||||
|  | 			time.Since(t), | ||||||
|  | 		) | ||||||
|  | 	} | ||||||
|  | } | ||||||
							
								
								
									
										100
									
								
								webutil/template.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								webutil/template.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | |||||||
|  | package webutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"embed" | ||||||
|  | 	"html/template" | ||||||
|  | 	"io/fs" | ||||||
|  | 	"log" | ||||||
|  | 	"path" | ||||||
|  | 	"strings" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | // ParseTemplateSet parses sets of templates from an embed.FS. | ||||||
|  | // | ||||||
|  | // Each directory constitutes a set of templates that are parsed together. | ||||||
|  | // | ||||||
|  | // Structure (within a directory): | ||||||
|  | //   - share/* are always parsed. | ||||||
|  | //   - base.html will be parsed with each other file in same dir | ||||||
|  | // | ||||||
|  | // Call a template with m[path].Execute(w, data) (root dir name is excluded). | ||||||
|  | // | ||||||
|  | // For example, if you have | ||||||
|  | //   - /user/share/* | ||||||
|  | //   - /user/base.html | ||||||
|  | //   - /user/home.html | ||||||
|  | // | ||||||
|  | // Then you call m["/user/home.html"].Execute(w, data). | ||||||
|  | func ParseTemplateSet(funcs template.FuncMap, fs embed.FS) map[string]*template.Template { | ||||||
|  | 	m := map[string]*template.Template{} | ||||||
|  | 	rootDir := readDir(fs, ".")[0].Name() | ||||||
|  | 	loadTemplateDir(fs, funcs, m, rootDir, rootDir) | ||||||
|  | 	return m | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func loadTemplateDir( | ||||||
|  | 	fs embed.FS, | ||||||
|  | 	funcs template.FuncMap, | ||||||
|  | 	m map[string]*template.Template, | ||||||
|  | 	dirPath string, | ||||||
|  | 	rootDir string, | ||||||
|  | ) map[string]*template.Template { | ||||||
|  | 	t := template.New("") | ||||||
|  | 	if funcs != nil { | ||||||
|  | 		t = t.Funcs(funcs) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	shareDir := path.Join(dirPath, "share") | ||||||
|  | 	if _, err := fs.ReadDir(shareDir); err == nil { | ||||||
|  | 		log.Printf("Parsing %s...", path.Join(shareDir, "*")) | ||||||
|  | 		t = template.Must(t.ParseFS(fs, path.Join(shareDir, "*"))) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if data, _ := fs.ReadFile(path.Join(dirPath, "base.html")); data != nil { | ||||||
|  | 		log.Printf("Parsing %s...", path.Join(dirPath, "base.html")) | ||||||
|  | 		t = template.Must(t.Parse(string(data))) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, ent := range readDir(fs, dirPath) { | ||||||
|  | 		if ent.Type().IsDir() { | ||||||
|  | 			if ent.Name() != "share" { | ||||||
|  | 				m = loadTemplateDir(fs, funcs, m, path.Join(dirPath, ent.Name()), rootDir) | ||||||
|  | 			} | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if !ent.Type().IsRegular() { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if ent.Name() == "base.html" { | ||||||
|  | 			continue | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		filePath := path.Join(dirPath, ent.Name()) | ||||||
|  | 		log.Printf("Parsing %s...", filePath) | ||||||
|  |  | ||||||
|  | 		key := strings.TrimPrefix(path.Join(dirPath, ent.Name()), rootDir) | ||||||
|  | 		tt := template.Must(t.Clone()) | ||||||
|  | 		tt = template.Must(tt.Parse(readFile(fs, filePath))) | ||||||
|  | 		m[key] = tt | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return m | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func readDir(fs embed.FS, dirPath string) []fs.DirEntry { | ||||||
|  | 	ents, err := fs.ReadDir(dirPath) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	return ents | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func readFile(fs embed.FS, path string) string { | ||||||
|  | 	data, err := fs.ReadFile(path) | ||||||
|  | 	if err != nil { | ||||||
|  | 		panic(err) | ||||||
|  | 	} | ||||||
|  | 	return string(data) | ||||||
|  | } | ||||||
							
								
								
									
										49
									
								
								webutil/template_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								webutil/template_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | |||||||
|  | package webutil | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"bytes" | ||||||
|  | 	"embed" | ||||||
|  | 	"html/template" | ||||||
|  | 	"strings" | ||||||
|  | 	"testing" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | //go:embed all:test-templates | ||||||
|  | var testFS embed.FS | ||||||
|  |  | ||||||
|  | func TestParseTemplateSet(t *testing.T) { | ||||||
|  | 	funcs := template.FuncMap{"join": strings.Join} | ||||||
|  | 	m := ParseTemplateSet(funcs, testFS) | ||||||
|  |  | ||||||
|  | 	type TestCase struct { | ||||||
|  | 		Key  string | ||||||
|  | 		Data any | ||||||
|  | 		Out  string | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cases := []TestCase{ | ||||||
|  | 		{ | ||||||
|  | 			Key:  "/home.html", | ||||||
|  | 			Data: "DATA", | ||||||
|  | 			Out:  "<p>HOME!</p>", | ||||||
|  | 		}, { | ||||||
|  | 			Key:  "/about.html", | ||||||
|  | 			Data: "DATA", | ||||||
|  | 			Out:  "<p><b>DATA</b></p>", | ||||||
|  | 		}, { | ||||||
|  | 			Key:  "/contact.html", | ||||||
|  | 			Data: []string{"a", "b", "c"}, | ||||||
|  | 			Out:  "<p>a,b,c</p>", | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, tc := range cases { | ||||||
|  | 		b := &bytes.Buffer{} | ||||||
|  | 		m[tc.Key].Execute(b, tc.Data) | ||||||
|  | 		out := strings.TrimSpace(b.String()) | ||||||
|  | 		if out != tc.Out { | ||||||
|  | 			t.Fatalf("%s != %s", out, tc.Out) | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | } | ||||||
							
								
								
									
										1
									
								
								webutil/test-templates/about.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/about.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | {{define "body"}}{{template "bold" .}}{{end}} | ||||||
							
								
								
									
										1
									
								
								webutil/test-templates/base.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/base.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | <p>{{block "body" .}}default{{end}}</p> | ||||||
							
								
								
									
										1
									
								
								webutil/test-templates/contact.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/contact.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | {{define "body"}}{{join . ","}}{{end}} | ||||||
							
								
								
									
										1
									
								
								webutil/test-templates/home.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/home.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | |||||||
|  | {{define "body"}}HOME!{{end}} | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user