83 lines
1.7 KiB
Go
83 lines
1.7 KiB
Go
|
package pgutil
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"embed"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"path/filepath"
|
||
|
"sort"
|
||
|
)
|
||
|
|
||
|
const initMigrationTableQuery = `
|
||
|
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`
|
||
|
|
||
|
const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`
|
||
|
|
||
|
const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`
|
||
|
|
||
|
func Migrate(db *sql.DB, migrationFS embed.FS) error {
|
||
|
return WithTx(db, func(tx *sql.Tx) error {
|
||
|
if _, err := tx.Exec(initMigrationTableQuery); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
dirs, err := migrationFS.ReadDir(".")
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
if len(dirs) != 1 {
|
||
|
return errors.New("expected a single migrations directory")
|
||
|
}
|
||
|
|
||
|
if !dirs[0].IsDir() {
|
||
|
return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
|
||
|
}
|
||
|
|
||
|
dirName := dirs[0].Name()
|
||
|
files, err := migrationFS.ReadDir(dirName)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Sort sql files by name.
|
||
|
sort.Slice(files, func(i, j int) bool {
|
||
|
return files[i].Name() < files[j].Name()
|
||
|
})
|
||
|
|
||
|
for _, dirEnt := range files {
|
||
|
if !dirEnt.Type().IsRegular() {
|
||
|
return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
name = dirEnt.Name()
|
||
|
exists bool
|
||
|
)
|
||
|
|
||
|
err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if exists {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if _, err := tx.Exec(string(migration)); err != nil {
|
||
|
return fmt.Errorf("migration %s failed: %v", name, err)
|
||
|
}
|
||
|
|
||
|
if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|