package pgutil

import (
	"database/sql"
	"math/rand"
	"time"
)

// Postgres doesn't use serializable transactions by default. This wrapper will
// run the enclosed function within a serializable. Note: this may return an
// retriable serialization error (see ErrIsSerializationFaiilure).
func WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
	return WithTxDefault(db, func(tx *sql.Tx) error {
		if _, err := tx.Exec("SET TRANSACTION ISOLATION LEVEL SERIALIZABLE"); err != nil {
			return err
		}
		return fn(tx)
	})
}

// This is a convenience function to provide a transaction wrapper with the
// default isolation level.
func WithTxDefault(db *sql.DB, fn func(*sql.Tx) error) error {
	// Start a transaction.
	tx, err := db.Begin()
	if err != nil {
		return err
	}

	err = fn(tx)

	if err == nil {
		err = tx.Commit()
	}

	if err != nil {
		_ = tx.Rollback()
	}

	return err
}

// SerialTxRunner attempts serializable transactions in a loop. If a
// transaction fails due to a serialization error, then the runner will retry
// with exponential backoff, until the sleep time reaches MaxTimeout.
//
// For example, if MinTimeout is 100 ms, and MaxTimeout is 800 ms, it may sleep
// for ~100, 200, 400, and 800 ms between retries.
//
// 10% jitter is added to the sleep time.
type SerialTxRunner struct {
	MinTimeout time.Duration
	MaxTimeout time.Duration
}

func (r SerialTxRunner) WithTx(db *sql.DB, fn func(*sql.Tx) error) error {
	timeout := r.MinTimeout
	for {
		err := WithTx(db, fn)
		if err == nil {
			return nil
		}
		if timeout > r.MaxTimeout || !ErrIsSerializationFailure(err) {
			return err
		}
		sleepTimeout := timeout + time.Duration(rand.Int63n(int64(timeout/10)))
		time.Sleep(sleepTimeout)
		timeout *= 2
	}
}