wip
This commit is contained in:
45
pgutil/README.md
Normal file
45
pgutil/README.md
Normal file
@@ -0,0 +1,45 @@
|
||||
# pgutil
|
||||
|
||||
## Transactions
|
||||
|
||||
Simplify postgres transactions using `WithTx` for serializable transactions,
|
||||
or `WithTxDefault` for the default isolation level. Use the `SerialTxRunner`
|
||||
type to get automatic retries of serialization errors.
|
||||
|
||||
## Migrations
|
||||
|
||||
Put your migrations into a directory, for example `migrations`, ordered by name
|
||||
(YYYY-MM-DD prefix, for example). Embed the directory and pass it to the
|
||||
`Migrate` function:
|
||||
|
||||
```Go
|
||||
//go:embed migrations
|
||||
var migrations embed.FS
|
||||
|
||||
func init() {
|
||||
Migrate(db, migrations) // Check the error, of course.
|
||||
}
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
In order to test this packge, we need to create a test user and database:
|
||||
|
||||
```
|
||||
sudo su postgres
|
||||
psql
|
||||
|
||||
CREATE DATABASE test;
|
||||
CREATE USER test WITH ENCRYPTED PASSWORD 'test';
|
||||
GRANT ALL PRIVILEGES ON DATABASE test TO test;
|
||||
|
||||
use test
|
||||
|
||||
GRANT ALL ON SCHEMA public TO test;
|
||||
```
|
||||
|
||||
Check that you can connect via the command line:
|
||||
|
||||
```
|
||||
psql -h 127.0.0.1 -U test --password test
|
||||
```
|
42
pgutil/dropall.go
Normal file
42
pgutil/dropall.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"log"
|
||||
)
|
||||
|
||||
const dropTablesQueryQuery = `
|
||||
SELECT 'DROP TABLE IF EXISTS "' || tablename || '" CASCADE;'
|
||||
FROM
|
||||
pg_tables
|
||||
WHERE
|
||||
schemaname='public'`
|
||||
|
||||
// Deletes all tables in the database. Useful for testing.
|
||||
func DropAllTables(db *sql.DB) error {
|
||||
rows, err := db.Query(dropTablesQueryQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
queries := []string{}
|
||||
for rows.Next() {
|
||||
var s string
|
||||
if err := rows.Scan(&s); err != nil {
|
||||
return err
|
||||
}
|
||||
queries = append(queries, s)
|
||||
}
|
||||
|
||||
if len(queries) > 0 {
|
||||
log.Printf("DROPPING ALL (%d) TABLES", len(queries))
|
||||
}
|
||||
|
||||
for _, query := range queries {
|
||||
if _, err := db.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
31
pgutil/errors.go
Normal file
31
pgutil/errors.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
func ErrIsDuplicateKey(err error) bool {
|
||||
return ErrHasCode(err, "23505")
|
||||
}
|
||||
|
||||
func ErrIsForeignKey(err error) bool {
|
||||
return ErrHasCode(err, "23503")
|
||||
}
|
||||
|
||||
func ErrIsSerializationFaiilure(err error) bool {
|
||||
return ErrHasCode(err, "40001")
|
||||
}
|
||||
|
||||
func ErrHasCode(err error, code string) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
var pErr *pq.Error
|
||||
if errors.As(err, &pErr) {
|
||||
return pErr.Code == pq.ErrorCode(code)
|
||||
}
|
||||
return false
|
||||
}
|
36
pgutil/errors_test.go
Normal file
36
pgutil/errors_test.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrors(t *testing.T) {
|
||||
db, err := sql.Open(
|
||||
"postgres",
|
||||
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DropAllTables(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (2, 'q@r.com')`)
|
||||
if !ErrIsDuplicateKey(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.Exec(`INSERT INTO users(UserID, Email) VALUES (3, 'c@d.com')`)
|
||||
if !ErrIsDuplicateKey(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = db.Exec(`INSERT INTO user_notes(UserID, NoteID, Note) VALUES (4, 1, 'hello')`)
|
||||
if !ErrIsForeignKey(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
5
pgutil/go.mod
Normal file
5
pgutil/go.mod
Normal file
@@ -0,0 +1,5 @@
|
||||
module git.crumpington.com/git/pgutil
|
||||
|
||||
go 1.23.2
|
||||
|
||||
require github.com/lib/pq v1.10.9
|
2
pgutil/go.sum
Normal file
2
pgutil/go.sum
Normal file
@@ -0,0 +1,2 @@
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
82
pgutil/migrate.go
Normal file
82
pgutil/migrate.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
)
|
||||
|
||||
const initMigrationTableQuery = `
|
||||
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
|
||||
|
||||
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
|
||||
|
||||
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
|
||||
|
||||
func Migrate(db *sql.DB, migrationFS embed.FS) error {
|
||||
return WithTx(db, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dirs, err := migrationFS.ReadDir(".")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(dirs) != 1 {
|
||||
return errors.New("expected a single migrations directory")
|
||||
}
|
||||
|
||||
if !dirs[0].IsDir() {
|
||||
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
|
||||
}
|
||||
|
||||
dirName := dirs[0].Name()
|
||||
files, err := migrationFS.ReadDir(dirName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Sort sql files by name.
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].Name() < files[j].Name()
|
||||
})
|
||||
|
||||
for _, dirEnt := range files {
|
||||
if !dirEnt.Type().IsRegular() {
|
||||
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
|
||||
}
|
||||
|
||||
var (
|
||||
name = dirEnt.Name()
|
||||
exists bool
|
||||
)
|
||||
|
||||
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(string(migration)); err != nil {
|
||||
return fmt.Errorf("migration %s failed: %v", name, err)
|
||||
}
|
||||
|
||||
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
50
pgutil/migrate_test.go
Normal file
50
pgutil/migrate_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"testing"
|
||||
)
|
||||
|
||||
//go:embed test-migrations
|
||||
var testMigrationFS embed.FS
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
db, err := sql.Open(
|
||||
"postgres",
|
||||
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DropAllTables(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Shouldn't have any effect.
|
||||
if err := Migrate(db, testMigrationFS); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)`
|
||||
var exists bool
|
||||
|
||||
if err = db.QueryRow(query, 1).Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if exists {
|
||||
t.Fatal("1 shouldn't exist")
|
||||
}
|
||||
|
||||
if err = db.QueryRow(query, 2).Scan(&exists); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !exists {
|
||||
t.Fatal("2 should exist")
|
||||
}
|
||||
|
||||
}
|
9
pgutil/test-migrations/000.sql
Normal file
9
pgutil/test-migrations/000.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
CREATE TABLE users(
|
||||
UserID BIGINT NOT NULL PRIMARY KEY,
|
||||
Email TEXT NOT NULL UNIQUE);
|
||||
|
||||
CREATE TABLE user_notes(
|
||||
UserID BIGINT NOT NULL REFERENCES users(UserID),
|
||||
NoteID BIGINT NOT NULL,
|
||||
Note Text NOT NULL,
|
||||
PRIMARY KEY(UserID,NoteID));
|
1
pgutil/test-migrations/001.sql
Normal file
1
pgutil/test-migrations/001.sql
Normal file
@@ -0,0 +1 @@
|
||||
INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com');
|
1
pgutil/test-migrations/002.sql
Normal file
1
pgutil/test-migrations/002.sql
Normal file
@@ -0,0 +1 @@
|
||||
DELETE FROM users WHERE UserID=1;
|
70
pgutil/tx.go
Normal file
70
pgutil/tx.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Postgres doesn't use serializable transactions by default. This wrapper will
|
||||
// run the enclosed function within a serializable. Note: this may return an
|
||||
// retriable serialization error (see ErrIsSerializationFaiilure).
|
||||
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
return WithTxDefault(db, func(tx *sql.Tx) error {
|
||||
if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil {
|
||||
return err
|
||||
}
|
||||
return fn(tx)
|
||||
})
|
||||
}
|
||||
|
||||
// This is a convenience function to provide a transaction wrapper with the
|
||||
// default isolation level.
|
||||
func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
// Start a transaction.
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fn(tx)
|
||||
|
||||
if err == nil {
|
||||
err = tx.Commit()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// SerialTxRunner attempts serializable transactions in a loop. If a
|
||||
// transaction fails due to a serialization error, then the runner will retry
|
||||
// with exponential backoff, until the sleep time reaches MaxTimeout.
|
||||
//
|
||||
// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep
|
||||
// for ~100, 200, 400, and 800 ms between retries.
|
||||
//
|
||||
// 10% jitter is added to the sleep time.
|
||||
type SerialTxRunner struct {
|
||||
MinTimeout time.Duration
|
||||
MaxTimeout time.Duration
|
||||
}
|
||||
|
||||
func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
|
||||
timeout := r.MinTimeout
|
||||
for {
|
||||
err := WithTx(db, fn)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if timeout > r.MaxTimeout || !ErrIsSerializationFaiilure(err) {
|
||||
return err
|
||||
}
|
||||
sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10)))
|
||||
time.Sleep(sleepTimeout)
|
||||
timeout *= 2
|
||||
}
|
||||
}
|
120
pgutil/tx_test.go
Normal file
120
pgutil/tx_test.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package pgutil
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// TestExecuteTx verifies transaction retry using the classic
|
||||
// example of write skew in bank account balance transfers.
|
||||
func TestWithTx(t *testing.T) {
|
||||
db, err := sql.Open(
|
||||
"postgres",
|
||||
"host=127.0.0.1 dbname=test sslmode=disable user=test password=test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := DropAllTables(db); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
defer db.Close()
|
||||
|
||||
initStmt := `
|
||||
CREATE TABLE t (acct INT PRIMARY KEY, balance INT);
|
||||
INSERT INTO t (acct, balance) VALUES (1, 100), (2, 100);
|
||||
`
|
||||
if _, err := db.Exec(initStmt); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
type queryI interface {
|
||||
Query(string, ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
getBalances := func(q queryI) (bal1, bal2 int, err error) {
|
||||
var rows *sql.Rows
|
||||
rows, err = q.Query(`SELECT balance FROM t WHERE acct IN (1, 2);`)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
defer rows.Close()
|
||||
balances := []*int{&bal1, &bal2}
|
||||
i := 0
|
||||
for ; rows.Next(); i++ {
|
||||
if err = rows.Scan(balances[i]); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
if i != 2 {
|
||||
err = fmt.Errorf("expected two balances; got %d", i)
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
txRunner := SerialTxRunner{100 * time.Millisecond, 800 * time.Millisecond}
|
||||
|
||||
runTxn := func(wg *sync.WaitGroup, iter *int) <-chan error {
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
*iter = 0
|
||||
errCh <- txRunner.WithTx(db, func(tx *sql.Tx) error {
|
||||
*iter++
|
||||
bal1, bal2, err := getBalances(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// If this is the first iteration, wait for the other tx to
|
||||
// also read.
|
||||
if *iter == 1 {
|
||||
wg.Done()
|
||||
wg.Wait()
|
||||
}
|
||||
// Now, subtract from one account and give to the other.
|
||||
if bal1 > bal2 {
|
||||
if _, err := tx.Exec(`
|
||||
UPDATE t SET balance=balance-100 WHERE acct=1;
|
||||
UPDATE t SET balance=balance+100 WHERE acct=2;
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := tx.Exec(`
|
||||
UPDATE t SET balance=balance+100 WHERE acct=1;
|
||||
UPDATE t SET balance=balance-100 WHERE acct=2;
|
||||
`); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}()
|
||||
return errCh
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
var iters1, iters2 int
|
||||
txn1Err := runTxn(&wg, &iters1)
|
||||
txn2Err := runTxn(&wg, &iters2)
|
||||
if err := <-txn1Err; err != nil {
|
||||
t.Errorf("expected success in txn1; got %s", err)
|
||||
}
|
||||
if err := <-txn2Err; err != nil {
|
||||
t.Errorf("expected success in txn2; got %s", err)
|
||||
}
|
||||
if iters1+iters2 <= 2 {
|
||||
t.Errorf("expected retries between the competing transactions; "+
|
||||
"got txn1=%d, txn2=%d", iters1, iters2)
|
||||
}
|
||||
bal1, bal2, err := getBalances(db)
|
||||
if err != nil || bal1 != 100 || bal2 != 100 {
|
||||
t.Errorf("expected balances to be restored without error; "+
|
||||
"got acct1=%d, acct2=%d: %s", bal1, bal2, err)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user