This commit is contained in:
jdl
2024-11-11 06:36:55 +01:00
parent d0587cc585
commit c5419d662e
102 changed files with 4181 additions and 0 deletions

45
pgutil/README.md Normal file
View File

@@ -0,0 +1,45 @@
# pgutil
## Transactions
Simplify postgres transactions using `WithTx` for serializable transactions,
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
type to get automatic retries of serialization errors.
## Migrations
Put your migrations into a directory, for example `migrations`, ordered by name
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
`Migrate` function:
```Go
//go:embed migrations
var migrations embed.FS
func init() {
Migrate(db, migrations) // Check the error, of course.
}
```
## Testing
In order to test this packge, we need to create a test user and database:
```
sudo su postgres
psql
CREATE DATABASE test;
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
GRANT ALL PRIVILEGES ON DATABASE test TO test;
use test
GRANT ALL ON SCHEMA public TO test;
```
Check that you can connect via the command line:
```
psql -h 127.0.0.1 -U test --password test
```

42
pgutil/dropall.go Normal file
View File

@@ -0,0 +1,42 @@
package pgutil
import (
"database/sql"
"log"
)
const dropTablesQueryQuery = `
SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;'
FROM
pg_tables
WHERE
schemaname='public'`
// Deletes all tables in the database. Useful for testing.
func DropAllTables(db *sql.DB) error {
rows, err := db.Query(dropTablesQueryQuery)
if err != nil {
return err
}
queries := []string{}
for rows.Next() {
var s string
if err := rows.Scan(&s); err != nil {
return err
}
queries = append(queries, s)
}
if len(queries) > 0 {
log.Printf("DROPPING ALL (%d) TABLES", len(queries))
}
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
return err
}
}
return nil
}

31
pgutil/errors.go Normal file
View File

@@ -0,0 +1,31 @@
package pgutil
import (
"errors"
"github.com/lib/pq"
)
func ErrIsDuplicateKey(err error) bool {
return ErrHasCode(err, "23505")
}
func ErrIsForeignKey(err error) bool {
return ErrHasCode(err, "23503")
}
func ErrIsSerializationFaiilure(err error) bool {
return ErrHasCode(err, "40001")
}
func ErrHasCode(err error, code string) bool {
if err == nil {
return false
}
var pErr *pq.Error
if errors.As(err, &pErr) {
return pErr.Code == pq.ErrorCode(code)
}
return false
}

36
pgutil/errors_test.go Normal file
View File

@@ -0,0 +1,36 @@
package pgutil
import (
"database/sql"
"testing"
)
func TestErrors(t *testing.T) {
db, err := sql.Open(
"postgres",
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
if err != nil {
t.Fatal(err)
}
if err := DropAllTables(db); err != nil {
t.Fatal(err)
}
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`)
if !ErrIsDuplicateKey(err) {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`)
if !ErrIsDuplicateKey(err) {
t.Fatal(err)
}
_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`)
if !ErrIsForeignKey(err) {
t.Fatal(err)
}
}

5
pgutil/go.mod Normal file
View File

@@ -0,0 +1,5 @@
module git.crumpington.com/git/pgutil
go 1.23.2
require github.com/lib/pq v1.10.9

2
pgutil/go.sum Normal file
View File

@@ -0,0 +1,2 @@
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=

82
pgutil/migrate.go Normal file
View File

@@ -0,0 +1,82 @@
package pgutil
import (
"database/sql"
"embed"
"errors"
"fmt"
"path/filepath"
"sort"
)
const initMigrationTableQuery = `
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
func Migrate(db *sql.DB, migrationFS embed.FS) error {
return WithTx(db, func(tx *sql.Tx) error {
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
return err
}
dirs, err := migrationFS.ReadDir(".")
if err != nil {
return err
}
if len(dirs) != 1 {
return errors.New("expected a single migrations directory")
}
if !dirs[0].IsDir() {
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
}
dirName := dirs[0].Name()
files, err := migrationFS.ReadDir(dirName)
if err != nil {
return err
}
// Sort sql files by name.
sort.Slice(files, func(i, j int) bool {
return files[i].Name() < files[j].Name()
})
for _, dirEnt := range files {
if !dirEnt.Type().IsRegular() {
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
}
var (
name = dirEnt.Name()
exists bool
)
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
if err != nil {
return err
}
if exists {
continue
}
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
if err != nil {
return err
}
if _, err := tx.Exec(string(migration)); err != nil {
return fmt.Errorf("migration %s failed: %v", name, err)
}
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
return err
}
}
return nil
})
}

50
pgutil/migrate_test.go Normal file
View File

@@ -0,0 +1,50 @@
package pgutil
import (
"database/sql"
"embed"
"testing"
)
//go:embed test-migrations
var testMigrationFS embed.FS
func TestMigrate(t *testing.T) {
db, err := sql.Open(
"postgres",
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
if err != nil {
t.Fatal(err)
}
if err := DropAllTables(db); err != nil {
t.Fatal(err)
}
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
// Shouldn't have any effect.
if err := Migrate(db, testMigrationFS); err != nil {
t.Fatal(err)
}
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
var exists bool
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
t.Fatal(err)
}
if exists {
t.Fatal("1 shouldn't exist")
}
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
t.Fatal(err)
}
if !exists {
t.Fatal("2 should exist")
}
}

