diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..9bdf060 --- /dev/null +++ b/errors.go @@ -0,0 +1,8 @@ +package sqliteutil + +import "github.com/mattn/go-sqlite3" + +func ErrIsConstraint(err error) bool { + e, ok := err.(sqlite3.Error) + return ok && e.Code == 19 +} diff --git a/migrate.go b/migrate.go new file mode 100644 index 0000000..ad0701f --- /dev/null +++ b/migrate.go @@ -0,0 +1,82 @@ +package sqliteutil + +import ( + "database/sql" + "embed" + "errors" + "fmt" + "path/filepath" + "sort" +) + +const initMigrationTableQuery = ` +CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` + +const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` + +const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` + +func Migrate(db *sql.DB, migrationFS embed.FS) error { + return WithTx(db, func(tx *sql.Tx) error { + if _, err := tx.Exec(initMigrationTableQuery); err != nil { + return err + } + + dirs, err := migrationFS.ReadDir(".") + if err != nil { + return err + } + if len(dirs) != 1 { + return errors.New("expected a single migrations directory") + } + + if !dirs[0].IsDir() { + return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) + } + + dirName := dirs[0].Name() + files, err := migrationFS.ReadDir(dirName) + if err != nil { + return err + } + + // Sort sql files by name. + sort.Slice(files, func(i, j int) bool { + return files[i].Name() < files[j].Name() + }) + + for _, dirEnt := range files { + if !dirEnt.Type().IsRegular() { + return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) + } + + var ( + name = dirEnt.Name() + exists bool + ) + + err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) + if err != nil { + return err + } + + if exists { + continue + } + + migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) + if err != nil { + return err + } + + if _, err := tx.Exec(string(migration)); err != nil { + return fmt.Errorf("migration %s failed: %v", name, err) + } + + if _, err := tx.Exec(insertMigrationQuery, name); err != nil { + return err + } + } + return nil + }) +} diff --git a/migrate_test.go b/migrate_test.go new file mode 100644 index 0000000..f78d4af --- /dev/null +++ b/migrate_test.go @@ -0,0 +1,44 @@ +package sqliteutil + +import ( + "database/sql" + "embed" + "testing" +) + +//go:embed test-migrations +var testMigrationFS embed.FS + +func TestMigrate(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + t.Fatal(err) + } + + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + // Shouldn't have any effect. + if err := Migrate(db, testMigrationFS); err != nil { + t.Fatal(err) + } + + query := `SELECT EXISTS(SELECT 1 FROM users WHERE UserID=$1)` + var exists bool + + if err = db.QueryRow(query, 1).Scan(&exists); err != nil { + t.Fatal(err) + } + if exists { + t.Fatal("1 shouldn't exist") + } + + if err = db.QueryRow(query, 2).Scan(&exists); err != nil { + t.Fatal(err) + } + if !exists { + t.Fatal("2 should exist") + } + +} diff --git a/test-migrations/000.sql b/test-migrations/000.sql new file mode 100644 index 0000000..ecb559c --- /dev/null +++ b/test-migrations/000.sql @@ -0,0 +1,9 @@ +CREATE TABLE users( + UserID BIGINT NOT NULL PRIMARY KEY, + Email TEXT NOT NULL UNIQUE); + +CREATE TABLE user_notes( + UserID BIGINT NOT NULL REFERENCES users(UserID), + NoteID BIGINT NOT NULL, + Note Text NOT NULL, + PRIMARY KEY(UserID,NoteID)); diff --git a/test-migrations/001.sql b/test-migrations/001.sql new file mode 100644 index 0000000..e424c57 --- /dev/null +++ b/test-migrations/001.sql @@ -0,0 +1 @@ +INSERT INTO users(UserID, Email) VALUES (1, 'a@b.com'), (2, 'c@d.com'); diff --git a/test-migrations/002.sql b/test-migrations/002.sql new file mode 100644 index 0000000..ba414d2 --- /dev/null +++ b/test-migrations/002.sql @@ -0,0 +1 @@ +DELETE FROM users WHERE UserID=1; diff --git a/tx.go b/tx.go new file mode 100644 index 0000000..37f6f33 --- /dev/null +++ b/tx.go @@ -0,0 +1,28 @@ +package sqliteutil + +import ( + "database/sql" + + _ "github.com/mattn/go-sqlite3" +) + +// This is a convenience function to run a function within a transaction. +func WithTx(db *sql.DB, fn func(*sql.Tx) error) error { + // Start a transaction. + tx, err := db.Begin() + if err != nil { + return err + } + + err = fn(tx) + + if err == nil { + err = tx.Commit() + } + + if err != nil { + _ = tx.Rollback() + } + + return err +}