wip
This commit is contained in:
		
							
								
								
									
										2
									
								
								flock/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								flock/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # flock | ||||
|  | ||||
							
								
								
									
										69
									
								
								flock/flock.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								flock/flock.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,69 @@ | ||||
| // The flock package provides a file-system mediated locking mechanism on linux | ||||
| // using the `flock` system call. | ||||
|  | ||||
| package flock | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"syscall" | ||||
| ) | ||||
|  | ||||
| // Lock gets an exclusive lock on the file at the given path. If the file | ||||
| // doesn't exist, it's created. | ||||
| func Lock(path string) (*os.File, error) { | ||||
| 	return lock(path, syscall.LOCK_EX) | ||||
| } | ||||
|  | ||||
| // TryLock will return a nil file if the file is already locked. | ||||
| func TryLock(path string) (*os.File, error) { | ||||
| 	return lock(path, syscall.LOCK_EX|syscall.LOCK_NB) | ||||
| } | ||||
|  | ||||
| func LockFile(f *os.File) error { | ||||
| 	_, err := lockFile(f, syscall.LOCK_EX) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Returns true if the lock was successfully acquired. | ||||
| func TryLockFile(f *os.File) (bool, error) { | ||||
| 	return lockFile(f, syscall.LOCK_EX|syscall.LOCK_NB) | ||||
| } | ||||
|  | ||||
| func lockFile(f *os.File, flags int) (bool, error) { | ||||
|  | ||||
| 	if err := flock(int(f.Fd()), flags); err != nil { | ||||
| 		if flags&syscall.LOCK_NB != 0 && errors.Is(err, syscall.EAGAIN) { | ||||
| 			return false, nil | ||||
| 		} | ||||
| 		return false, err | ||||
| 	} | ||||
| 	return true, nil | ||||
| } | ||||
|  | ||||
| func flock(fd int, how int) error { | ||||
| 	_, _, e1 := syscall.Syscall(syscall.SYS_FLOCK, uintptr(fd), uintptr(how), 0) | ||||
| 	if e1 != 0 { | ||||
| 		return syscall.Errno(e1) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func lock(path string, flags int) (*os.File, error) { | ||||
| 	perm := os.O_CREATE | os.O_RDWR | ||||
| 	f, err := os.OpenFile(path, perm, 0600) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	ok, err := lockFile(f, flags) | ||||
| 	if err != nil || !ok { | ||||
| 		f.Close() | ||||
| 		f = nil | ||||
| 	} | ||||
| 	return f, err | ||||
| } | ||||
|  | ||||
| // Unlock releases the lock acquired via the Lock function. | ||||
| func Unlock(f *os.File) error { | ||||
| 	return f.Close() | ||||
| } | ||||
							
								
								
									
										66
									
								
								flock/flock_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								flock/flock_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| package flock | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func Test_Lock_basic(t *testing.T) { | ||||
| 	ch := make(chan int, 1) | ||||
| 	f, err := Lock("/tmp/fsutil-test-lock") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	go func() { | ||||
| 		time.Sleep(time.Second) | ||||
| 		ch <- 10 | ||||
| 		Unlock(f) | ||||
| 	}() | ||||
|  | ||||
| 	select { | ||||
| 	case x := <-ch: | ||||
| 		t.Fatal(x) | ||||
| 	default: | ||||
|  | ||||
| 	} | ||||
|  | ||||
| 	f2, _ := Lock("/tmp/fsutil-test-lock") | ||||
| 	defer Unlock(f2) | ||||
| 	select { | ||||
| 	case i := <-ch: | ||||
| 		if i != 10 { | ||||
| 			t.Fatal(i) | ||||
| 		} | ||||
| 	default: | ||||
| 		t.Fatal("No value available.") | ||||
| 	} | ||||
|  | ||||
| } | ||||
|  | ||||
| func Test_Lock_badPath(t *testing.T) { | ||||
| 	_, err := Lock("./dne/file.lock") | ||||
| 	if err == nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestTryLock(t *testing.T) { | ||||
| 	lockPath := "/tmp/fsutil-test-lock" | ||||
| 	f, err := TryLock(lockPath) | ||||
| 	if err != nil { | ||||
| 		t.Fatalf("%#v", err) | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	f2, err := TryLock(lockPath) | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if f2 != nil { | ||||
| 		t.Fatal(f2) | ||||
| 	} | ||||
|  | ||||
| 	if err := Unlock(f); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										3
									
								
								flock/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								flock/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/flock | ||||
|  | ||||
| go 1.23.0 | ||||
							
								
								
									
										0
									
								
								flock/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								flock/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										2
									
								
								httpconn/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								httpconn/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # httpconn | ||||
|  | ||||
							
								
								
									
										104
									
								
								httpconn/client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										104
									
								
								httpconn/client.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,104 @@ | ||||
| package httpconn | ||||
|  | ||||
| import ( | ||||
| 	"bufio" | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/url" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var ( | ||||
| 	ErrUnknownScheme = errors.New("uknown scheme") | ||||
| ) | ||||
|  | ||||
| type Dialer struct { | ||||
| 	timeout time.Duration | ||||
| } | ||||
|  | ||||
| func NewDialer() *Dialer { | ||||
| 	return &Dialer{timeout: 10 * time.Second} | ||||
| } | ||||
|  | ||||
| func (d *Dialer) SetTimeout(timeout time.Duration) { | ||||
| 	d.timeout = timeout | ||||
| } | ||||
|  | ||||
| func (d *Dialer) Dial(rawURL string) (net.Conn, error) { | ||||
| 	u, err := url.Parse(rawURL) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	switch u.Scheme { | ||||
| 	case "https": | ||||
| 		return d.DialHTTPS(u.Host+":443", u.Path) | ||||
| 	case "http": | ||||
| 		return d.DialHTTP(u.Host, u.Path) | ||||
| 	default: | ||||
| 		return nil, ErrUnknownScheme | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) { | ||||
| 	ctx, cancel := context.WithTimeout(context.Background(), d.timeout) | ||||
| 	dd := tls.Dialer{} | ||||
| 	conn, err := dd.DialContext(ctx, "tcp", host) | ||||
| 	cancel() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return d.finishDialing(conn, host, path) | ||||
|  | ||||
| } | ||||
|  | ||||
| func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) { | ||||
| 	conn, err := net.DialTimeout("tcp", host, d.timeout) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return d.finishDialing(conn, host, path) | ||||
| } | ||||
|  | ||||
| func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) { | ||||
| 	conn.SetDeadline(time.Now().Add(d.timeout)) | ||||
|  | ||||
| 	if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	// Require successful HTTP response before using the conn. | ||||
| 	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"}) | ||||
| 	if err != nil { | ||||
| 		conn.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if resp.Status != "200 OK" { | ||||
| 		conn.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	conn.SetDeadline(time.Time{}) | ||||
|  | ||||
| 	return conn, nil | ||||
| } | ||||
|  | ||||
| func Dial(rawURL string) (net.Conn, error) { | ||||
| 	return NewDialer().Dial(rawURL) | ||||
| } | ||||
|  | ||||
| func DialHTTPS(host, path string) (net.Conn, error) { | ||||
| 	return NewDialer().DialHTTPS(host, path) | ||||
| } | ||||
|  | ||||
| func DialHTTP(host, path string) (net.Conn, error) { | ||||
| 	return NewDialer().DialHTTP(host, path) | ||||
| } | ||||
							
								
								
									
										42
									
								
								httpconn/conn_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								httpconn/conn_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package httpconn | ||||
|  | ||||
| import ( | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"net/http/httptest" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
|  | ||||
| 	"golang.org/x/net/nettest" | ||||
| ) | ||||
|  | ||||
| func TestNetTest_TestConn(t *testing.T) { | ||||
| 	nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { | ||||
|  | ||||
| 		connCh := make(chan net.Conn, 1) | ||||
| 		doneCh := make(chan bool) | ||||
|  | ||||
| 		mux := http.NewServeMux() | ||||
| 		mux.HandleFunc("/connect", func(w http.ResponseWriter, r *http.Request) { | ||||
| 			conn, err := Accept(w, r) | ||||
| 			if err != nil { | ||||
| 				panic(err) | ||||
| 			} | ||||
| 			connCh <- conn | ||||
| 			<-doneCh | ||||
| 		}) | ||||
|  | ||||
| 		srv := httptest.NewServer(mux) | ||||
|  | ||||
| 		c1, err = DialHTTP(strings.TrimPrefix(srv.URL, "http://"), "/connect") | ||||
| 		if err != nil { | ||||
| 			panic(err) | ||||
| 		} | ||||
| 		c2 = <-connCh | ||||
|  | ||||
| 		return c1, c2, func() { | ||||
| 			doneCh <- true | ||||
| 			srv.Close() | ||||
| 		}, nil | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										5
									
								
								httpconn/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								httpconn/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| module git.crumpington.com/lib/httpconn | ||||
|  | ||||
| go 1.23.2 | ||||
|  | ||||
| require golang.org/x/net v0.30.0 | ||||
							
								
								
									
										2
									
								
								httpconn/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								httpconn/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| golang.org/x/net v0.30.0 h1:AcW1SDZMkb8IpzCdQUaIq2sP4sZ4zw+55h6ynffypl4= | ||||
| golang.org/x/net v0.30.0/go.mod h1:2wGyMJ5iFasEhkwi13ChkO/t1ECNC4X4eBKkVFyYFlU= | ||||
							
								
								
									
										32
									
								
								httpconn/server.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								httpconn/server.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,32 @@ | ||||
| package httpconn | ||||
|  | ||||
| import ( | ||||
| 	"io" | ||||
| 	"net" | ||||
| 	"net/http" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func Accept(w http.ResponseWriter, r *http.Request) (net.Conn, error) { | ||||
| 	if r.Method != "CONNECT" { | ||||
| 		w.Header().Set("Content-Type", "text/plain; charset=utf-8") | ||||
| 		w.WriteHeader(http.StatusMethodNotAllowed) | ||||
| 		io.WriteString(w, "405 must CONNECT\n") | ||||
| 		return nil, http.ErrNotSupported | ||||
| 	} | ||||
|  | ||||
| 	hj, ok := w.(http.Hijacker) | ||||
| 	if !ok { | ||||
| 		return nil, http.ErrNotSupported | ||||
| 	} | ||||
|  | ||||
| 	conn, _, err := hj.Hijack() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	_, _ = io.WriteString(conn, "HTTP/1.0 200 OK\n\n") | ||||
| 	conn.SetDeadline(time.Time{}) | ||||
|  | ||||
| 	return conn, nil | ||||
| } | ||||
							
								
								
									
										2
									
								
								idgen/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								idgen/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # idgen | ||||
|  | ||||
							
								
								
									
										3
									
								
								idgen/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								idgen/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/idgen | ||||
|  | ||||
| go 1.23.2 | ||||
							
								
								
									
										53
									
								
								idgen/idgen.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										53
									
								
								idgen/idgen.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,53 @@ | ||||
| package idgen | ||||
|  | ||||
| import ( | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base32" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // Creates a new, random token. | ||||
| func NewToken() string { | ||||
| 	buf := make([]byte, 20) | ||||
| 	_, err := rand.Read(buf) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return base32.StdEncoding.EncodeToString(buf) | ||||
| } | ||||
|  | ||||
| var ( | ||||
| 	lock       sync.Mutex | ||||
| 	ts         int64 = time.Now().Unix() | ||||
| 	counter    int64 = 1 | ||||
| 	counterMax int64 = 1 << 20 | ||||
| ) | ||||
|  | ||||
| // NextID can generate ~1M ints per second for a given nodeID. | ||||
| // | ||||
| // nodeID must be less than 64. | ||||
| func NextID(nodeID int64) int64 { | ||||
| 	lock.Lock() | ||||
| 	defer lock.Unlock() | ||||
|  | ||||
| 	tt := time.Now().Unix() | ||||
| 	if tt > ts { | ||||
| 		ts = tt | ||||
| 		counter = 1 | ||||
| 	} else { | ||||
| 		counter++ | ||||
| 		if counter == counterMax { | ||||
| 			panic("Too many IDs.") | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return (ts << 26) + (nodeID << 20) + counter | ||||
| } | ||||
|  | ||||
| func SplitID(id int64) (unixTime, nodeID, counter int64) { | ||||
| 	counter = id & (0x00000000000FFFFF) | ||||
| 	nodeID = (id >> 20) & (0x000000000000003F) | ||||
| 	unixTime = id >> 26 | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										19
									
								
								idgen/idgen_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								idgen/idgen_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,19 @@ | ||||
| package idgen | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func BenchmarkNext(b *testing.B) { | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		NextID(0) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestNextID(t *testing.T) { | ||||
| 	id := NextID(32) | ||||
|  | ||||
| 	a, b, c := SplitID(id) | ||||
| 	log.Print(a, b, c) | ||||
| } | ||||
							
								
								
									
										2
									
								
								keyedmutex/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								keyedmutex/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # keyedmutex | ||||
|  | ||||
							
								
								
									
										3
									
								
								keyedmutex/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								keyedmutex/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/keyedmutex | ||||
|  | ||||
| go 1.23.2 | ||||
							
								
								
									
										67
									
								
								keyedmutex/keyedmutex.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										67
									
								
								keyedmutex/keyedmutex.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,67 @@ | ||||
| package keyedmutex | ||||
|  | ||||
| import ( | ||||
| 	"container/list" | ||||
| 	"sync" | ||||
| ) | ||||
|  | ||||
| type KeyedMutex[K comparable] struct { | ||||
| 	mu       *sync.Mutex | ||||
| 	waitList map[K]*list.List | ||||
| } | ||||
|  | ||||
| func New[K comparable]() KeyedMutex[K] { | ||||
| 	return KeyedMutex[K]{ | ||||
| 		mu:       new(sync.Mutex), | ||||
| 		waitList: map[K]*list.List{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (m KeyedMutex[K]) Lock(key K) { | ||||
| 	if ch := m.lock(key); ch != nil { | ||||
| 		<-ch | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (m KeyedMutex[K]) lock(key K) chan struct{} { | ||||
| 	m.mu.Lock() | ||||
| 	defer m.mu.Unlock() | ||||
|  | ||||
| 	if waitList, ok := m.waitList[key]; ok { | ||||
| 		ch := make(chan struct{}) | ||||
| 		waitList.PushBack(ch) | ||||
| 		return ch | ||||
| 	} | ||||
|  | ||||
| 	m.waitList[key] = list.New() | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (m KeyedMutex[K]) TryLock(key K) bool { | ||||
| 	m.mu.Lock() | ||||
| 	defer m.mu.Unlock() | ||||
|  | ||||
| 	if _, ok := m.waitList[key]; ok { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	m.waitList[key] = list.New() | ||||
| 	return true | ||||
| } | ||||
|  | ||||
| func (m KeyedMutex[K]) Unlock(key K) { | ||||
| 	m.mu.Lock() | ||||
| 	defer m.mu.Unlock() | ||||
|  | ||||
| 	waitList, ok := m.waitList[key] | ||||
| 	if !ok { | ||||
| 		panic("unlock of unlocked mutex") | ||||
| 	} | ||||
|  | ||||
| 	if waitList.Len() == 0 { | ||||
| 		delete(m.waitList, key) | ||||
| 	} else { | ||||
| 		ch := waitList.Remove(waitList.Front()).(chan struct{}) | ||||
| 		ch <- struct{}{} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										123
									
								
								keyedmutex/keyedmutex_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										123
									
								
								keyedmutex/keyedmutex_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,123 @@ | ||||
| package keyedmutex | ||||
|  | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestKeyedMutex(t *testing.T) { | ||||
| 	checkState := func(t *testing.T, m KeyedMutex[string], keys ...string) { | ||||
| 		if len(m.waitList) != len(keys) { | ||||
| 			t.Fatal(m.waitList, keys) | ||||
| 		} | ||||
|  | ||||
| 		for _, key := range keys { | ||||
| 			if _, ok := m.waitList[key]; !ok { | ||||
| 				t.Fatal(key) | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	m := New[string]() | ||||
| 	checkState(t, m) | ||||
|  | ||||
| 	m.Lock("a") | ||||
| 	checkState(t, m, "a") | ||||
| 	m.Lock("b") | ||||
| 	checkState(t, m, "a", "b") | ||||
| 	m.Lock("c") | ||||
| 	checkState(t, m, "a", "b", "c") | ||||
|  | ||||
| 	if m.TryLock("a") { | ||||
| 		t.Fatal("a") | ||||
| 	} | ||||
| 	if m.TryLock("b") { | ||||
| 		t.Fatal("b") | ||||
| 	} | ||||
| 	if m.TryLock("c") { | ||||
| 		t.Fatal("c") | ||||
| 	} | ||||
|  | ||||
| 	if !m.TryLock("d") { | ||||
| 		t.Fatal("d") | ||||
| 	} | ||||
|  | ||||
| 	checkState(t, m, "a", "b", "c", "d") | ||||
|  | ||||
| 	if !m.TryLock("e") { | ||||
| 		t.Fatal("e") | ||||
| 	} | ||||
| 	checkState(t, m, "a", "b", "c", "d", "e") | ||||
|  | ||||
| 	m.Unlock("c") | ||||
| 	checkState(t, m, "a", "b", "d", "e") | ||||
| 	m.Unlock("a") | ||||
| 	checkState(t, m, "b", "d", "e") | ||||
| 	m.Unlock("e") | ||||
| 	checkState(t, m, "b", "d") | ||||
|  | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	for i := 0; i < 8; i++ { | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			m.Lock("b") | ||||
| 			m.Unlock("b") | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	time.Sleep(100 * time.Millisecond) | ||||
| 	m.Unlock("b") | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	checkState(t, m, "d") | ||||
|  | ||||
| 	m.Unlock("d") | ||||
| 	checkState(t, m) | ||||
| } | ||||
|  | ||||
| func TestKeyedMutex_unlockUnlocked(t *testing.T) { | ||||
| 	defer func() { | ||||
| 		if r := recover(); r == nil { | ||||
| 			t.Fatal(r) | ||||
| 		} | ||||
| 	}() | ||||
|  | ||||
| 	m := New[string]() | ||||
| 	m.Unlock("aldkfj") | ||||
| } | ||||
|  | ||||
| func BenchmarkUncontendedMutex(b *testing.B) { | ||||
| 	m := New[string]() | ||||
| 	key := "xyz" | ||||
|  | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		m.Lock(key) | ||||
| 		m.Unlock(key) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func BenchmarkContendedMutex(b *testing.B) { | ||||
| 	m := New[string]() | ||||
| 	key := "xyz" | ||||
|  | ||||
| 	m.Lock(key) | ||||
|  | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	for i := 0; i < b.N; i++ { | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			m.Lock(key) | ||||
| 			m.Unlock(key) | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	time.Sleep(time.Second) | ||||
|  | ||||
| 	b.ResetTimer() | ||||
| 	m.Unlock(key) | ||||
| 	wg.Wait() | ||||
| } | ||||
							
								
								
									
										2
									
								
								kvmemcache/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								kvmemcache/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # kvmemcache | ||||
|  | ||||
							
								
								
									
										134
									
								
								kvmemcache/cache.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										134
									
								
								kvmemcache/cache.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,134 @@ | ||||
| package kvmemcache | ||||
|  | ||||
| import ( | ||||
| 	"container/list" | ||||
| 	"sync" | ||||
| 	"time" | ||||
|  | ||||
| 	"git.crumpington.com/lib/keyedmutex" | ||||
| ) | ||||
|  | ||||
| type Cache[K comparable, V any] struct { | ||||
| 	updateLock keyedmutex.KeyedMutex[K] | ||||
| 	src        func(K) (V, error) | ||||
| 	ttl        time.Duration | ||||
| 	maxSize    int | ||||
|  | ||||
| 	// Lock protects variables below. | ||||
| 	lock  sync.Mutex | ||||
| 	cache map[K]*list.Element | ||||
| 	ll    *list.List | ||||
| 	stats Stats | ||||
| } | ||||
|  | ||||
| type lruItem[K comparable, V any] struct { | ||||
| 	key       K | ||||
| 	createdAt time.Time | ||||
| 	value     V | ||||
| 	err       error | ||||
| } | ||||
|  | ||||
| type Config[K comparable, V any] struct { | ||||
| 	MaxSize int | ||||
| 	TTL     time.Duration // Zero to ignore. | ||||
| 	Src     func(K) (V, error) | ||||
| } | ||||
|  | ||||
| func New[K comparable, V any](conf Config[K, V]) *Cache[K, V] { | ||||
| 	return &Cache[K, V]{ | ||||
| 		updateLock: keyedmutex.New[K](), | ||||
| 		src:        conf.Src, | ||||
| 		ttl:        conf.TTL, | ||||
| 		maxSize:    conf.MaxSize, | ||||
| 		lock:       sync.Mutex{}, | ||||
| 		cache:      make(map[K]*list.Element, conf.MaxSize+1), | ||||
| 		ll:         list.New(), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) Get(key K) (V, error) { | ||||
| 	ok, val, err := c.get(key) | ||||
| 	if ok { | ||||
| 		return val, err | ||||
| 	} | ||||
|  | ||||
| 	return c.load(key) | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) Evict(key K) { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
| 	c.evict(key) | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) Stats() Stats { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
| 	return c.stats | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) put(key K, value V, err error) { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
|  | ||||
| 	c.stats.Misses++ | ||||
|  | ||||
| 	c.cache[key] = c.ll.PushFront(lruItem[K, V]{ | ||||
| 		key:       key, | ||||
| 		createdAt: time.Now(), | ||||
| 		value:     value, | ||||
| 		err:       err, | ||||
| 	}) | ||||
|  | ||||
| 	if c.maxSize != 0 && len(c.cache) > c.maxSize { | ||||
| 		li := c.ll.Back() | ||||
| 		c.ll.Remove(li) | ||||
| 		delete(c.cache, li.Value.(lruItem[K, V]).key) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) evict(key K) { | ||||
| 	elem := c.cache[key] | ||||
| 	if elem != nil { | ||||
| 		delete(c.cache, key) | ||||
| 		c.ll.Remove(elem) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) get(key K) (ok bool, val V, err error) { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
|  | ||||
| 	li := c.cache[key] | ||||
| 	if li == nil { | ||||
| 		return false, val, nil | ||||
| 	} | ||||
|  | ||||
| 	item := li.Value.(lruItem[K, V]) | ||||
| 	// Maybe evict. | ||||
| 	if c.ttl != 0 && time.Since(item.createdAt) > c.ttl { | ||||
| 		c.evict(key) | ||||
| 		return false, val, nil | ||||
| 	} | ||||
|  | ||||
| 	c.stats.Hits++ | ||||
|  | ||||
| 	c.ll.MoveToFront(li) | ||||
| 	return true, item.value, item.err | ||||
| } | ||||
|  | ||||
| func (c *Cache[K, V]) load(key K) (V, error) { | ||||
| 	c.updateLock.Lock(key) | ||||
| 	defer c.updateLock.Unlock(key) | ||||
|  | ||||
| 	// Check again in case we lost the update race. | ||||
| 	ok, val, err := c.get(key) | ||||
| 	if ok { | ||||
| 		return val, err | ||||
| 	} | ||||
|  | ||||
| 	// Won the update race. | ||||
| 	val, err = c.src(key) | ||||
| 	c.put(key, val, err) | ||||
| 	return val, err | ||||
| } | ||||
							
								
								
									
										249
									
								
								kvmemcache/cache_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										249
									
								
								kvmemcache/cache_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,249 @@ | ||||
| package kvmemcache | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| type State[K comparable] struct { | ||||
| 	Keys  []K | ||||
| 	Stats Stats | ||||
| } | ||||
|  | ||||
| func (c *Cache[K,V]) assert(state State[K]) error { | ||||
| 	c.lock.Lock() | ||||
| 	defer c.lock.Unlock() | ||||
|  | ||||
| 	if len(c.cache) != len(state.Keys) { | ||||
| 		return fmt.Errorf( | ||||
| 			"Expected %d keys but found %d.", | ||||
| 			len(state.Keys), | ||||
| 			len(c.cache)) | ||||
| 	} | ||||
|  | ||||
| 	for _, k := range state.Keys { | ||||
| 		if _, ok := c.cache[k]; !ok { | ||||
| 			return fmt.Errorf( | ||||
| 				"Expected key %v not found.", | ||||
| 				k) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if c.stats.Hits != state.Stats.Hits { | ||||
| 		return fmt.Errorf( | ||||
| 			"Expected %d hits, but found %d.", | ||||
| 			state.Stats.Hits, | ||||
| 			c.stats.Hits) | ||||
| 	} | ||||
|  | ||||
| 	if c.stats.Misses != state.Stats.Misses { | ||||
| 		return fmt.Errorf( | ||||
| 			"Expected %d misses, but found %d.", | ||||
| 			state.Stats.Misses, | ||||
| 			c.stats.Misses) | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| var ErrTest = errors.New("Hello") | ||||
|  | ||||
| func TestCache_basic(t *testing.T) { | ||||
| 	c := New(Config[string, string]{ | ||||
| 		MaxSize: 4, | ||||
| 		TTL:     50 * time.Millisecond, | ||||
| 		Src: func(key string) (string, error) { | ||||
| 			if key == "err" { | ||||
| 				return "", ErrTest | ||||
| 			} | ||||
| 			return key, nil | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	type testCase struct { | ||||
| 		name  string | ||||
| 		sleep time.Duration | ||||
| 		key   string | ||||
| 		evict bool | ||||
| 		state State[string] | ||||
| 	} | ||||
|  | ||||
| 	cases := []testCase{ | ||||
| 		{ | ||||
| 			name: "get a", | ||||
| 			key:  "a", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a"}, | ||||
| 				Stats: Stats{Hits: 0, Misses: 1}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get a again", | ||||
| 			key:  "a", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a"}, | ||||
| 				Stats: Stats{Hits: 1, Misses: 1}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name:  "sleep, then get a again", | ||||
| 			sleep: 55 * time.Millisecond, | ||||
| 			key:   "a", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a"}, | ||||
| 				Stats: Stats{Hits: 1, Misses: 2}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get b", | ||||
| 			key:  "b", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "b"}, | ||||
| 				Stats: Stats{Hits: 1, Misses: 3}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get c", | ||||
| 			key:  "c", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "b", "c"}, | ||||
| 				Stats: Stats{Hits: 1, Misses: 4}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get d", | ||||
| 			key:  "d", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "b", "c", "d"}, | ||||
| 				Stats: Stats{Hits: 1, Misses: 5}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get e", | ||||
| 			key:  "e", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"b", "c", "d", "e"}, | ||||
| 				Stats: Stats{Hits: 1, Misses: 6}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get c again", | ||||
| 			key:  "c", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"b", "c", "d", "e"}, | ||||
| 				Stats: Stats{Hits: 2, Misses: 6}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get err", | ||||
| 			key:  "err", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"c", "d", "e", "err"}, | ||||
| 				Stats: Stats{Hits: 2, Misses: 7}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "get err again", | ||||
| 			key:  "err", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"c", "d", "e", "err"}, | ||||
| 				Stats: Stats{Hits: 3, Misses: 7}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name:  "evict c", | ||||
| 			key:   "c", | ||||
| 			evict: true, | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"d", "e", "err"}, | ||||
| 				Stats: Stats{Hits: 3, Misses: 7}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "reload-all a", | ||||
| 			key:  "a", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "d", "e", "err"}, | ||||
| 				Stats: Stats{Hits: 3, Misses: 8}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "reload-all b", | ||||
| 			key:  "b", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "b", "e", "err"}, | ||||
| 				Stats: Stats{Hits: 3, Misses: 9}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "reload-all c", | ||||
| 			key:  "c", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "b", "c", "err"}, | ||||
| 				Stats: Stats{Hits: 3, Misses: 10}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "reload-all d", | ||||
| 			key:  "d", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"a", "b", "c", "d"}, | ||||
| 				Stats: Stats{Hits: 3, Misses: 11}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "read a again", | ||||
| 			key:  "a", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"b", "c", "d", "a"}, | ||||
| 				Stats: Stats{Hits: 4, Misses: 11}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			name: "read e, evicting b", | ||||
| 			key:  "e", | ||||
| 			state: State[string]{ | ||||
| 				Keys:  []string{"c", "d", "a", "e"}, | ||||
| 				Stats: Stats{Hits: 4, Misses: 12}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		time.Sleep(tc.sleep) | ||||
| 		if !tc.evict { | ||||
| 			val, err := c.Get(tc.key) | ||||
| 			if tc.key == "err" && err != ErrTest { | ||||
| 				t.Fatal(tc.name, val) | ||||
| 			} | ||||
| 			if tc.key != "err" && val != tc.key { | ||||
| 				t.Fatal(tc.name, tc.key, val) | ||||
| 			} | ||||
| 		} else { | ||||
| 			c.Evict(tc.key) | ||||
| 		} | ||||
|  | ||||
| 		if err := c.assert(tc.state); err != nil { | ||||
| 			t.Fatal(err) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestCache_thunderingHerd(t *testing.T) { | ||||
| 	c := New(Config[string,string]{ | ||||
| 		MaxSize: 4, | ||||
| 		Src: func(key string) (string, error) { | ||||
| 			time.Sleep(time.Second) | ||||
| 			return key, nil | ||||
| 		}, | ||||
| 	}) | ||||
|  | ||||
| 	wg := sync.WaitGroup{} | ||||
| 	for i := 0; i < 16384; i++ { | ||||
| 		wg.Add(1) | ||||
| 		go func() { | ||||
| 			defer wg.Done() | ||||
| 			val, err := c.Get("a") | ||||
| 			if err != nil { | ||||
| 				panic(err) | ||||
| 			} | ||||
| 			if val != "a" { | ||||
| 				panic(err) | ||||
| 			} | ||||
| 		}() | ||||
| 	} | ||||
|  | ||||
| 	wg.Wait() | ||||
|  | ||||
| 	stats := c.Stats() | ||||
| 	if stats.Hits != 16383 || stats.Misses != 1 { | ||||
| 		t.Fatal(stats) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										5
									
								
								kvmemcache/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								kvmemcache/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| module git.crumpington.com/lib/kvmemcache | ||||
|  | ||||
| go 1.23.2 | ||||
|  | ||||
| require git.crumpington.com/lib/keyedmutex v1.0.1 | ||||
							
								
								
									
										2
									
								
								kvmemcache/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								kvmemcache/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| git.crumpington.com/lib/keyedmutex v1.0.1 h1:5ylwGXQzL9ojZIhlqkut6dpa4yt6Wz6bOWbf/tQBAMQ= | ||||
| git.crumpington.com/lib/keyedmutex v1.0.1/go.mod h1:VxxJRU/XvvF61IuJZG7kUIv954Q8+Rh8bnVpEzGYrQ4= | ||||
							
								
								
									
										6
									
								
								kvmemcache/stats.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								kvmemcache/stats.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| package kvmemcache | ||||
|  | ||||
| type Stats struct { | ||||
| 	Hits   uint64 | ||||
| 	Misses uint64 | ||||
| } | ||||
							
								
								
									
										2
									
								
								mmap/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								mmap/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # mmap | ||||
|  | ||||
							
								
								
									
										100
									
								
								mmap/file.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								mmap/file.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| package mmap | ||||
|  | ||||
| import "os" | ||||
|  | ||||
| type File struct { | ||||
| 	f   *os.File | ||||
| 	Map []byte | ||||
| } | ||||
|  | ||||
| func Create(path string, size int64) (*File, error) { | ||||
| 	f, err := os.Create(path) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	if err := f.Truncate(size); err != nil { | ||||
| 		f.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	m, err := Map(f, PROT_READ|PROT_WRITE) | ||||
| 	if err != nil { | ||||
| 		f.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &File{f, m}, nil | ||||
| } | ||||
|  | ||||
| // Opens a mapped file in read-only mode. | ||||
| func Open(path string) (*File, error) { | ||||
| 	f, err := os.Open(path) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	m, err := Map(f, PROT_READ) | ||||
| 	if err != nil { | ||||
| 		f.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &File{f, m}, nil | ||||
| } | ||||
|  | ||||
| func OpenFile( | ||||
| 	path string, | ||||
| 	fileFlags int, | ||||
| 	perm os.FileMode, | ||||
| 	size int64, // -1 for file size. | ||||
| ) (*File, error) { | ||||
| 	f, err := os.OpenFile(path, fileFlags, perm) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	writable := fileFlags|os.O_RDWR != 0 || fileFlags|os.O_WRONLY != 0 | ||||
|  | ||||
| 	fi, err := f.Stat() | ||||
| 	if err != nil { | ||||
| 		f.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	if writable && size > 0 && size != fi.Size() { | ||||
| 		if err := f.Truncate(size); err != nil { | ||||
| 			f.Close() | ||||
| 			return nil, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	mapFlags := PROT_READ | ||||
| 	if writable { | ||||
| 		mapFlags |= PROT_WRITE | ||||
| 	} | ||||
|  | ||||
| 	m, err := Map(f, mapFlags) | ||||
| 	if err != nil { | ||||
| 		f.Close() | ||||
| 		return nil, err | ||||
| 	} | ||||
|  | ||||
| 	return &File{f, m}, nil | ||||
| } | ||||
|  | ||||
| func (f *File) Sync() error { | ||||
| 	return Sync(f.Map) | ||||
| } | ||||
|  | ||||
| func (f *File) Close() error { | ||||
| 	if f.Map != nil { | ||||
| 		if err := Unmap(f.Map); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		f.Map = nil | ||||
| 	} | ||||
|  | ||||
| 	if f.f != nil { | ||||
| 		if err := f.f.Close(); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		f.f = nil | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										3
									
								
								mmap/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								mmap/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/mmap | ||||
|  | ||||
| go 1.23.2 | ||||
							
								
								
									
										0
									
								
								mmap/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								mmap/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										59
									
								
								mmap/mmap.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								mmap/mmap.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,59 @@ | ||||
| package mmap | ||||
|  | ||||
| import ( | ||||
| 	"os" | ||||
| 	"syscall" | ||||
| 	"unsafe" | ||||
| ) | ||||
|  | ||||
| const ( | ||||
| 	PROT_READ  = syscall.PROT_READ | ||||
| 	PROT_WRITE = syscall.PROT_WRITE | ||||
| ) | ||||
|  | ||||
| // Mmap creates a memory map of the given file. The flags argument should be a | ||||
| // combination of PROT_READ and PROT_WRITE. The size of the map will be the | ||||
| // file's size. | ||||
| func Map(f *os.File, flags int) ([]byte, error) { | ||||
| 	fi, err := f.Stat() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	size := fi.Size() | ||||
|  | ||||
| 	addr, _, errno := syscall.Syscall6( | ||||
| 		syscall.SYS_MMAP, | ||||
| 		0, // addr: 0 => allow kernel to choose | ||||
| 		uintptr(size), | ||||
| 		uintptr(flags), | ||||
| 		uintptr(syscall.MAP_SHARED), | ||||
| 		f.Fd(), | ||||
| 		0) // offset: 0 => start of file | ||||
| 	if errno != 0 { | ||||
| 		return nil, syscall.Errno(errno) | ||||
| 	} | ||||
|  | ||||
| 	return unsafe.Slice((*byte)(unsafe.Pointer(addr)), size), nil | ||||
| } | ||||
|  | ||||
| // Munmap unmaps the data obtained by Map. | ||||
| func Unmap(data []byte) error { | ||||
| 	_, _, errno := syscall.Syscall( | ||||
| 		syscall.SYS_MUNMAP, | ||||
| 		uintptr(unsafe.Pointer(&data[:1][0])), | ||||
| 		uintptr(cap(data)), | ||||
| 		0) | ||||
| 	if errno != 0 { | ||||
| 		return syscall.Errno(errno) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func Sync(b []byte) (err error) { | ||||
| 	_p0 := unsafe.Pointer(&b[0]) | ||||
| 	_, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(_p0), uintptr(len(b)), uintptr(syscall.MS_SYNC)) | ||||
| 	if errno != 0 { | ||||
| 		err = syscall.Errno(errno) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										45
									
								
								pgutil/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								pgutil/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| # pgutil | ||||
|  | ||||
| ## Transactions | ||||
|  | ||||
| Simplify postgres transactions using `WithTx` for serializable transactions, | ||||
| or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner` | ||||
| type to get automatic retries of serialization errors. | ||||
|  | ||||
| ## Migrations | ||||
|  | ||||
| Put your migrations into a directory, for example `migrations`, ordered by name | ||||
| (YYYY-MM-DD prefix, for example). Embed the directory and pass it to the | ||||
| `Migrate` function: | ||||
|  | ||||
| ```Go | ||||
| //go:embed migrations | ||||
| var migrations embed.FS | ||||
|  | ||||
| func init() { | ||||
|     Migrate(db, migrations) // Check the error, of course. | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Testing | ||||
|  | ||||
| In order to test this packge, we need to create a test user and database: | ||||
|  | ||||
| ``` | ||||
| sudo su postgres | ||||
| psql | ||||
|  | ||||
| CREATE DATABASE test; | ||||
| CREATE USER test WITH ENCRYPTED PASSWORD 'test'; | ||||
| GRANT ALL PRIVILEGES ON DATABASE test TO test; | ||||
|  | ||||
| use test | ||||
|  | ||||
| GRANT ALL ON SCHEMA public TO test; | ||||
| ``` | ||||
|  | ||||
| Check that you can connect via the command line: | ||||
|  | ||||
| ``` | ||||
| psql -h 127.0.0.1 -U test --password test | ||||
| ``` | ||||
							
								
								
									
										42
									
								
								pgutil/dropall.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								pgutil/dropall.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,42 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"log" | ||||
| ) | ||||
|  | ||||
| const dropTablesQueryQuery = ` | ||||
| SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;' | ||||
| FROM | ||||
|    pg_tables | ||||
| WHERE | ||||
|   schemaname='public'` | ||||
|  | ||||
| // Deletes all tables in the database. Useful for testing. | ||||
| func DropAllTables(db *sql.DB) error { | ||||
| 	rows, err := db.Query(dropTablesQueryQuery) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	queries := []string{} | ||||
| 	for rows.Next() { | ||||
| 		var s string | ||||
| 		if err := rows.Scan(&s); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		queries = append(queries, s) | ||||
| 	} | ||||
|  | ||||
| 	if len(queries) > 0 { | ||||
| 		log.Printf("DROPPING ALL (%d) TABLES", len(queries)) | ||||
| 	} | ||||
|  | ||||
| 	for _, query := range queries { | ||||
| 		if _, err := db.Exec(query); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return nil | ||||
| } | ||||
							
								
								
									
										31
									
								
								pgutil/errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								pgutil/errors.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
|  | ||||
| 	"github.com/lib/pq" | ||||
| ) | ||||
|  | ||||
| func ErrIsDuplicateKey(err error) bool { | ||||
| 	return ErrHasCode(err, "23505") | ||||
| } | ||||
|  | ||||
| func ErrIsForeignKey(err error) bool { | ||||
| 	return ErrHasCode(err, "23503") | ||||
| } | ||||
|  | ||||
| func ErrIsSerializationFaiilure(err error) bool { | ||||
| 	return ErrHasCode(err, "40001") | ||||
| } | ||||
|  | ||||
| func ErrHasCode(err error, code string) bool { | ||||
| 	if err == nil { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	var pErr *pq.Error | ||||
| 	if errors.As(err, &pErr) { | ||||
| 		return pErr.Code == pq.ErrorCode(code) | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										36
									
								
								pgutil/errors_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								pgutil/errors_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestErrors(t *testing.T) { | ||||
| 	db, err := sql.Open( | ||||
| 		"postgres", | ||||
| 		"host=127.0.0.1 dbname=test sslmode=disable user=test password=test") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := DropAllTables(db); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := Migrate(db, testMigrationFS); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`) | ||||
| 	if !ErrIsDuplicateKey(err) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`) | ||||
| 	if !ErrIsDuplicateKey(err) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`) | ||||
| 	if !ErrIsForeignKey(err) { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										5
									
								
								pgutil/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								pgutil/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| module git.crumpington.com/git/pgutil | ||||
|  | ||||
| go 1.23.2 | ||||
|  | ||||
| require github.com/lib/pq v1.10.9 | ||||
							
								
								
									
										2
									
								
								pgutil/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								pgutil/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= | ||||
| github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= | ||||
							
								
								
									
										82
									
								
								pgutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								pgutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"embed" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"path/filepath" | ||||
| 	"sort" | ||||
| ) | ||||
|  | ||||
| const initMigrationTableQuery = ` | ||||
| CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` | ||||
|  | ||||
| const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` | ||||
|  | ||||
| const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` | ||||
|  | ||||
| func Migrate(db *sql.DB, migrationFS embed.FS) error { | ||||
| 	return WithTx(db, func(tx *sql.Tx) error { | ||||
| 		if _, err := tx.Exec(initMigrationTableQuery); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		dirs, err := migrationFS.ReadDir(".") | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if len(dirs) != 1 { | ||||
| 			return errors.New("expected a single migrations directory") | ||||
| 		} | ||||
|  | ||||
| 		if !dirs[0].IsDir() { | ||||
| 			return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) | ||||
| 		} | ||||
|  | ||||
| 		dirName := dirs[0].Name() | ||||
| 		files, err := migrationFS.ReadDir(dirName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Sort sql files by name. | ||||
| 		sort.Slice(files, func(i, j int) bool { | ||||
| 			return files[i].Name() < files[j].Name() | ||||
| 		}) | ||||
|  | ||||
| 		for _, dirEnt := range files { | ||||
| 			if !dirEnt.Type().IsRegular() { | ||||
| 				return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) | ||||
| 			} | ||||
|  | ||||
| 			var ( | ||||
| 				name   = dirEnt.Name() | ||||
| 				exists bool | ||||
| 			) | ||||
|  | ||||
| 			err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			if exists { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			if _, err := tx.Exec(string(migration)); err != nil { | ||||
| 				return fmt.Errorf("migration %s failed: %v", name, err) | ||||
| 			} | ||||
|  | ||||
| 			if _, err := tx.Exec(insertMigrationQuery, name); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										50
									
								
								pgutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								pgutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,50 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"embed" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| //go:embed test-migrations | ||||
| var testMigrationFS embed.FS | ||||
|  | ||||
| func TestMigrate(t *testing.T) { | ||||
| 	db, err := sql.Open( | ||||
| 		"postgres", | ||||
| 		"host=127.0.0.1 dbname=test sslmode=disable user=test password=test") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := DropAllTables(db); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := Migrate(db, testMigrationFS); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Shouldn't have any effect. | ||||
| 	if err := Migrate(db, testMigrationFS); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` | ||||
| 	var exists bool | ||||
|  | ||||
| 	if err = db.QueryRow(query, 1).Scan(&exists); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if exists { | ||||
| 		t.Fatal("1 shouldn't exist") | ||||
| 	} | ||||
|  | ||||
| 	if err = db.QueryRow(query, 2).Scan(&exists); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if !exists { | ||||
| 		t.Fatal("2 should exist") | ||||
| 	} | ||||
|  | ||||
| } | ||||
							
								
								
									
										9
									
								
								pgutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								pgutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| CREATE TABLE users( | ||||
|        UserID BIGINT NOT NULL PRIMARY KEY, | ||||
|        Email  TEXT   NOT NULL UNIQUE); | ||||
|  | ||||
| CREATE TABLE user_notes( | ||||
|        UserID BIGINT NOT NULL REFERENCES users(UserID), | ||||
|        NoteID BIGINT NOT NULL, | ||||
|        Note   Text   NOT NULL, | ||||
|        PRIMARY KEY(UserID,NoteID)); | ||||
							
								
								
									
										1
									
								
								pgutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								pgutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); | ||||
							
								
								
									
										1
									
								
								pgutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								pgutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| DELETE FROM users WHERE UserID=1; | ||||
							
								
								
									
										70
									
								
								pgutil/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								pgutil/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,70 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"math/rand" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // Postgres doesn't use serializable transactions by default. This wrapper will | ||||
| // run the enclosed function within a serializable. Note: this may return an | ||||
| // retriable serialization error (see ErrIsSerializationFaiilure). | ||||
| func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { | ||||
| 	return WithTxDefault(db, func(tx *sql.Tx) error { | ||||
| 		if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		return fn(tx) | ||||
| 	}) | ||||
| } | ||||
|  | ||||
| // This is a convenience function to provide a transaction wrapper with the | ||||
| // default isolation level. | ||||
| func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error { | ||||
| 	// Start a transaction. | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = fn(tx) | ||||
|  | ||||
| 	if err == nil { | ||||
| 		err = tx.Commit() | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		_ = tx.Rollback() | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // SerialTxRunner attempts serializable transactions in a loop. If a | ||||
| // transaction fails due to a serialization error, then the runner will retry | ||||
| // with exponential backoff, until the sleep time reaches MaxTimeout. | ||||
| // | ||||
| // For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep | ||||
| // for ~100, 200, 400, and 800 ms between retries. | ||||
| // | ||||
| // 10% jitter is added to the sleep time. | ||||
| type SerialTxRunner struct { | ||||
| 	MinTimeout time.Duration | ||||
| 	MaxTimeout time.Duration | ||||
| } | ||||
|  | ||||
| func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error { | ||||
| 	timeout := r.MinTimeout | ||||
| 	for { | ||||
| 		err := WithTx(db, fn) | ||||
| 		if err == nil { | ||||
| 			return nil | ||||
| 		} | ||||
| 		if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) { | ||||
| 			return err | ||||
| 		} | ||||
| 		sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10))) | ||||
| 		time.Sleep(sleepTimeout) | ||||
| 		timeout *= 2 | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										120
									
								
								pgutil/tx_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										120
									
								
								pgutil/tx_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,120 @@ | ||||
| package pgutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| // TestExecuteTx verifies transaction retry using the classic | ||||
| // example of write skew in bank account balance transfers. | ||||
| func TestWithTx(t *testing.T) { | ||||
| 	db, err := sql.Open( | ||||
| 		"postgres", | ||||
| 		"host=127.0.0.1 dbname=test sslmode=disable user=test password=test") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := DropAllTables(db); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	defer db.Close() | ||||
|  | ||||
| 	initStmt := ` | ||||
| CREATE TABLE t (acct INT PRIMARY KEY, balance INT); | ||||
| INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100); | ||||
| ` | ||||
| 	if _, err := db.Exec(initStmt); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	type queryI interface { | ||||
| 		Query(string, ...interface{}) (*sql.Rows, error) | ||||
| 	} | ||||
|  | ||||
| 	getBalances := func(q queryI) (bal1, bal2 int, err error) { | ||||
| 		var rows *sql.Rows | ||||
| 		rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`) | ||||
| 		if err != nil { | ||||
| 			return | ||||
| 		} | ||||
| 		defer rows.Close() | ||||
| 		balances := []*int{&bal1, &bal2} | ||||
| 		i := 0 | ||||
| 		for ; rows.Next(); i++ { | ||||
| 			if err = rows.Scan(balances[i]); err != nil { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 		if i != 2 { | ||||
| 			err = fmt.Errorf("expected two balances; got %d", i) | ||||
| 			return | ||||
| 		} | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond} | ||||
|  | ||||
| 	runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error { | ||||
| 		errCh := make(chan error, 1) | ||||
| 		go func() { | ||||
| 			*iter = 0 | ||||
| 			errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error { | ||||
| 				*iter++ | ||||
| 				bal1, bal2, err := getBalances(tx) | ||||
| 				if err != nil { | ||||
| 					return err | ||||
| 				} | ||||
| 				// If this is the first iteration, wait for the other tx to | ||||
| 				// also read. | ||||
| 				if *iter == 1 { | ||||
| 					wg.Done() | ||||
| 					wg.Wait() | ||||
| 				} | ||||
| 				// Now, subtract from one account and give to the other. | ||||
| 				if bal1 > bal2 { | ||||
| 					if _, err := tx.Exec(` | ||||
| UPDATE t SET balance=balance-100 WHERE acct=1; | ||||
| UPDATE t SET balance=balance+100 WHERE acct=2; | ||||
| `); err != nil { | ||||
| 						return err | ||||
| 					} | ||||
| 				} else { | ||||
| 					if _, err := tx.Exec(` | ||||
| UPDATE t SET balance=balance+100 WHERE acct=1; | ||||
| UPDATE t SET balance=balance-100 WHERE acct=2; | ||||
| `); err != nil { | ||||
| 						return err | ||||
| 					} | ||||
| 				} | ||||
| 				return nil | ||||
| 			}) | ||||
| 		}() | ||||
| 		return errCh | ||||
| 	} | ||||
|  | ||||
| 	var wg sync.WaitGroup | ||||
| 	wg.Add(2) | ||||
| 	var iters1, iters2 int | ||||
| 	txn1Err := runTxn(&wg, &iters1) | ||||
| 	txn2Err := runTxn(&wg, &iters2) | ||||
| 	if err := <-txn1Err; err != nil { | ||||
| 		t.Errorf("expected success in txn1; got %s", err) | ||||
| 	} | ||||
| 	if err := <-txn2Err; err != nil { | ||||
| 		t.Errorf("expected success in txn2; got %s", err) | ||||
| 	} | ||||
| 	if iters1+iters2 <= 2 { | ||||
| 		t.Errorf("expected retries between the competing transactions; "+ | ||||
| 			"got txn1=%d, txn2=%d", iters1, iters2) | ||||
| 	} | ||||
| 	bal1, bal2, err := getBalances(db) | ||||
| 	if err != nil || bal1 != 100 || bal2 != 100 { | ||||
| 		t.Errorf("expected balances to be restored without error; "+ | ||||
| 			"got acct1=%d, acct2=%d: %s", bal1, bal2, err) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										3
									
								
								ratelimiter/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								ratelimiter/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/ratelimiter | ||||
|  | ||||
| go 1.23.2 | ||||
							
								
								
									
										86
									
								
								ratelimiter/ratelimiter.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										86
									
								
								ratelimiter/ratelimiter.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,86 @@ | ||||
| package ratelimiter | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var ErrBackoff = errors.New("Backoff") | ||||
|  | ||||
| type Config struct { | ||||
| 	BurstLimit   int64         // Number of requests to allow to burst. | ||||
| 	FillPeriod   time.Duration // Add one per period. | ||||
| 	MaxWaitCount int64         // Max number of waiting requests. 0 disables. | ||||
| } | ||||
|  | ||||
| type Limiter struct { | ||||
| 	lock sync.Mutex | ||||
|  | ||||
| 	fillPeriod  time.Duration | ||||
| 	minWaitTime time.Duration | ||||
| 	maxWaitTime time.Duration | ||||
|  | ||||
| 	waitTime    time.Duration // If waitTime < 0, no waiting occurs. | ||||
| 	lastRequest time.Time | ||||
| } | ||||
|  | ||||
| func New(conf Config) *Limiter { | ||||
| 	if conf.BurstLimit < 0 { | ||||
| 		panic(conf.BurstLimit) | ||||
| 	} | ||||
| 	if conf.FillPeriod <= 0 { | ||||
| 		panic(conf.FillPeriod) | ||||
| 	} | ||||
| 	if conf.MaxWaitCount < 0 { | ||||
| 		panic(conf.MaxWaitCount) | ||||
| 	} | ||||
|  | ||||
| 	lim := &Limiter{ | ||||
| 		lastRequest: time.Now(), | ||||
| 		fillPeriod:  conf.FillPeriod, | ||||
| 		waitTime:    -conf.FillPeriod * time.Duration(conf.BurstLimit), | ||||
| 		minWaitTime: -conf.FillPeriod * time.Duration(conf.BurstLimit), | ||||
| 		maxWaitTime: conf.FillPeriod * time.Duration(conf.MaxWaitCount), | ||||
| 	} | ||||
|  | ||||
| 	lim.waitTime = lim.minWaitTime | ||||
|  | ||||
| 	return lim | ||||
| } | ||||
|  | ||||
| func (lim *Limiter) limit(count int64) (time.Duration, error) { | ||||
| 	lim.lock.Lock() | ||||
| 	defer lim.lock.Unlock() | ||||
|  | ||||
| 	dt := time.Since(lim.lastRequest) | ||||
|  | ||||
| 	waitTime := lim.waitTime - dt + time.Duration(count)*lim.fillPeriod | ||||
|  | ||||
| 	if waitTime < lim.minWaitTime { | ||||
| 		waitTime = lim.minWaitTime | ||||
| 	} else if waitTime > lim.maxWaitTime { | ||||
| 		return 0, ErrBackoff | ||||
| 	} | ||||
|  | ||||
| 	lim.waitTime = waitTime | ||||
| 	lim.lastRequest = lim.lastRequest.Add(dt) | ||||
|  | ||||
| 	return lim.waitTime, nil | ||||
| } | ||||
|  | ||||
| // Apply the limiter to the calling thread. The function may sleep for up to | ||||
| // maxWaitTime before returning. If the timeout would need to be more than | ||||
| // maxWaitTime to enforce the rate limit, ErrBackoff is returned. | ||||
| func (lim *Limiter) Limit() error { | ||||
| 	dt, err := lim.limit(1) | ||||
| 	time.Sleep(dt) // Will return immediately for dt <= 0. | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| // Apply the limiter for multiple items at once. | ||||
| func (lim *Limiter) LimitMultiple(count int64) error { | ||||
| 	dt, err := lim.limit(count) | ||||
| 	time.Sleep(dt) | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										101
									
								
								ratelimiter/ratelimiter_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										101
									
								
								ratelimiter/ratelimiter_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,101 @@ | ||||
| package ratelimiter | ||||
|  | ||||
| import ( | ||||
| 	"sync" | ||||
| 	"testing" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| func TestRateLimiter_Limit_Errors(t *testing.T) { | ||||
| 	type TestCase struct { | ||||
| 		Name     string | ||||
| 		Conf     Config | ||||
| 		N        int | ||||
| 		ErrCount int | ||||
| 		DT       time.Duration | ||||
| 	} | ||||
|  | ||||
| 	cases := []TestCase{ | ||||
| 		{ | ||||
| 			Name: "no burst, no wait", | ||||
| 			Conf: Config{ | ||||
| 				BurstLimit:   0, | ||||
| 				FillPeriod:   100 * time.Millisecond, | ||||
| 				MaxWaitCount: 0, | ||||
| 			}, | ||||
| 			N:        32, | ||||
| 			ErrCount: 32, | ||||
| 			DT:       0, | ||||
| 		}, { | ||||
| 			Name: "no wait", | ||||
| 			Conf: Config{ | ||||
| 				BurstLimit:   10, | ||||
| 				FillPeriod:   100 * time.Millisecond, | ||||
| 				MaxWaitCount: 0, | ||||
| 			}, | ||||
| 			N:        32, | ||||
| 			ErrCount: 22, | ||||
| 			DT:       0, | ||||
| 		}, { | ||||
| 			Name: "no burst", | ||||
| 			Conf: Config{ | ||||
| 				BurstLimit:   0, | ||||
| 				FillPeriod:   10 * time.Millisecond, | ||||
| 				MaxWaitCount: 10, | ||||
| 			}, | ||||
| 			N:        32, | ||||
| 			ErrCount: 22, | ||||
| 			DT:       100 * time.Millisecond, | ||||
| 		}, { | ||||
| 			Name: "burst and wait", | ||||
| 			Conf: Config{ | ||||
| 				BurstLimit:   10, | ||||
| 				FillPeriod:   10 * time.Millisecond, | ||||
| 				MaxWaitCount: 10, | ||||
| 			}, | ||||
| 			N:        32, | ||||
| 			ErrCount: 12, | ||||
| 			DT:       100 * time.Millisecond, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		wg := sync.WaitGroup{} | ||||
| 		l := New(tc.Conf) | ||||
| 		errs := make([]error, tc.N) | ||||
|  | ||||
| 		t0 := time.Now() | ||||
|  | ||||
| 		for i := 0; i < tc.N; i++ { | ||||
| 			wg.Add(1) | ||||
| 			go func(i int) { | ||||
| 				errs[i] = l.Limit() | ||||
| 				wg.Done() | ||||
| 			}(i) | ||||
| 		} | ||||
|  | ||||
| 		wg.Wait() | ||||
|  | ||||
| 		dt := time.Since(t0) | ||||
|  | ||||
| 		errCount := 0 | ||||
| 		for _, err := range errs { | ||||
| 			if err != nil { | ||||
| 				errCount++ | ||||
| 			} | ||||
| 		} | ||||
|  | ||||
| 		if errCount != tc.ErrCount { | ||||
| 			t.Fatalf("%s: Expected %d errors but got %d.", | ||||
| 				tc.Name, tc.ErrCount, errCount) | ||||
| 		} | ||||
|  | ||||
| 		if dt < tc.DT { | ||||
| 			t.Fatal(tc.Name, dt, tc.DT) | ||||
| 		} | ||||
|  | ||||
| 		if dt > tc.DT+10*time.Millisecond { | ||||
| 			t.Fatal(tc.Name, dt, tc.DT) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										22
									
								
								sqlgen/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								sqlgen/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| # sqlgen | ||||
|  | ||||
| ## Installing | ||||
|  | ||||
| ``` | ||||
| go install git.crumpington.com/lib/sqlgen/cmd/sqlgen@latest | ||||
| ``` | ||||
|  | ||||
| ## Usage | ||||
|  | ||||
| ``` | ||||
| sqlgen [driver] [defs-path] [output-path] | ||||
| ``` | ||||
|  | ||||
| ## File Format | ||||
|  | ||||
| ``` | ||||
| TABLE [sql-name] OF [go-type] <NoInsert> <NoUpdate> <NoDelete> ( | ||||
|     [sql-column] [go-type] <AS go-name> <PK> <NoInsert> <NoUpdate>, | ||||
|     ... | ||||
| ); | ||||
| ``` | ||||
							
								
								
									
										7
									
								
								sqlgen/cmd/sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								sqlgen/cmd/sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package main | ||||
|  | ||||
| import "git.crumpington.com/lib/sqlgen" | ||||
|  | ||||
| func main() { | ||||
| 	sqlgen.Main() | ||||
| } | ||||
							
								
								
									
										3
									
								
								sqlgen/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								sqlgen/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/sqlgen | ||||
|  | ||||
| go 1.23.2 | ||||
							
								
								
									
										43
									
								
								sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								sqlgen/main.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,43 @@ | ||||
| package sqlgen | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"os" | ||||
| ) | ||||
|  | ||||
| func Main() { | ||||
| 	usage := func() { | ||||
| 		fmt.Fprintf(os.Stderr, ` | ||||
| %s DRIVER DEFS_PATH OUTPUT_PATH | ||||
|  | ||||
| Drivers are one of: sqlite, postgres | ||||
| `, | ||||
| 			os.Args[0]) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
|  | ||||
| 	if len(os.Args) != 4 { | ||||
| 		usage() | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		template   string | ||||
| 		driver     = os.Args[1] | ||||
| 		defsPath   = os.Args[2] | ||||
| 		outputPath = os.Args[3] | ||||
| 	) | ||||
|  | ||||
| 	switch driver { | ||||
| 	case "sqlite": | ||||
| 		template = sqliteTemplate | ||||
| 	default: | ||||
| 		fmt.Fprintf(os.Stderr, "Unknown driver: %s", driver) | ||||
| 		usage() | ||||
| 	} | ||||
|  | ||||
| 	err := render(template, defsPath, outputPath) | ||||
| 	if err != nil { | ||||
| 		fmt.Fprintf(os.Stderr, "Error: %v", err) | ||||
| 		os.Exit(1) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										143
									
								
								sqlgen/parse.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										143
									
								
								sqlgen/parse.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,143 @@ | ||||
| package sqlgen | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"os" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| func parsePath(filePath string) (*schema, error) { | ||||
| 	fileBytes, err := os.ReadFile(filePath) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return parseBytes(fileBytes) | ||||
| } | ||||
|  | ||||
| func parseBytes(fileBytes []byte) (*schema, error) { | ||||
| 	s := string(fileBytes) | ||||
| 	for _, c := range []string{",", "(", ")", ";"} { | ||||
| 		s = strings.ReplaceAll(s, c, " "+c+" ") | ||||
| 	} | ||||
|  | ||||
| 	var ( | ||||
| 		tokens = strings.Fields(s) | ||||
| 		schema = &schema{} | ||||
| 		err    error | ||||
| 	) | ||||
|  | ||||
| 	for len(tokens) > 0 { | ||||
| 		switch tokens[0] { | ||||
| 		case "TABLE": | ||||
| 			tokens, err = parseTable(schema, tokens) | ||||
| 			if err != nil { | ||||
| 				return nil, err | ||||
| 			} | ||||
|  | ||||
| 		default: | ||||
| 			return nil, errors.New("invalid token: " + tokens[0]) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return schema, nil | ||||
| } | ||||
|  | ||||
| func parseTable(schema *schema, tokens []string) ([]string, error) { | ||||
| 	tokens = tokens[1:] | ||||
| 	if len(tokens) < 3 { | ||||
| 		return tokens, errors.New("incomplete table definition") | ||||
| 	} | ||||
| 	if tokens[1] != "OF" { | ||||
| 		return tokens, errors.New("expected OF in table definition") | ||||
| 	} | ||||
|  | ||||
| 	table := &table{ | ||||
| 		Name: tokens[0], | ||||
| 		Type: tokens[2], | ||||
| 	} | ||||
| 	schema.Tables = append(schema.Tables, table) | ||||
|  | ||||
| 	tokens = tokens[3:] | ||||
|  | ||||
| 	if len(tokens) == 0 { | ||||
| 		return tokens, errors.New("missing table definition body") | ||||
| 	} | ||||
|  | ||||
| 	for len(tokens) > 0 { | ||||
| 		switch tokens[0] { | ||||
| 		case "NoInsert": | ||||
| 			table.NoInsert = true | ||||
| 			tokens = tokens[1:] | ||||
| 		case "NoUpdate": | ||||
| 			table.NoUpdate = true | ||||
| 			tokens = tokens[1:] | ||||
| 		case "NoDelete": | ||||
| 			table.NoDelete = true | ||||
| 			tokens = tokens[1:] | ||||
| 		case "(": | ||||
| 			return parseTableBody(table, tokens[1:]) | ||||
| 		default: | ||||
| 			return tokens, errors.New("unexpected token in table definition: " + tokens[0]) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return tokens, errors.New("incomplete table definition") | ||||
| } | ||||
|  | ||||
| func parseTableBody(table *table, tokens []string) ([]string, error) { | ||||
| 	var err error | ||||
|  | ||||
| 	for len(tokens) > 0 && tokens[0] != ";" { | ||||
| 		tokens, err = parseTableColumn(table, tokens) | ||||
| 		if err != nil { | ||||
| 			return tokens, err | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(tokens) < 1 || tokens[0] != ";" { | ||||
| 		return tokens, errors.New("incomplete table column definitions") | ||||
| 	} | ||||
|  | ||||
| 	return tokens[1:], nil | ||||
| } | ||||
|  | ||||
| func parseTableColumn(table *table, tokens []string) ([]string, error) { | ||||
| 	if len(tokens) < 2 { | ||||
| 		return tokens, errors.New("incomplete column definition") | ||||
| 	} | ||||
| 	column := &column{ | ||||
| 		Name:    tokens[0], | ||||
| 		Type:    tokens[1], | ||||
| 		SqlName: tokens[0], | ||||
| 	} | ||||
| 	table.Columns = append(table.Columns, column) | ||||
|  | ||||
| 	tokens = tokens[2:] | ||||
| 	for len(tokens) > 0 && tokens[0] != "," && tokens[0] != ")" { | ||||
| 		switch tokens[0] { | ||||
| 		case "AS": | ||||
| 			if len(tokens) < 2 { | ||||
| 				return tokens, errors.New("incomplete AS clause in column definition") | ||||
| 			} | ||||
| 			column.Name = tokens[1] | ||||
| 			tokens = tokens[2:] | ||||
| 		case "PK": | ||||
| 			column.PK = true | ||||
| 			tokens = tokens[1:] | ||||
| 		case "NoInsert": | ||||
| 			column.NoInsert = true | ||||
| 			tokens = tokens[1:] | ||||
| 		case "NoUpdate": | ||||
| 			column.NoUpdate = true | ||||
| 			tokens = tokens[1:] | ||||
| 		default: | ||||
| 			return tokens, errors.New("unexpected token in column definition: " + tokens[0]) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(tokens) == 0 { | ||||
| 		return tokens, errors.New("incomplete column definition") | ||||
| 	} | ||||
|  | ||||
| 	return tokens[1:], nil | ||||
| } | ||||
							
								
								
									
										45
									
								
								sqlgen/parse_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								sqlgen/parse_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| package sqlgen | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"os" | ||||
| 	"path/filepath" | ||||
| 	"reflect" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestParse(t *testing.T) { | ||||
| 	toString := func(v any) string { | ||||
| 		txt, _ := json.MarshalIndent(v, "", "  ") | ||||
| 		return string(txt) | ||||
| 	} | ||||
|  | ||||
| 	paths, err := filepath.Glob("test-files/TestParse/*.def") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	for _, defPath := range paths { | ||||
| 		t.Run(filepath.Base(defPath), func(t *testing.T) { | ||||
| 			parsed, err := parsePath(defPath) | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
|  | ||||
| 			b, err := os.ReadFile(strings.TrimSuffix(defPath, "def") + "json") | ||||
| 			if err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
|  | ||||
| 			expected := &schema{} | ||||
| 			if err := json.Unmarshal(b, expected); err != nil { | ||||
| 				t.Fatal(err) | ||||
| 			} | ||||
|  | ||||
| 			if !reflect.DeepEqual(parsed, expected) { | ||||
| 				t.Fatalf("%s != %s", toString(parsed), toString(expected)) | ||||
| 			} | ||||
| 		}) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										263
									
								
								sqlgen/schema.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										263
									
								
								sqlgen/schema.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,263 @@ | ||||
| package sqlgen | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type schema struct { | ||||
| 	Tables []*table | ||||
| } | ||||
|  | ||||
| type table struct { | ||||
| 	Name     string // Name in SQL | ||||
| 	Type     string // Go type | ||||
| 	NoInsert bool | ||||
| 	NoUpdate bool | ||||
| 	NoDelete bool | ||||
| 	Columns  []*column | ||||
| } | ||||
|  | ||||
| type column struct { | ||||
| 	Name     string | ||||
| 	Type     string | ||||
| 	SqlName  string // Defaults to Name | ||||
| 	PK       bool   // PK won't be updated | ||||
| 	NoInsert bool | ||||
| 	NoUpdate bool // Don't update column in update function | ||||
| } | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| func (t *table) colSQLNames() []string { | ||||
| 	names := make([]string, len(t.Columns)) | ||||
| 	for i := range names { | ||||
| 		names[i] = t.Columns[i].SqlName | ||||
| 	} | ||||
| 	return names | ||||
| } | ||||
|  | ||||
| func (t *table) SelectQuery() string { | ||||
| 	return fmt.Sprintf(`SELECT %s FROM %s`, | ||||
| 		strings.Join(t.colSQLNames(), ","), | ||||
| 		t.Name) | ||||
| } | ||||
|  | ||||
| func (t *table) insertCols() (cols []*column) { | ||||
| 	for _, c := range t.Columns { | ||||
| 		if !c.NoInsert { | ||||
| 			cols = append(cols, c) | ||||
| 		} | ||||
| 	} | ||||
| 	return cols | ||||
| } | ||||
|  | ||||
| func (t *table) InsertQuery() string { | ||||
| 	cols := t.insertCols() | ||||
| 	b := &strings.Builder{} | ||||
| 	b.WriteString(`INSERT INTO `) | ||||
| 	b.WriteString(t.Name) | ||||
| 	b.WriteString(`(`) | ||||
| 	for i, c := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(`,`) | ||||
| 		} | ||||
| 		b.WriteString(c.SqlName) | ||||
| 	} | ||||
| 	b.WriteString(`) VALUES(`) | ||||
| 	for i := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(`,`) | ||||
| 		} | ||||
| 		b.WriteString(`?`) | ||||
| 	} | ||||
| 	b.WriteString(`)`) | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) InsertArgs() string { | ||||
| 	args := []string{} | ||||
| 	for i, col := range t.Columns { | ||||
| 		if !col.NoInsert { | ||||
| 			args = append(args, "row."+t.Columns[i].Name) | ||||
| 		} | ||||
| 	} | ||||
| 	return strings.Join(args, ", ") | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateCols() (cols []*column) { | ||||
| 	for _, c := range t.Columns { | ||||
| 		if !(c.PK || c.NoUpdate) { | ||||
| 			cols = append(cols, c) | ||||
| 		} | ||||
| 	} | ||||
| 	return cols | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateQuery() string { | ||||
| 	cols := t.UpdateCols() | ||||
|  | ||||
| 	b := &strings.Builder{} | ||||
| 	b.WriteString(`UPDATE `) | ||||
| 	b.WriteString(t.Name + ` SET `) | ||||
| 	for i, col := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteByte(',') | ||||
| 		} | ||||
| 		b.WriteString(col.SqlName + `=?`) | ||||
| 	} | ||||
|  | ||||
| 	b.WriteString(` WHERE`) | ||||
|  | ||||
| 	for i, c := range t.pkCols() { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(` AND`) | ||||
| 		} | ||||
| 		b.WriteString(` ` + c.SqlName + `=?`) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateArgs() string { | ||||
| 	cols := t.UpdateCols() | ||||
|  | ||||
| 	b := &strings.Builder{} | ||||
| 	for i, col := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(`, `) | ||||
| 		} | ||||
| 		b.WriteString("row." + col.Name) | ||||
| 	} | ||||
| 	for _, col := range t.pkCols() { | ||||
| 		b.WriteString(", row." + col.Name) | ||||
| 	} | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateFullCols() (cols []*column) { | ||||
| 	for _, c := range t.Columns { | ||||
| 		if !c.PK { | ||||
| 			cols = append(cols, c) | ||||
| 		} | ||||
| 	} | ||||
| 	return cols | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateFullQuery() string { | ||||
| 	cols := t.UpdateFullCols() | ||||
|  | ||||
| 	b := &strings.Builder{} | ||||
| 	b.WriteString(`UPDATE `) | ||||
| 	b.WriteString(t.Name + ` SET `) | ||||
| 	for i, col := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteByte(',') | ||||
| 		} | ||||
| 		b.WriteString(col.SqlName + `=?`) | ||||
| 	} | ||||
|  | ||||
| 	b.WriteString(` WHERE`) | ||||
|  | ||||
| 	for i, c := range t.pkCols() { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(` AND`) | ||||
| 		} | ||||
| 		b.WriteString(` ` + c.SqlName + `=?`) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) UpdateFullArgs() string { | ||||
| 	cols := t.UpdateFullCols() | ||||
|  | ||||
| 	b := &strings.Builder{} | ||||
| 	for i, col := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(`, `) | ||||
| 		} | ||||
| 		b.WriteString("row." + col.Name) | ||||
| 	} | ||||
|  | ||||
| 	for _, col := range t.pkCols() { | ||||
| 		b.WriteString(", row." + col.Name) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) pkCols() (cols []*column) { | ||||
| 	for _, c := range t.Columns { | ||||
| 		if c.PK { | ||||
| 			cols = append(cols, c) | ||||
| 		} | ||||
| 	} | ||||
| 	return cols | ||||
| } | ||||
|  | ||||
| func (t *table) PKFunctionArgs() string { | ||||
| 	b := &strings.Builder{} | ||||
| 	for _, col := range t.pkCols() { | ||||
| 		b.WriteString(col.Name) | ||||
| 		b.WriteString(` `) | ||||
| 		b.WriteString(col.Type) | ||||
| 		b.WriteString(",\n") | ||||
| 	} | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) DeleteQuery() string { | ||||
| 	cols := t.pkCols() | ||||
|  | ||||
| 	b := &strings.Builder{} | ||||
| 	b.WriteString(`DELETE FROM `) | ||||
| 	b.WriteString(t.Name) | ||||
| 	b.WriteString(` WHERE `) | ||||
|  | ||||
| 	for i, col := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(` AND `) | ||||
| 		} | ||||
| 		b.WriteString(col.SqlName) | ||||
| 		b.WriteString(`=?`) | ||||
| 	} | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) DeleteArgs() string { | ||||
| 	cols := t.pkCols() | ||||
| 	b := &strings.Builder{} | ||||
|  | ||||
| 	for i, col := range cols { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(`,`) | ||||
| 		} | ||||
| 		b.WriteString(col.Name) | ||||
| 	} | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) GetQuery() string { | ||||
| 	b := &strings.Builder{} | ||||
| 	b.WriteString(t.SelectQuery()) | ||||
| 	b.WriteString(` WHERE `) | ||||
| 	for i, col := range t.pkCols() { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(` AND `) | ||||
| 		} | ||||
| 		b.WriteString(col.SqlName + `=?`) | ||||
| 	} | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func (t *table) ScanArgs() string { | ||||
| 	b := &strings.Builder{} | ||||
| 	for i, col := range t.Columns { | ||||
| 		if i != 0 { | ||||
| 			b.WriteString(`, `) | ||||
| 		} | ||||
| 		b.WriteString(`&row.` + col.Name) | ||||
| 	} | ||||
| 	return b.String() | ||||
| } | ||||
							
								
								
									
										206
									
								
								sqlgen/sqlite.go.tmpl
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										206
									
								
								sqlgen/sqlite.go.tmpl
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,206 @@ | ||||
| package {{.PackageName}} | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"iter" | ||||
| ) | ||||
|  | ||||
| type TX interface { | ||||
| 	Exec(query string, args ...any) (sql.Result, error) | ||||
| 	Query(query string, args ...any) (*sql.Rows, error) | ||||
| 	QueryRow(query string, args ...any) *sql.Row | ||||
| } | ||||
|  | ||||
| {{range .Schema.Tables}} | ||||
|  | ||||
| // ---------------------------------------------------------------------------- | ||||
| // Table: {{.Name}} | ||||
| // ---------------------------------------------------------------------------- | ||||
|  | ||||
| type {{.Type}} struct { | ||||
| 	{{- range .Columns}} | ||||
| 	{{.Name}} {{.Type}}{{end}} | ||||
| } | ||||
|  | ||||
| const {{.Type}}_SelectQuery = "{{.SelectQuery}}" | ||||
|  | ||||
| {{if not .NoInsert -}} | ||||
|  | ||||
| func {{.Type}}_Insert( | ||||
| 	tx TX, | ||||
| 	row *{{.Type}}, | ||||
| ) (err error) { | ||||
|   {{.Type}}_Sanitize(row) | ||||
|   if err = {{.Type}}_Validate(row); err != nil { | ||||
|     return err | ||||
|   } | ||||
|  | ||||
| 	_, err = tx.Exec("{{.InsertQuery}}", {{.InsertArgs}}) | ||||
| 	return err | ||||
| } | ||||
|  | ||||
| {{- end}} {{/* if not .NoInsert */}} | ||||
|  | ||||
| {{if not .NoUpdate -}} | ||||
|  | ||||
| {{if .UpdateCols -}} | ||||
|  | ||||
| func {{.Type}}_Update( | ||||
| 	tx TX, | ||||
| 	row *{{.Type}}, | ||||
| ) (found bool, err error) { | ||||
|   {{.Type}}_Sanitize(row) | ||||
|   if err = {{.Type}}_Validate(row); err != nil { | ||||
|     return false, err | ||||
|   } | ||||
|  | ||||
| 	result, err := tx.Exec("{{.UpdateQuery}}", {{.UpdateArgs}}) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	n, err := result.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|  | ||||
| 	if n > 1 { | ||||
| 		panic("multiple rows updated") | ||||
| 	} | ||||
|  | ||||
| 	return n != 0, nil | ||||
| } | ||||
| {{- end}} | ||||
|  | ||||
| {{if .UpdateFullCols -}} | ||||
|  | ||||
| func {{.Type}}_UpdateFull( | ||||
| 	tx TX, | ||||
| 	row *{{.Type}}, | ||||
| ) (found bool, err error) { | ||||
|   {{.Type}}_Sanitize(row) | ||||
|   if err = {{.Type}}_Validate(row); err != nil { | ||||
|     return false, err | ||||
|   } | ||||
|  | ||||
| 	result, err := tx.Exec("{{.UpdateFullQuery}}", {{.UpdateFullArgs}}) | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	n, err := result.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|  | ||||
| 	if n > 1 { | ||||
| 		panic("multiple rows updated") | ||||
| 	} | ||||
|  | ||||
| 	return n != 0, nil | ||||
| } | ||||
| {{- end}} | ||||
|  | ||||
|  | ||||
| {{- end}} {{/* if not .NoUpdate */}} | ||||
|  | ||||
| {{if not .NoDelete -}} | ||||
|  | ||||
| func {{.Type}}_Delete( | ||||
| 	tx TX, | ||||
| 	{{.PKFunctionArgs -}} | ||||
| ) (found bool, err error) { | ||||
| 	result, err := tx.Exec("{{.DeleteQuery}}", {{.DeleteArgs}}) | ||||
| 		if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
|  | ||||
| 	n, err := result.RowsAffected() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
|  | ||||
| 	if n > 1 { | ||||
| 		panic("multiple rows deleted") | ||||
| 	} | ||||
|  | ||||
| 	return n != 0, nil | ||||
| } | ||||
|  | ||||
| {{- end}} | ||||
|  | ||||
| func {{.Type}}_Get( | ||||
| 	tx TX, | ||||
| 	{{.PKFunctionArgs -}} | ||||
| ) ( | ||||
| 	row *{{.Type}}, | ||||
| 	err error, | ||||
| ) { | ||||
|   row = &{{.Type}}{} | ||||
| 	r := tx.QueryRow("{{.GetQuery}}", {{.DeleteArgs}}) | ||||
| 	err = r.Scan({{.ScanArgs}}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
|  | ||||
| func {{.Type}}_GetWhere( | ||||
| 	tx TX, | ||||
| 	query string, | ||||
| 	args ...any, | ||||
| ) ( | ||||
| 	row *{{.Type}}, | ||||
| 	err error, | ||||
| ) { | ||||
|   row = &{{.Type}}{} | ||||
| 	r := tx.QueryRow(query, args...) | ||||
| 	err = r.Scan({{.ScanArgs}}) | ||||
| 	return | ||||
| } | ||||
|  | ||||
|  | ||||
|  | ||||
| func {{.Type}}_Iterate( | ||||
| 	tx TX, | ||||
| 	query string, | ||||
| 	args ...any, | ||||
| ) ( | ||||
| 	iter.Seq2[*{{.Type}}, error], | ||||
| ) { | ||||
| 	rows, err := tx.Query(query, args...) | ||||
| 	if err != nil { | ||||
| 		return func(yield func(*{{.Type}}, error) bool) { | ||||
| 			yield(nil, err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return func(yield func(*{{.Type}}, error) bool) { | ||||
| 		defer rows.Close() | ||||
| 		for rows.Next() { | ||||
| 			row := &{{.Type}}{} | ||||
| 			err := rows.Scan({{.ScanArgs}}) | ||||
| 			if !yield(row, err) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func {{.Type}}_List( | ||||
| 	tx TX, | ||||
| 	query string, | ||||
| 	args ...any, | ||||
| ) ( | ||||
| 	l []*{{.Type}}, | ||||
| 	err error, | ||||
| ) { | ||||
| 	for row, err := range {{.Type}}_Iterate(tx, query, args...) { | ||||
| 		if err != nil { | ||||
| 			return nil, err | ||||
| 		} | ||||
| 		l = append(l, row) | ||||
| 	} | ||||
| 	return l, nil | ||||
| } | ||||
|  | ||||
|  | ||||
| {{end}} {{/* range .Schema.Tables */}} | ||||
							
								
								
									
										36
									
								
								sqlgen/template.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								sqlgen/template.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,36 @@ | ||||
| package sqlgen | ||||
|  | ||||
| import ( | ||||
| 	_ "embed" | ||||
| 	"os" | ||||
| 	"os/exec" | ||||
| 	"path/filepath" | ||||
| 	"text/template" | ||||
| ) | ||||
|  | ||||
| //go:embed sqlite.go.tmpl | ||||
| var sqliteTemplate string | ||||
|  | ||||
| func render(templateStr, schemaPath, outputPath string) error { | ||||
| 	sch, err := parsePath(schemaPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	tmpl := template.Must(template.New("").Parse(templateStr)) | ||||
| 	fOut, err := os.Create(outputPath) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	defer fOut.Close() | ||||
|  | ||||
| 	err = tmpl.Execute(fOut, struct { | ||||
| 		PackageName string | ||||
| 		Schema      *schema | ||||
| 	}{filepath.Base(filepath.Dir(outputPath)), sch}) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	return exec.Command("gofmt", "-w", outputPath).Run() | ||||
| } | ||||
							
								
								
									
										3
									
								
								sqlgen/test-files/TestParse/000.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								sqlgen/test-files/TestParse/000.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| TABLE users OF User NoDelete ( | ||||
|   user_id string AS UserID PK | ||||
| ); | ||||
							
								
								
									
										17
									
								
								sqlgen/test-files/TestParse/000.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								sqlgen/test-files/TestParse/000.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,17 @@ | ||||
| { | ||||
|   "Tables": [ | ||||
|     { | ||||
|       "Name": "users", | ||||
|       "Type": "User", | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         } | ||||
|       ] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
							
								
								
									
										4
									
								
								sqlgen/test-files/TestParse/001.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								sqlgen/test-files/TestParse/001.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| TABLE users OF User NoDelete ( | ||||
|   user_id string AS UserID PK, | ||||
|   email   string AS Email NoUpdate | ||||
| ); | ||||
							
								
								
									
										22
									
								
								sqlgen/test-files/TestParse/001.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								sqlgen/test-files/TestParse/001.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| { | ||||
|   "Tables": [ | ||||
|     { | ||||
|       "Name": "users", | ||||
|       "Type": "User", | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         }, { | ||||
|           "Name": "Email", | ||||
|           "Type": "string", | ||||
|           "SqlName": "email", | ||||
|           "NoUpdate": true | ||||
|         } | ||||
|       ] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
							
								
								
									
										6
									
								
								sqlgen/test-files/TestParse/002.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								sqlgen/test-files/TestParse/002.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| TABLE users OF User NoDelete ( | ||||
|   user_id string AS UserID PK, | ||||
|   email   string AS Email NoUpdate, | ||||
|   name    string AS Name NoInsert, | ||||
|   admin   bool   AS Admin NoInsert NoUpdate | ||||
| ); | ||||
							
								
								
									
										33
									
								
								sqlgen/test-files/TestParse/002.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								sqlgen/test-files/TestParse/002.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,33 @@ | ||||
| { | ||||
|   "Tables": [ | ||||
|     { | ||||
|       "Name": "users", | ||||
|       "Type": "User", | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         }, { | ||||
|           "Name": "Email", | ||||
|           "Type": "string", | ||||
|           "SqlName": "email", | ||||
|           "NoUpdate": true | ||||
|         }, { | ||||
|           "Name": "Name", | ||||
|           "Type": "string", | ||||
|           "SqlName": "name", | ||||
|           "NoInsert": true | ||||
|         }, { | ||||
|           "Name": "Admin", | ||||
|           "Type": "bool", | ||||
|           "SqlName": "admin", | ||||
|           "NoInsert": true, | ||||
|           "NoUpdate": true | ||||
|         } | ||||
|       ] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
							
								
								
									
										12
									
								
								sqlgen/test-files/TestParse/003.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										12
									
								
								sqlgen/test-files/TestParse/003.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,12 @@ | ||||
| TABLE users OF User NoDelete ( | ||||
|   user_id string AS UserID PK, | ||||
|   email   string AS Email NoUpdate, | ||||
|   name    string AS Name NoInsert, | ||||
|   admin   bool   AS Admin NoInsert NoUpdate | ||||
| ); | ||||
|  | ||||
| TABLE users_view OF UserView NoInsert NoUpdate NoDelete ( | ||||
|   user_id string AS UserID PK, | ||||
|   email   string AS Email, | ||||
|   name    string AS Name | ||||
| ); | ||||
							
								
								
									
										61
									
								
								sqlgen/test-files/TestParse/003.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								sqlgen/test-files/TestParse/003.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,61 @@ | ||||
| { | ||||
|   "Tables": [ | ||||
|     { | ||||
|       "Name": "users", | ||||
|       "Type": "User", | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Email", | ||||
|           "Type": "string", | ||||
|           "SqlName": "email", | ||||
|           "NoUpdate": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Name", | ||||
|           "Type": "string", | ||||
|           "SqlName": "name", | ||||
|           "NoInsert": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Admin", | ||||
|           "Type": "bool", | ||||
|           "SqlName": "admin", | ||||
|           "NoInsert": true, | ||||
|           "NoUpdate": true | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "Name": "users_view", | ||||
|       "Type": "UserView", | ||||
|       "NoInsert": true, | ||||
|       "NoUpdate": true, | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Email", | ||||
|           "Type": "string", | ||||
|           "SqlName": "email" | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Name", | ||||
|           "Type": "string", | ||||
|           "SqlName": "name" | ||||
|         } | ||||
|       ] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
							
								
								
									
										13
									
								
								sqlgen/test-files/TestParse/004.def
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								sqlgen/test-files/TestParse/004.def
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,13 @@ | ||||
| TABLE users OF User NoDelete ( | ||||
|   user_id string AS UserID PK, | ||||
|   email   string AS Email NoUpdate, | ||||
|   name    string AS Name NoInsert, | ||||
|   admin   bool   AS Admin NoInsert NoUpdate, | ||||
|   SSN     string NoUpdate | ||||
| ); | ||||
|  | ||||
| TABLE users_view OF UserView NoInsert NoUpdate NoDelete ( | ||||
|   user_id string AS UserID PK, | ||||
|   email   string AS Email, | ||||
|   name    string AS Name | ||||
| ); | ||||
							
								
								
									
										66
									
								
								sqlgen/test-files/TestParse/004.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								sqlgen/test-files/TestParse/004.json
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,66 @@ | ||||
| { | ||||
|   "Tables": [ | ||||
|     { | ||||
|       "Name": "users", | ||||
|       "Type": "User", | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Email", | ||||
|           "Type": "string", | ||||
|           "SqlName": "email", | ||||
|           "NoUpdate": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Name", | ||||
|           "Type": "string", | ||||
|           "SqlName": "name", | ||||
|           "NoInsert": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Admin", | ||||
|           "Type": "bool", | ||||
|           "SqlName": "admin", | ||||
|           "NoInsert": true, | ||||
|           "NoUpdate": true | ||||
|         }, { | ||||
|           "Name": "SSN", | ||||
|           "Type": "string", | ||||
|           "SqlName": "SSN", | ||||
|           "NoUpdate": true | ||||
|         } | ||||
|       ] | ||||
|     }, | ||||
|     { | ||||
|       "Name": "users_view", | ||||
|       "Type": "UserView", | ||||
|       "NoInsert": true, | ||||
|       "NoUpdate": true, | ||||
|       "NoDelete": true, | ||||
|       "Columns": [ | ||||
|         { | ||||
|           "Name": "UserID", | ||||
|           "Type": "string", | ||||
|           "SqlName": "user_id", | ||||
|           "PK": true | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Email", | ||||
|           "Type": "string", | ||||
|           "SqlName": "email" | ||||
|         }, | ||||
|         { | ||||
|           "Name": "Name", | ||||
|           "Type": "string", | ||||
|           "SqlName": "name" | ||||
|         } | ||||
|       ] | ||||
|     } | ||||
|   ] | ||||
| } | ||||
							
								
								
									
										45
									
								
								sqliteutil/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										45
									
								
								sqliteutil/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,45 @@ | ||||
| # sqliteutil | ||||
|  | ||||
| ## Transactions | ||||
|  | ||||
| Simplify postgres transactions using `WithTx` for serializable transactions, | ||||
| or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner` | ||||
| type to get automatic retries of serialization errors. | ||||
|  | ||||
| ## Migrations | ||||
|  | ||||
| Put your migrations into a directory, for example `migrations`, ordered by name | ||||
| (YYYY-MM-DD prefix, for example). Embed the directory and pass it to the | ||||
| `Migrate` function: | ||||
|  | ||||
| ```Go | ||||
| //go:embed migrations | ||||
| var migrations embed.FS | ||||
|  | ||||
| func init() { | ||||
|     Migrate(db, migrations) // Check the error, of course. | ||||
| } | ||||
| ``` | ||||
|  | ||||
| ## Testing | ||||
|  | ||||
| In order to test this packge, we need to create a test user and database: | ||||
|  | ||||
| ``` | ||||
| sudo su postgres | ||||
| psql | ||||
|  | ||||
| CREATE DATABASE test; | ||||
| CREATE USER test WITH ENCRYPTED PASSWORD 'test'; | ||||
| GRANT ALL PRIVILEGES ON DATABASE test TO test; | ||||
|  | ||||
| use test | ||||
|  | ||||
| GRANT ALL ON SCHEMA public TO test; | ||||
| ``` | ||||
|  | ||||
| Check that you can connect via the command line: | ||||
|  | ||||
| ``` | ||||
| psql -h 127.0.0.1 -U test --password test | ||||
| ``` | ||||
							
								
								
									
										5
									
								
								sqliteutil/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								sqliteutil/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| module git.crumpington.com/lib/sqliteutil | ||||
|  | ||||
| go 1.23.2 | ||||
|  | ||||
| require github.com/mattn/go-sqlite3 v1.14.24 // indirect | ||||
							
								
								
									
										2
									
								
								sqliteutil/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								sqliteutil/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= | ||||
| github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= | ||||
							
								
								
									
										82
									
								
								sqliteutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										82
									
								
								sqliteutil/migrate.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,82 @@ | ||||
| package sqliteutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"embed" | ||||
| 	"errors" | ||||
| 	"fmt" | ||||
| 	"path/filepath" | ||||
| 	"sort" | ||||
| ) | ||||
|  | ||||
| const initMigrationTableQuery = ` | ||||
| CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` | ||||
|  | ||||
| const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` | ||||
|  | ||||
| const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` | ||||
|  | ||||
| func Migrate(db *sql.DB, migrationFS embed.FS) error { | ||||
| 	return WithTx(db, func(tx *sql.Tx) error { | ||||
| 		if _, err := tx.Exec(initMigrationTableQuery); err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		dirs, err := migrationFS.ReadDir(".") | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		if len(dirs) != 1 { | ||||
| 			return errors.New("expected a single migrations directory") | ||||
| 		} | ||||
|  | ||||
| 		if !dirs[0].IsDir() { | ||||
| 			return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) | ||||
| 		} | ||||
|  | ||||
| 		dirName := dirs[0].Name() | ||||
| 		files, err := migrationFS.ReadDir(dirName) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
|  | ||||
| 		// Sort sql files by name. | ||||
| 		sort.Slice(files, func(i, j int) bool { | ||||
| 			return files[i].Name() < files[j].Name() | ||||
| 		}) | ||||
|  | ||||
| 		for _, dirEnt := range files { | ||||
| 			if !dirEnt.Type().IsRegular() { | ||||
| 				return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) | ||||
| 			} | ||||
|  | ||||
| 			var ( | ||||
| 				name   = dirEnt.Name() | ||||
| 				exists bool | ||||
| 			) | ||||
|  | ||||
| 			err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			if exists { | ||||
| 				continue | ||||
| 			} | ||||
|  | ||||
| 			migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) | ||||
| 			if err != nil { | ||||
| 				return err | ||||
| 			} | ||||
|  | ||||
| 			if _, err := tx.Exec(string(migration)); err != nil { | ||||
| 				return fmt.Errorf("migration %s failed: %v", name, err) | ||||
| 			} | ||||
|  | ||||
| 			if _, err := tx.Exec(insertMigrationQuery, name); err != nil { | ||||
| 				return err | ||||
| 			} | ||||
| 		} | ||||
| 		return nil | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										44
									
								
								sqliteutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								sqliteutil/migrate_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,44 @@ | ||||
| package sqliteutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
| 	"embed" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| //go:embed test-migrations | ||||
| var testMigrationFS embed.FS | ||||
|  | ||||
| func TestMigrate(t *testing.T) { | ||||
| 	db, err := sql.Open("sqlite3", ":memory:") | ||||
| 	if err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	if err := Migrate(db, testMigrationFS); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	// Shouldn't have any effect. | ||||
| 	if err := Migrate(db, testMigrationFS); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
|  | ||||
| 	query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` | ||||
| 	var exists bool | ||||
|  | ||||
| 	if err = db.QueryRow(query, 1).Scan(&exists); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if exists { | ||||
| 		t.Fatal("1 shouldn't exist") | ||||
| 	} | ||||
|  | ||||
| 	if err = db.QueryRow(query, 2).Scan(&exists); err != nil { | ||||
| 		t.Fatal(err) | ||||
| 	} | ||||
| 	if !exists { | ||||
| 		t.Fatal("2 should exist") | ||||
| 	} | ||||
|  | ||||
| } | ||||
							
								
								
									
										9
									
								
								sqliteutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								sqliteutil/test-migrations/000.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,9 @@ | ||||
| CREATE TABLE users( | ||||
|        UserID BIGINT NOT NULL PRIMARY KEY, | ||||
|        Email  TEXT   NOT NULL UNIQUE); | ||||
|  | ||||
| CREATE TABLE user_notes( | ||||
|        UserID BIGINT NOT NULL REFERENCES users(UserID), | ||||
|        NoteID BIGINT NOT NULL, | ||||
|        Note   Text   NOT NULL, | ||||
|        PRIMARY KEY(UserID,NoteID)); | ||||
							
								
								
									
										1
									
								
								sqliteutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								sqliteutil/test-migrations/001.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); | ||||
							
								
								
									
										1
									
								
								sqliteutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								sqliteutil/test-migrations/002.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| DELETE FROM users WHERE UserID=1; | ||||
							
								
								
									
										28
									
								
								sqliteutil/tx.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								sqliteutil/tx.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,28 @@ | ||||
| package sqliteutil | ||||
|  | ||||
| import ( | ||||
| 	"database/sql" | ||||
|  | ||||
| 	_ "github.com/mattn/go-sqlite3" | ||||
| ) | ||||
|  | ||||
| // This is a convenience function to run a function within a transaction. | ||||
| func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { | ||||
| 	// Start a transaction. | ||||
| 	tx, err := db.Begin() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	err = fn(tx) | ||||
|  | ||||
| 	if err == nil { | ||||
| 		err = tx.Commit() | ||||
| 	} | ||||
|  | ||||
| 	if err != nil { | ||||
| 		_ = tx.Rollback() | ||||
| 	} | ||||
|  | ||||
| 	return err | ||||
| } | ||||
							
								
								
									
										2
									
								
								tagengine/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										2
									
								
								tagengine/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,2 @@ | ||||
| # tagengine | ||||
|  | ||||
							
								
								
									
										3
									
								
								tagengine/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								tagengine/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| module git.crumpington.com/lib/tagengine | ||||
|  | ||||
| go 1.23.2 | ||||
							
								
								
									
										0
									
								
								tagengine/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								tagengine/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										30
									
								
								tagengine/ngram.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tagengine/ngram.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package tagengine | ||||
|  | ||||
| import "unicode" | ||||
|  | ||||
| func ngramLength(s string) int { | ||||
| 	N := len(s) | ||||
| 	i := 0 | ||||
| 	count := 0 | ||||
|  | ||||
| 	for { | ||||
| 		// Eat spaces. | ||||
| 		for i < N && unicode.IsSpace(rune(s[i])) { | ||||
| 			i++ | ||||
| 		} | ||||
|  | ||||
| 		// Done? | ||||
| 		if i == N { | ||||
| 			break | ||||
| 		} | ||||
|  | ||||
| 		// Non-space! | ||||
| 		count++ | ||||
|  | ||||
| 		// Eat non-spaces. | ||||
| 		for i < N && !unicode.IsSpace(rune(s[i])) { | ||||
| 			i++ | ||||
| 		} | ||||
| 	} | ||||
| 	return count | ||||
| } | ||||
							
								
								
									
										31
									
								
								tagengine/ngram_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								tagengine/ngram_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,31 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestNGramLength(t *testing.T) { | ||||
| 	type Case struct { | ||||
| 		Input  string | ||||
| 		Length int | ||||
| 	} | ||||
|  | ||||
| 	cases := []Case{ | ||||
| 		{"a b c", 3}, | ||||
| 		{"  xyz\nlkj  dflaj a", 4}, | ||||
| 		{"a", 1}, | ||||
| 		{" a", 1}, | ||||
| 		{"a", 1}, | ||||
| 		{" a\n", 1}, | ||||
| 		{" a ", 1}, | ||||
| 		{"\tx\ny\nz q ", 4}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		length := ngramLength(tc.Input) | ||||
| 		if length != tc.Length { | ||||
| 			log.Fatalf("%s: %d != %d", tc.Input, length, tc.Length) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										79
									
								
								tagengine/node.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										79
									
								
								tagengine/node.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,79 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| type node struct { | ||||
| 	Token    string | ||||
| 	Matches  []*Rule // If a list of tokens reaches this node, it matches these. | ||||
| 	Children map[string]*node | ||||
| } | ||||
|  | ||||
| func (n *node) AddRule(r *Rule) { | ||||
| 	n.addRule(r, 0) | ||||
| } | ||||
|  | ||||
| func (n *node) addRule(r *Rule, idx int) { | ||||
| 	if len(r.Includes) == idx { | ||||
| 		n.Matches = append(n.Matches, r) | ||||
| 		return | ||||
| 	} | ||||
|  | ||||
| 	token := r.Includes[idx] | ||||
|  | ||||
| 	child, ok := n.Children[token] | ||||
| 	if !ok { | ||||
| 		child = &node{ | ||||
| 			Token:    token, | ||||
| 			Children: map[string]*node{}, | ||||
| 		} | ||||
| 		n.Children[token] = child | ||||
| 	} | ||||
|  | ||||
| 	child.addRule(r, idx+1) | ||||
| } | ||||
|  | ||||
| // Note that tokens must be sorted. This is the case for tokens created from | ||||
| // the tokenize function. | ||||
| func (n *node) Match(tokens []string) (rules []*Rule) { | ||||
| 	return n.match(tokens, rules) | ||||
| } | ||||
|  | ||||
| func (n *node) match(tokens []string, rules []*Rule) []*Rule { | ||||
| 	// Check for a match. | ||||
| 	if n.Matches != nil { | ||||
| 		rules = append(rules, n.Matches...) | ||||
| 	} | ||||
|  | ||||
| 	if len(tokens) == 0 { | ||||
| 		return rules | ||||
| 	} | ||||
|  | ||||
| 	// Attempt to match children. | ||||
| 	for i := 0; i < len(tokens); i++ { | ||||
| 		token := tokens[i] | ||||
| 		if child, ok := n.Children[token]; ok { | ||||
| 			rules = child.match(tokens[i+1:], rules) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return rules | ||||
| } | ||||
|  | ||||
| func (n *node) Dump() { | ||||
| 	n.dump(0) | ||||
| } | ||||
|  | ||||
| func (n *node) dump(depth int) { | ||||
| 	indent := strings.Repeat(" ", 2*depth) | ||||
| 	tag := "" | ||||
| 	for _, m := range n.Matches { | ||||
| 		tag += " " + m.Tag | ||||
| 	} | ||||
| 	fmt.Printf("%s%s%s\n", indent, n.Token, tag) | ||||
| 	for _, child := range n.Children { | ||||
| 		child.dump(depth + 1) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										159
									
								
								tagengine/rule.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										159
									
								
								tagengine/rule.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,159 @@ | ||||
| package tagengine | ||||
|  | ||||
| type Rule struct { | ||||
| 	// The purpose of a Rule is to attach it's Tag to matching text. | ||||
| 	Tag string | ||||
|  | ||||
| 	// Includes is a list of strings that must be found in the input in order to | ||||
| 	// match. | ||||
| 	Includes []string | ||||
|  | ||||
| 	// Excludes is a list of strings that can exclude a match for this rule. | ||||
| 	Excludes []string | ||||
|  | ||||
| 	// Blocks: If this rule is matched, then it will block matches of any tags | ||||
| 	// listed here. | ||||
| 	Blocks []string | ||||
|  | ||||
| 	// The Score encodes the complexity of the Rule. A higher score indicates a | ||||
| 	// more specific match. A Rule more includes, or includes with multiple words | ||||
| 	// should havee a higher Score than a Rule with fewer includes or less | ||||
| 	// complex includes. | ||||
| 	Score int | ||||
|  | ||||
| 	excludes map[string]struct{} | ||||
| } | ||||
|  | ||||
| func NewRule(tag string) Rule { | ||||
| 	return Rule{Tag: tag} | ||||
| } | ||||
|  | ||||
| func (r Rule) Inc(l ...string) Rule { | ||||
| 	return Rule{ | ||||
| 		Tag:      r.Tag, | ||||
| 		Includes: append(r.Includes, l...), | ||||
| 		Excludes: r.Excludes, | ||||
| 		Blocks:   r.Blocks, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r Rule) Exc(l ...string) Rule { | ||||
| 	return Rule{ | ||||
| 		Tag:      r.Tag, | ||||
| 		Includes: r.Includes, | ||||
| 		Excludes: append(r.Excludes, l...), | ||||
| 		Blocks:   r.Blocks, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r Rule) Block(l ...string) Rule { | ||||
| 	return Rule{ | ||||
| 		Tag:      r.Tag, | ||||
| 		Includes: r.Includes, | ||||
| 		Excludes: r.Excludes, | ||||
| 		Blocks:   append(r.Blocks, l...), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (rule *Rule) normalize(sanitize func(string) string) { | ||||
| 	for i, token := range rule.Includes { | ||||
| 		rule.Includes[i] = sanitize(token) | ||||
| 	} | ||||
| 	for i, token := range rule.Excludes { | ||||
| 		rule.Excludes[i] = sanitize(token) | ||||
| 	} | ||||
|  | ||||
| 	sortTokens(rule.Includes) | ||||
| 	sortTokens(rule.Excludes) | ||||
|  | ||||
| 	rule.excludes = map[string]struct{}{} | ||||
| 	for _, s := range rule.Excludes { | ||||
| 		rule.excludes[s] = struct{}{} | ||||
| 	} | ||||
|  | ||||
| 	rule.Score = rule.computeScore() | ||||
| } | ||||
|  | ||||
| func (r Rule) maxNGram() int { | ||||
| 	max := 0 | ||||
| 	for _, s := range r.Includes { | ||||
| 		n := ngramLength(s) | ||||
| 		if n > max { | ||||
| 			max = n | ||||
| 		} | ||||
| 	} | ||||
| 	for _, s := range r.Excludes { | ||||
| 		n := ngramLength(s) | ||||
| 		if n > max { | ||||
| 			max = n | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return max | ||||
| } | ||||
|  | ||||
| func (r Rule) isExcluded(tokens []string) bool { | ||||
| 	// This is most often the case. | ||||
| 	if len(r.excludes) == 0 { | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	for _, s := range tokens { | ||||
| 		if _, ok := r.excludes[s]; ok { | ||||
| 			return true | ||||
| 		} | ||||
| 	} | ||||
| 	return false | ||||
| } | ||||
|  | ||||
| func (r Rule) computeScore() (score int) { | ||||
| 	for _, token := range r.Includes { | ||||
| 		n := ngramLength(token) | ||||
| 		score += n * (n + 1) / 2 | ||||
| 	} | ||||
| 	return score | ||||
| } | ||||
|  | ||||
| func ruleLess(lhs, rhs *Rule) bool { | ||||
| 	// If scores differ, sort by score. | ||||
| 	if lhs.Score != rhs.Score { | ||||
| 		return lhs.Score < rhs.Score | ||||
| 	} | ||||
|  | ||||
| 	// If include depth differs, sort by depth. | ||||
| 	lDepth := len(lhs.Includes) | ||||
| 	rDepth := len(rhs.Includes) | ||||
|  | ||||
| 	if lDepth != rDepth { | ||||
| 		return lDepth < rDepth | ||||
| 	} | ||||
|  | ||||
| 	// If exclude depth differs, sort by depth. | ||||
| 	lDepth = len(lhs.Excludes) | ||||
| 	rDepth = len(rhs.Excludes) | ||||
|  | ||||
| 	if lDepth != rDepth { | ||||
| 		return lDepth < rDepth | ||||
| 	} | ||||
|  | ||||
| 	// Sort alphabetically by includes. | ||||
| 	for i := range lhs.Includes { | ||||
| 		if lhs.Includes[i] != rhs.Includes[i] { | ||||
| 			return lhs.Includes[i] < rhs.Includes[i] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Sort by alphabetically by excludes. | ||||
| 	for i := range lhs.Excludes { | ||||
| 		if lhs.Excludes[i] != rhs.Excludes[i] { | ||||
| 			return lhs.Excludes[i] < rhs.Excludes[i] | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Sort by tag. | ||||
| 	if lhs.Tag != rhs.Tag { | ||||
| 		return lhs.Tag < rhs.Tag | ||||
| 	} | ||||
|  | ||||
| 	return false | ||||
| } | ||||
							
								
								
									
										58
									
								
								tagengine/rulegroup.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										58
									
								
								tagengine/rulegroup.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,58 @@ | ||||
| package tagengine | ||||
|  | ||||
| // A RuleGroup can be converted into a list of rules. Each rule will point to | ||||
| // the same tag, and have the same exclude set and blocks. | ||||
| type RuleGroup struct { | ||||
| 	Tag      string | ||||
| 	Includes [][]string | ||||
| 	Excludes []string | ||||
| 	Blocks   []string | ||||
| } | ||||
|  | ||||
| func NewRuleGroup(tag string) RuleGroup { | ||||
| 	return RuleGroup{ | ||||
| 		Tag:      tag, | ||||
| 		Includes: [][]string{}, | ||||
| 		Excludes: []string{}, | ||||
| 		Blocks:   []string{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (g RuleGroup) Inc(l ...string) RuleGroup { | ||||
| 	return RuleGroup{ | ||||
| 		Tag:      g.Tag, | ||||
| 		Includes: append(g.Includes, l), | ||||
| 		Excludes: g.Excludes, | ||||
| 		Blocks:   g.Blocks, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (g RuleGroup) Exc(l ...string) RuleGroup { | ||||
| 	return RuleGroup{ | ||||
| 		Tag:      g.Tag, | ||||
| 		Includes: g.Includes, | ||||
| 		Excludes: append(g.Excludes, l...), | ||||
| 		Blocks:   g.Blocks, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (g RuleGroup) Block(l ...string) RuleGroup { | ||||
| 	return RuleGroup{ | ||||
| 		Tag:      g.Tag, | ||||
| 		Includes: g.Includes, | ||||
| 		Excludes: g.Excludes, | ||||
| 		Blocks:   append(g.Blocks, l...), | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (g RuleGroup) ToList() (l []Rule) { | ||||
| 	for _, includes := range g.Includes { | ||||
| 		l = append(l, Rule{ | ||||
| 			Tag:      g.Tag, | ||||
| 			Excludes: g.Excludes, | ||||
| 			Includes: includes, | ||||
| 			Blocks:   g.Blocks, | ||||
| 		}) | ||||
| 	} | ||||
| 	return | ||||
| } | ||||
							
								
								
									
										162
									
								
								tagengine/ruleset.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										162
									
								
								tagengine/ruleset.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,162 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"sort" | ||||
| ) | ||||
|  | ||||
| type RuleSet struct { | ||||
| 	root     *node | ||||
| 	maxNgram int | ||||
| 	sanitize func(string) string | ||||
| 	rules    []*Rule | ||||
| } | ||||
|  | ||||
| func NewRuleSet() *RuleSet { | ||||
| 	return &RuleSet{ | ||||
| 		root: &node{ | ||||
| 			Token:    "/", | ||||
| 			Children: map[string]*node{}, | ||||
| 		}, | ||||
| 		sanitize: BasicSanitizer, | ||||
| 		rules:    []*Rule{}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func NewRuleSetFromList(rules []Rule) *RuleSet { | ||||
| 	rs := NewRuleSet() | ||||
| 	rs.AddRule(rules...) | ||||
| 	return rs | ||||
| } | ||||
|  | ||||
| func (t *RuleSet) Add(ruleOrGroup ...interface{}) { | ||||
| 	for _, ix := range ruleOrGroup { | ||||
| 		switch x := ix.(type) { | ||||
| 		case Rule: | ||||
| 			t.AddRule(x) | ||||
| 		case RuleGroup: | ||||
| 			t.AddRuleGroup(x) | ||||
| 		default: | ||||
| 			panic("Add expects either Rule or RuleGroup objects.") | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (t *RuleSet) AddRule(rules ...Rule) { | ||||
| 	for _, rule := range rules { | ||||
| 		rule := rule | ||||
|  | ||||
| 		// Make sure rule is well-formed. | ||||
| 		rule.normalize(t.sanitize) | ||||
|  | ||||
| 		// Update maxNgram. | ||||
| 		N := rule.maxNGram() | ||||
| 		if N > t.maxNgram { | ||||
| 			t.maxNgram = N | ||||
| 		} | ||||
|  | ||||
| 		t.rules = append(t.rules, &rule) | ||||
| 		t.root.AddRule(&rule) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (t *RuleSet) AddRuleGroup(ruleGroups ...RuleGroup) { | ||||
| 	for _, rg := range ruleGroups { | ||||
| 		t.AddRule(rg.ToList()...) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| // MatchRules will return a list of all matching rules. The rules are sorted by | ||||
| // the match's Score. The best match will be first. | ||||
| func (t *RuleSet) MatchRules(input string) (rules []*Rule) { | ||||
| 	input = t.sanitize(input) | ||||
| 	tokens := Tokenize(input, t.maxNgram) | ||||
|  | ||||
| 	rules = t.root.Match(tokens) | ||||
| 	if len(rules) == 0 { | ||||
| 		return rules | ||||
| 	} | ||||
|  | ||||
| 	// Check excludes. | ||||
| 	l := rules[:0] | ||||
| 	for _, r := range rules { | ||||
| 		if !r.isExcluded(tokens) { | ||||
| 			l = append(l, r) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	rules = l | ||||
|  | ||||
| 	// Sort rules descending. | ||||
| 	sort.Slice(rules, func(i, j int) bool { | ||||
| 		return ruleLess(rules[j], rules[i]) | ||||
| 	}) | ||||
|  | ||||
| 	return rules | ||||
| } | ||||
|  | ||||
| type Match struct { | ||||
| 	Tag string | ||||
|  | ||||
| 	// Confidence is used to sort all matches, and is normalized so the sum of | ||||
| 	// Confidence values for all matches is 1. Confidence is relative to the | ||||
| 	// number of matches and the size of matches in terms of number of tokens. | ||||
| 	Confidence float64 // In the range (0,1]. | ||||
| } | ||||
|  | ||||
| // Return a list of matches with confidence. This is useful if you'd like to | ||||
| // find the best matching rule out of all the matched rules. | ||||
| // | ||||
| // If you just want to find all matching rules, then use MatchRules. | ||||
| func (t *RuleSet) Match(input string) []Match { | ||||
| 	rules := t.MatchRules(input) | ||||
| 	if len(rules) == 0 { | ||||
| 		return []Match{} | ||||
| 	} | ||||
| 	if len(rules) == 1 { | ||||
| 		return []Match{{ | ||||
| 			Tag:        rules[0].Tag, | ||||
| 			Confidence: 1, | ||||
| 		}} | ||||
| 	} | ||||
|  | ||||
| 	// Create list of blocked tags. | ||||
| 	blocks := map[string]struct{}{} | ||||
| 	for _, rule := range rules { | ||||
| 		for _, tag := range rule.Blocks { | ||||
| 			blocks[tag] = struct{}{} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Remove rules for blocked tags. | ||||
| 	iOut := 0 | ||||
| 	for _, rule := range rules { | ||||
| 		if _, ok := blocks[rule.Tag]; ok { | ||||
| 			continue | ||||
| 		} | ||||
| 		rules[iOut] = rule | ||||
| 		iOut++ | ||||
| 	} | ||||
| 	rules = rules[:iOut] | ||||
|  | ||||
| 	// Matches by index. | ||||
| 	matches := map[string]int{} | ||||
| 	out := []Match{} | ||||
| 	sum := float64(0) | ||||
|  | ||||
| 	for _, rule := range rules { | ||||
| 		idx, ok := matches[rule.Tag] | ||||
| 		if !ok { | ||||
| 			idx = len(matches) | ||||
| 			matches[rule.Tag] = idx | ||||
| 			out = append(out, Match{Tag: rule.Tag}) | ||||
| 		} | ||||
| 		out[idx].Confidence += float64(rule.Score) | ||||
| 		sum += float64(rule.Score) | ||||
| 	} | ||||
|  | ||||
| 	for i := range out { | ||||
| 		out[i].Confidence /= sum | ||||
| 	} | ||||
|  | ||||
| 	return out | ||||
| } | ||||
							
								
								
									
										84
									
								
								tagengine/ruleset_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								tagengine/ruleset_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,84 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestRulesSet(t *testing.T) { | ||||
| 	rs := NewRuleSet() | ||||
| 	rs.AddRule(Rule{ | ||||
| 		Tag:      "cc/2", | ||||
| 		Includes: []string{"cola", "coca"}, | ||||
| 	}) | ||||
| 	rs.AddRule(Rule{ | ||||
| 		Tag:      "cc/0", | ||||
| 		Includes: []string{"coca cola"}, | ||||
| 	}) | ||||
| 	rs.AddRule(Rule{ | ||||
| 		Tag:      "cz/2", | ||||
| 		Includes: []string{"coca", "zero"}, | ||||
| 	}) | ||||
| 	rs.AddRule(Rule{ | ||||
| 		Tag:      "cc0/3", | ||||
| 		Includes: []string{"zero", "coca", "cola"}, | ||||
| 	}) | ||||
| 	rs.AddRule(Rule{ | ||||
| 		Tag:      "cc0/3.1", | ||||
| 		Includes: []string{"coca", "cola", "zero"}, | ||||
| 		Excludes: []string{"pepsi"}, | ||||
| 	}) | ||||
| 	rs.AddRule(Rule{ | ||||
| 		Tag:      "spa", | ||||
| 		Includes: []string{"spa"}, | ||||
| 		Blocks:   []string{"cc/0", "cc0/3", "cc0/3.1"}, | ||||
| 	}) | ||||
|  | ||||
| 	type TestCase struct { | ||||
| 		Input   string | ||||
| 		Matches []Match | ||||
| 	} | ||||
|  | ||||
| 	cases := []TestCase{ | ||||
| 		{ | ||||
| 			Input: "coca-cola zero", | ||||
| 			Matches: []Match{ | ||||
| 				{"cc0/3.1", 0.3}, | ||||
| 				{"cc0/3", 0.3}, | ||||
| 				{"cz/2", 0.2}, | ||||
| 				{"cc/2", 0.2}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			Input: "coca cola", | ||||
| 			Matches: []Match{ | ||||
| 				{"cc/0", 0.6}, | ||||
| 				{"cc/2", 0.4}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			Input: "coca cola zero pepsi", | ||||
| 			Matches: []Match{ | ||||
| 				{"cc0/3", 0.3}, | ||||
| 				{"cc/0", 0.3}, | ||||
| 				{"cz/2", 0.2}, | ||||
| 				{"cc/2", 0.2}, | ||||
| 			}, | ||||
| 		}, { | ||||
| 			Input:   "fanta orange", | ||||
| 			Matches: []Match{}, | ||||
| 		}, { | ||||
| 			Input: "coca-cola zero / fanta / spa", | ||||
| 			Matches: []Match{ | ||||
| 				{"cz/2", 0.4}, | ||||
| 				{"cc/2", 0.4}, | ||||
| 				{"spa", 0.2}, | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		matches := rs.Match(tc.Input) | ||||
| 		if !reflect.DeepEqual(matches, tc.Matches) { | ||||
| 			t.Fatalf("%v != %v", matches, tc.Matches) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										20
									
								
								tagengine/sanitize.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										20
									
								
								tagengine/sanitize.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,20 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
|  | ||||
| 	"git.crumpington.com/lib/tagengine/sanitize" | ||||
| ) | ||||
|  | ||||
| // The basic sanitizer: | ||||
| // * lower-case | ||||
| // * put spaces around numbers | ||||
| // * put slaces around punctuation | ||||
| // * collapse multiple spaces | ||||
| func BasicSanitizer(s string) string { | ||||
| 	s = strings.ToLower(s) | ||||
| 	s = sanitize.SpaceNumbers(s) | ||||
| 	s = sanitize.SpacePunctuation(s) | ||||
| 	s = sanitize.CollapseSpaces(s) | ||||
| 	return s | ||||
| } | ||||
							
								
								
									
										91
									
								
								tagengine/sanitize/sanitize.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								tagengine/sanitize/sanitize.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,91 @@ | ||||
| package sanitize | ||||
|  | ||||
| import ( | ||||
| 	"strings" | ||||
| 	"unicode" | ||||
| ) | ||||
|  | ||||
| func SpaceNumbers(s string) string { | ||||
| 	if len(s) == 0 { | ||||
| 		return s | ||||
| 	} | ||||
|  | ||||
| 	isDigit := func(b rune) bool { | ||||
| 		switch b { | ||||
| 		case '0', '1', '2', '3', '4', '5', '6', '7', '8', '9': | ||||
| 			return true | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	b := strings.Builder{} | ||||
|  | ||||
| 	var first rune | ||||
| 	for _, c := range s { | ||||
| 		first = c | ||||
| 		break | ||||
| 	} | ||||
|  | ||||
| 	digit := isDigit(first) | ||||
|  | ||||
| 	// Range over runes. | ||||
| 	for _, c := range s { | ||||
| 		thisDigit := isDigit(c) | ||||
| 		if thisDigit != digit { | ||||
| 			b.WriteByte(' ') | ||||
| 			digit = thisDigit | ||||
| 		} | ||||
| 		b.WriteRune(c) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func SpacePunctuation(s string) string { | ||||
| 	needsSpace := func(r rune) bool { | ||||
| 		switch r { | ||||
| 		case '`', '~', '!', '@', '#', '%', '^', '&', '*', '(', ')', | ||||
| 			'-', '_', '+', '=', '[', '{', ']', '}', '\\', '|', | ||||
| 			':', ';', '"', '\'', ',', '<', '.', '>', '?', '/': | ||||
| 			return true | ||||
| 		} | ||||
| 		return false | ||||
| 	} | ||||
|  | ||||
| 	b := strings.Builder{} | ||||
|  | ||||
| 	// Range over runes. | ||||
| 	for _, r := range s { | ||||
| 		if needsSpace(r) { | ||||
| 			b.WriteRune(' ') | ||||
| 			b.WriteRune(r) | ||||
| 			b.WriteRune(' ') | ||||
| 		} else { | ||||
| 			b.WriteRune(r) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| } | ||||
|  | ||||
| func CollapseSpaces(s string) string { | ||||
| 	// Trim leading and trailing spaces. | ||||
| 	s = strings.TrimSpace(s) | ||||
|  | ||||
| 	b := strings.Builder{} | ||||
| 	wasSpace := false | ||||
|  | ||||
| 	// Range over runes. | ||||
| 	for _, c := range s { | ||||
| 		if unicode.IsSpace(c) { | ||||
| 			wasSpace = true | ||||
| 			continue | ||||
| 		} else if wasSpace { | ||||
| 			wasSpace = false | ||||
| 			b.WriteRune(' ') | ||||
| 		} | ||||
| 		b.WriteRune(c) | ||||
| 	} | ||||
|  | ||||
| 	return b.String() | ||||
| } | ||||
							
								
								
									
										30
									
								
								tagengine/sanitize_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										30
									
								
								tagengine/sanitize_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,30 @@ | ||||
| package tagengine | ||||
|  | ||||
| import "testing" | ||||
|  | ||||
| func TestSanitize(t *testing.T) { | ||||
| 	sanitize := BasicSanitizer | ||||
|  | ||||
| 	type Case struct { | ||||
| 		In  string | ||||
| 		Out string | ||||
| 	} | ||||
|  | ||||
| 	cases := []Case{ | ||||
| 		{"", ""}, | ||||
| 		{"123abc", "123 abc"}, | ||||
| 		{"abc123", "abc 123"}, | ||||
| 		{"abc123xyz", "abc 123 xyz"}, | ||||
| 		{"1f2", "1 f 2"}, | ||||
| 		{" abc", "abc"}, | ||||
| 		{" ; KitKat/m&m's (bottle) @ ", "; kitkat / m & m ' s ( bottle ) @"}, | ||||
| 		{"€", "€"}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		out := sanitize(tc.In) | ||||
| 		if out != tc.Out { | ||||
| 			t.Fatalf("%v != %v", out, tc.Out) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										63
									
								
								tagengine/tokenize.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								tagengine/tokenize.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,63 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"sort" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| var ignoreTokens = map[string]struct{}{} | ||||
|  | ||||
| func init() { | ||||
| 	// These on their own are ignored. | ||||
| 	tokens := []string{ | ||||
| 		"`", `~`, `!`, `@`, `#`, `%`, `^`, `&`, `*`, `(`, `)`, | ||||
| 		`-`, `_`, `+`, `=`, `[`, `{`, `]`, `}`, `\`, `|`, | ||||
| 		`:`, `;`, `"`, `'`, `,`, `<`, `.`, `>`, `?`, `/`, | ||||
| 	} | ||||
| 	for _, s := range tokens { | ||||
| 		ignoreTokens[s] = struct{}{} | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func Tokenize( | ||||
| 	input string, | ||||
| 	maxNgram int, | ||||
| ) ( | ||||
| 	tokens []string, | ||||
| ) { | ||||
| 	// Avoid duplicate ngrams. | ||||
| 	ignored := map[string]bool{} | ||||
|  | ||||
| 	fields := strings.Fields(input) | ||||
|  | ||||
| 	if len(fields) < maxNgram { | ||||
| 		maxNgram = len(fields) | ||||
| 	} | ||||
|  | ||||
| 	for i := 1; i < maxNgram+1; i++ { | ||||
| 		jMax := len(fields) - i + 1 | ||||
|  | ||||
| 		for j := 0; j < jMax; j++ { | ||||
| 			ngram := strings.Join(fields[j:i+j], " ") | ||||
| 			if _, ok := ignoreTokens[ngram]; !ok { | ||||
| 				if _, ok := ignored[ngram]; !ok { | ||||
| 					tokens = append(tokens, ngram) | ||||
| 					ignored[ngram] = true | ||||
| 				} | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	sortTokens(tokens) | ||||
|  | ||||
| 	return tokens | ||||
| } | ||||
|  | ||||
| func sortTokens(tokens []string) { | ||||
| 	sort.Slice(tokens, func(i, j int) bool { | ||||
| 		if len(tokens[i]) != len(tokens[j]) { | ||||
| 			return len(tokens[i]) < len(tokens[j]) | ||||
| 		} | ||||
| 		return tokens[i] < tokens[j] | ||||
| 	}) | ||||
| } | ||||
							
								
								
									
										55
									
								
								tagengine/tokenize_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										55
									
								
								tagengine/tokenize_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,55 @@ | ||||
| package tagengine | ||||
|  | ||||
| import ( | ||||
| 	"reflect" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| func TestTokenize(t *testing.T) { | ||||
| 	type Case struct { | ||||
| 		Input    string | ||||
| 		MaxNgram int | ||||
| 		Output   []string | ||||
| 	} | ||||
|  | ||||
| 	cases := []Case{ | ||||
| 		{ | ||||
| 			Input:    "a bb c d", | ||||
| 			MaxNgram: 3, | ||||
| 			Output: []string{ | ||||
| 				"a", "c", "d", "bb", | ||||
| 				"c d", "a bb", "bb c", | ||||
| 				"a bb c", "bb c d", | ||||
| 			}, | ||||
| 		}, { | ||||
| 			Input:    "a b", | ||||
| 			MaxNgram: 3, | ||||
| 			Output: []string{ | ||||
| 				"a", "b", "a b", | ||||
| 			}, | ||||
| 		}, { | ||||
| 			Input:    "- b c d", | ||||
| 			MaxNgram: 3, | ||||
| 			Output: []string{ | ||||
| 				"b", "c", "d", | ||||
| 				"- b", "b c", "c d", | ||||
| 				"- b c", "b c d", | ||||
| 			}, | ||||
| 		}, { | ||||
| 			Input:    "a a b c d c d", | ||||
| 			MaxNgram: 3, | ||||
| 			Output: []string{ | ||||
| 				"a", "b", "c", "d", | ||||
| 				"a a", "a b", "b c", "c d", "d c", | ||||
| 				"a a b", "a b c", "b c d", "c d c", "d c d", | ||||
| 			}, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		output := Tokenize(tc.Input, tc.MaxNgram) | ||||
| 		if !reflect.DeepEqual(output, tc.Output) { | ||||
| 			t.Fatalf("%s: %#v", tc.Input, output) | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										5
									
								
								webutil/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										5
									
								
								webutil/README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,5 @@ | ||||
| # webutil | ||||
|  | ||||
| ## Roadmap | ||||
|  | ||||
| * logging middleware | ||||
							
								
								
									
										10
									
								
								webutil/go.mod
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								webutil/go.mod
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,10 @@ | ||||
| module git.crumpington.com/lib/webutil | ||||
|  | ||||
| go 1.23.2 | ||||
|  | ||||
| require golang.org/x/crypto v0.28.0 | ||||
|  | ||||
| require ( | ||||
| 	golang.org/x/net v0.21.0 // indirect | ||||
| 	golang.org/x/text v0.19.0 // indirect | ||||
| ) | ||||
							
								
								
									
										6
									
								
								webutil/go.sum
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								webutil/go.sum
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,6 @@ | ||||
| golang.org/x/crypto v0.28.0 h1:GBDwsMXVQi34v5CCYUm2jkJvu4cbtru2U4TN2PSyQnw= | ||||
| golang.org/x/crypto v0.28.0/go.mod h1:rmgy+3RHxRZMyY0jjAJShp2zgEdOqj2AO7U0pYmeQ7U= | ||||
| golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= | ||||
| golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= | ||||
| golang.org/x/text v0.19.0 h1:kTxAhCbGbxhK0IwgSKiMO5awPoDQ0RpfiVYBfK860YM= | ||||
| golang.org/x/text v0.19.0/go.mod h1:BuEKDfySbSR4drPmRPG/7iBdf8hvFMuRexcpahXilzY= | ||||
							
								
								
									
										24
									
								
								webutil/listenandserve.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										24
									
								
								webutil/listenandserve.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,24 @@ | ||||
| package webutil | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
|  | ||||
| 	"golang.org/x/crypto/acme/autocert" | ||||
| ) | ||||
|  | ||||
| // Serve requests using the given http.Server. If srv.Addr has the format | ||||
| // `hostname:https`, then use autocert to manage certificates for the domain. | ||||
| // | ||||
| // For http on port 80, you can use :http. | ||||
| func ListenAndServe(srv *http.Server) error { | ||||
| 	if strings.HasSuffix(srv.Addr, ":https") { | ||||
| 		hostname := strings.TrimSuffix(srv.Addr, ":https") | ||||
| 		if len(hostname) == 0 { | ||||
| 			return errors.New("https requires a hostname") | ||||
| 		} | ||||
| 		return srv.Serve(autocert.NewListener(hostname)) | ||||
| 	} | ||||
| 	return srv.ListenAndServe() | ||||
| } | ||||
							
								
								
									
										47
									
								
								webutil/middleware-logging.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										47
									
								
								webutil/middleware-logging.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,47 @@ | ||||
| package webutil | ||||
|  | ||||
| import ( | ||||
| 	"log" | ||||
| 	"net/http" | ||||
| 	"os" | ||||
| 	"time" | ||||
| ) | ||||
|  | ||||
| var _log = log.New(os.Stderr, "", 0) | ||||
|  | ||||
| type responseWriterWrapper struct { | ||||
| 	http.ResponseWriter | ||||
| 	httpStatus   int | ||||
| 	responseSize int | ||||
| } | ||||
|  | ||||
| func (w *responseWriterWrapper) WriteHeader(status int) { | ||||
| 	w.httpStatus = status | ||||
| 	w.ResponseWriter.WriteHeader(status) | ||||
| } | ||||
|  | ||||
| func (w *responseWriterWrapper) Write(b []byte) (int, error) { | ||||
| 	if w.httpStatus == 0 { | ||||
| 		w.httpStatus = 200 | ||||
| 	} | ||||
| 	w.responseSize += len(b) | ||||
| 	return w.ResponseWriter.Write(b) | ||||
| } | ||||
|  | ||||
| func WithLogging(inner http.HandlerFunc) http.HandlerFunc { | ||||
| 	return func(w http.ResponseWriter, r *http.Request) { | ||||
| 		t := time.Now() | ||||
| 		wrapper := responseWriterWrapper{w, 0, 0} | ||||
|  | ||||
| 		inner(&wrapper, r) | ||||
| 		_log.Printf("%s \"%s %s %s\" %d %d %v\n", | ||||
| 			r.RemoteAddr, | ||||
| 			r.Method, | ||||
| 			r.URL.Path, | ||||
| 			r.Proto, | ||||
| 			wrapper.httpStatus, | ||||
| 			wrapper.responseSize, | ||||
| 			time.Since(t), | ||||
| 		) | ||||
| 	} | ||||
| } | ||||
							
								
								
									
										100
									
								
								webutil/template.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								webutil/template.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,100 @@ | ||||
| package webutil | ||||
|  | ||||
| import ( | ||||
| 	"embed" | ||||
| 	"html/template" | ||||
| 	"io/fs" | ||||
| 	"log" | ||||
| 	"path" | ||||
| 	"strings" | ||||
| ) | ||||
|  | ||||
| // ParseTemplateSet parses sets of templates from an embed.FS. | ||||
| // | ||||
| // Each directory constitutes a set of templates that are parsed together. | ||||
| // | ||||
| // Structure (within a directory): | ||||
| //   - share/* are always parsed. | ||||
| //   - base.html will be parsed with each other file in same dir | ||||
| // | ||||
| // Call a template with m[path].Execute(w, data) (root dir name is excluded). | ||||
| // | ||||
| // For example, if you have | ||||
| //   - /user/share/* | ||||
| //   - /user/base.html | ||||
| //   - /user/home.html | ||||
| // | ||||
| // Then you call m["/user/home.html"].Execute(w, data). | ||||
| func ParseTemplateSet(funcs template.FuncMap, fs embed.FS) map[string]*template.Template { | ||||
| 	m := map[string]*template.Template{} | ||||
| 	rootDir := readDir(fs, ".")[0].Name() | ||||
| 	loadTemplateDir(fs, funcs, m, rootDir, rootDir) | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| func loadTemplateDir( | ||||
| 	fs embed.FS, | ||||
| 	funcs template.FuncMap, | ||||
| 	m map[string]*template.Template, | ||||
| 	dirPath string, | ||||
| 	rootDir string, | ||||
| ) map[string]*template.Template { | ||||
| 	t := template.New("") | ||||
| 	if funcs != nil { | ||||
| 		t = t.Funcs(funcs) | ||||
| 	} | ||||
|  | ||||
| 	shareDir := path.Join(dirPath, "share") | ||||
| 	if _, err := fs.ReadDir(shareDir); err == nil { | ||||
| 		log.Printf("Parsing %s...", path.Join(shareDir, "*")) | ||||
| 		t = template.Must(t.ParseFS(fs, path.Join(shareDir, "*"))) | ||||
| 	} | ||||
|  | ||||
| 	if data, _ := fs.ReadFile(path.Join(dirPath, "base.html")); data != nil { | ||||
| 		log.Printf("Parsing %s...", path.Join(dirPath, "base.html")) | ||||
| 		t = template.Must(t.Parse(string(data))) | ||||
| 	} | ||||
|  | ||||
| 	for _, ent := range readDir(fs, dirPath) { | ||||
| 		if ent.Type().IsDir() { | ||||
| 			if ent.Name() != "share" { | ||||
| 				m = loadTemplateDir(fs, funcs, m, path.Join(dirPath, ent.Name()), rootDir) | ||||
| 			} | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if !ent.Type().IsRegular() { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		if ent.Name() == "base.html" { | ||||
| 			continue | ||||
| 		} | ||||
|  | ||||
| 		filePath := path.Join(dirPath, ent.Name()) | ||||
| 		log.Printf("Parsing %s...", filePath) | ||||
|  | ||||
| 		key := strings.TrimPrefix(path.Join(dirPath, ent.Name()), rootDir) | ||||
| 		tt := template.Must(t.Clone()) | ||||
| 		tt = template.Must(tt.Parse(readFile(fs, filePath))) | ||||
| 		m[key] = tt | ||||
| 	} | ||||
|  | ||||
| 	return m | ||||
| } | ||||
|  | ||||
| func readDir(fs embed.FS, dirPath string) []fs.DirEntry { | ||||
| 	ents, err := fs.ReadDir(dirPath) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return ents | ||||
| } | ||||
|  | ||||
| func readFile(fs embed.FS, path string) string { | ||||
| 	data, err := fs.ReadFile(path) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return string(data) | ||||
| } | ||||
							
								
								
									
										49
									
								
								webutil/template_test.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										49
									
								
								webutil/template_test.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,49 @@ | ||||
| package webutil | ||||
|  | ||||
| import ( | ||||
| 	"bytes" | ||||
| 	"embed" | ||||
| 	"html/template" | ||||
| 	"strings" | ||||
| 	"testing" | ||||
| ) | ||||
|  | ||||
| //go:embed all:test-templates | ||||
| var testFS embed.FS | ||||
|  | ||||
| func TestParseTemplateSet(t *testing.T) { | ||||
| 	funcs := template.FuncMap{"join": strings.Join} | ||||
| 	m := ParseTemplateSet(funcs, testFS) | ||||
|  | ||||
| 	type TestCase struct { | ||||
| 		Key  string | ||||
| 		Data any | ||||
| 		Out  string | ||||
| 	} | ||||
|  | ||||
| 	cases := []TestCase{ | ||||
| 		{ | ||||
| 			Key:  "/home.html", | ||||
| 			Data: "DATA", | ||||
| 			Out:  "<p>HOME!</p>", | ||||
| 		}, { | ||||
| 			Key:  "/about.html", | ||||
| 			Data: "DATA", | ||||
| 			Out:  "<p><b>DATA</b></p>", | ||||
| 		}, { | ||||
| 			Key:  "/contact.html", | ||||
| 			Data: []string{"a", "b", "c"}, | ||||
| 			Out:  "<p>a,b,c</p>", | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	for _, tc := range cases { | ||||
| 		b := &bytes.Buffer{} | ||||
| 		m[tc.Key].Execute(b, tc.Data) | ||||
| 		out := strings.TrimSpace(b.String()) | ||||
| 		if out != tc.Out { | ||||
| 			t.Fatalf("%s != %s", out, tc.Out) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| } | ||||
							
								
								
									
										1
									
								
								webutil/test-templates/about.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/about.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| {{define "body"}}{{template "bold" .}}{{end}} | ||||
							
								
								
									
										1
									
								
								webutil/test-templates/base.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/base.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| <p>{{block "body" .}}default{{end}}</p> | ||||
							
								
								
									
										1
									
								
								webutil/test-templates/contact.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/contact.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| {{define "body"}}{{join . ","}}{{end}} | ||||
							
								
								
									
										1
									
								
								webutil/test-templates/home.html
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								webutil/test-templates/home.html
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1 @@ | ||||
| {{define "body"}}HOME!{{end}} | ||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user