View File

@@ -0,0 +1,9 @@
CREATE TABLE users(
UserID BIGINT NOT NULL PRIMARY KEY,
Email TEXT NOT NULL UNIQUE);
CREATE TABLE user_notes(
UserID BIGINT NOT NULL REFERENCES users(UserID),
NoteID BIGINT NOT NULL,
Note Text NOT NULL,
PRIMARY KEY(UserID,NoteID));

View File

@@ -0,0 +1 @@
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');

View File

@@ -0,0 +1 @@
DELETE FROM users WHERE UserID=1;

70
pgutil/tx.go Normal file
View File

@@ -0,0 +1,70 @@
package pgutil
import (
"database/sql"
"math/rand"
"time"
)
// Postgres doesn't use serializable transactions by default. This wrapper will
// run the enclosed function within a serializable. Note: this may return an
// retriable serialization error (see ErrIsSerializationFaiilure).
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
return WithTxDefault(db, func(tx *sql.Tx) error {
if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil {
return err
}
return fn(tx)
})
}
// This is a convenience function to provide a transaction wrapper with the
// default isolation level.
func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error {
// Start a transaction.
tx, err := db.Begin()
if err != nil {
return err
}
err = fn(tx)
if err == nil {
err = tx.Commit()
}
if err != nil {
_ = tx.Rollback()
}
return err
}
// SerialTxRunner attempts serializable transactions in a loop. If a
// transaction fails due to a serialization error, then the runner will retry
// with exponential backoff, until the sleep time reaches MaxTimeout.
//
// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep
// for ~100, 200, 400, and 800 ms between retries.
//
// 10% jitter is added to the sleep time.
type SerialTxRunner struct {
MinTimeout time.Duration
MaxTimeout time.Duration
}
func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
timeout := r.MinTimeout
for {
err := WithTx(db, fn)
if err == nil {
return nil
}
if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) {
return err
}
sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10)))
time.Sleep(sleepTimeout)
timeout *= 2
}
}

120
pgutil/tx_test.go Normal file
View File

@@ -0,0 +1,120 @@
package pgutil
import (
"database/sql"
"fmt"
"sync"
"testing"
"time"
)
// TestExecuteTx verifies transaction retry using the classic
// example of write skew in bank account balance transfers.
func TestWithTx(t *testing.T) {
db, err := sql.Open(
"postgres",
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
if err != nil {
t.Fatal(err)
}
if err := DropAllTables(db); err != nil {
t.Fatal(err)
}
defer db.Close()
initStmt := `
CREATE TABLE t (acct INT PRIMARY KEY, balance INT);
INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100);
`
if _, err := db.Exec(initStmt); err != nil {
t.Fatal(err)
}
type queryI interface {
Query(string, ...interface{}) (*sql.Rows, error)
}
getBalances := func(q queryI) (bal1, bal2 int, err error) {
var rows *sql.Rows
rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`)
if err != nil {
return
}
defer rows.Close()
balances := []*int{&bal1, &bal2}
i := 0
for ; rows.Next(); i++ {
if err = rows.Scan(balances[i]); err != nil {
return
}
}
if i != 2 {
err = fmt.Errorf("expected two balances; got %d", i)
return
}
return
}
txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond}
runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error {
errCh := make(chan error, 1)
go func() {
*iter = 0
errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error {
*iter++
bal1, bal2, err := getBalances(tx)
if err != nil {
return err
}
// If this is the first iteration, wait for the other tx to
// also read.
if *iter == 1 {
wg.Done()
wg.Wait()
}
// Now, subtract from one account and give to the other.
if bal1 > bal2 {
if _, err := tx.Exec(`
UPDATE t SET balance=balance-100 WHERE acct=1;
UPDATE t SET balance=balance+100 WHERE acct=2;
`); err != nil {
return err
}
} else {
if _, err := tx.Exec(`
UPDATE t SET balance=balance+100 WHERE acct=1;
UPDATE t SET balance=balance-100 WHERE acct=2;
`); err != nil {
return err
}
}
return nil
})
}()
return errCh
}
var wg sync.WaitGroup
wg.Add(2)
var iters1, iters2 int
txn1Err := runTxn(&wg, &iters1)
txn2Err := runTxn(&wg, &iters2)
if err := <-txn1Err; err != nil {
t.Errorf("expected success in txn1; got %s", err)
}
if err := <-txn2Err; err != nil {
t.Errorf("expected success in txn2; got %s", err)
}
if iters1+iters2 <= 2 {
t.Errorf("expected retries between the competing transactions; "+
"got txn1=%d, txn2=%d", iters1, iters2)
}
bal1, bal2, err := getBalances(db)
if err != nil || bal1 != 100 || bal2 != 100 {
t.Errorf("expected balances to be restored without error; "+
"got acct1=%d, acct2=%d: %s", bal1, bal2, err)
}
}