package sqliteutil import ( "database/sql" "embed" "errors" "fmt" "path/filepath" "sort" ) const initMigrationTableQuery = ` CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);` const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)` const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)` func Migrate(db *sql.DB, migrationFS embed.FS) error { return WithTx(db, func(tx *sql.Tx) error { if _, err := tx.Exec(initMigrationTableQuery); err != nil { return err } dirs, err := migrationFS.ReadDir(".") if err != nil { return err } if len(dirs) != 1 { return errors.New("expected a single migrations directory") } if !dirs[0].IsDir() { return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name()) } dirName := dirs[0].Name() files, err := migrationFS.ReadDir(dirName) if err != nil { return err } // Sort sql files by name. sort.Slice(files, func(i, j int) bool { return files[i].Name() < files[j].Name() }) for _, dirEnt := range files { if !dirEnt.Type().IsRegular() { return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name()) } var ( name = dirEnt.Name() exists bool ) err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists) if err != nil { return err } if exists { continue } migration, err := migrationFS.ReadFile(filepath.Join(dirName, name)) if err != nil { return err } if _, err := tx.Exec(string(migration)); err != nil { return fmt.Errorf("migration %s failed: %v", name, err) } if _, err := tx.Exec(insertMigrationQuery, name); err != nil { return err } } return nil }) }