package sqliteutil

import (
	"database/sql"
	"embed"
	"errors"
	"fmt"
	"path/filepath"
	"sort"
)

const initMigrationTableQuery = `
CREATE TABLE IF NOT EXISTS migrations(filename TEXT NOT NULL PRIMARY KEY);`

const insertMigrationQuery = `INSERT INTO migrations(filename) VALUES($1)`

const checkMigrationAppliedQuery = `SELECT EXISTS(SELECT 1 FROM migrations WHERE filename=$1)`

func Migrate(db *sql.DB, migrationFS embed.FS) error {
	return WithTx(db, func(tx *sql.Tx) error {
		if _, err := tx.Exec(initMigrationTableQuery); err != nil {
			return err
		}

		dirs, err := migrationFS.ReadDir(".")
		if err != nil {
			return err
		}
		if len(dirs) != 1 {
			return errors.New("expected a single migrations directory")
		}

		if !dirs[0].IsDir() {
			return fmt.Errorf("unexpected non-directory in migration FS: %s", dirs[0].Name())
		}

		dirName := dirs[0].Name()
		files, err := migrationFS.ReadDir(dirName)
		if err != nil {
			return err
		}

		// Sort sql files by name.
		sort.Slice(files, func(i, j int) bool {
			return files[i].Name() < files[j].Name()
		})

		for _, dirEnt := range files {
			if !dirEnt.Type().IsRegular() {
				return fmt.Errorf("unexpected non-regular file in migration fs: %s", dirEnt.Name())
			}

			var (
				name   = dirEnt.Name()
				exists bool
			)

			err := tx.QueryRow(checkMigrationAppliedQuery, name).Scan(&exists)
			if err != nil {
				return err
			}

			if exists {
				continue
			}

			migration, err := migrationFS.ReadFile(filepath.Join(dirName, name))
			if err != nil {
				return err
			}

			if _, err := tx.Exec(string(migration)); err != nil {
				return fmt.Errorf("migration %s failed: %v", name, err)
			}

			if _, err := tx.Exec(insertMigrationQuery, name); err != nil {
				return err
			}
		}
		return nil
	})
}