// Package migrate provides a dead simple Go package for performing sql
// migrations using database/sql.
package migrate

import (
	"database/sql"
	"fmt"
	"hash/crc32"
	"sort"
	"sync"
)

type MigrationDirection int

const (
	Up MigrationDirection = iota
	Down
)

type TransactionMode int

const (
	// In this mode, each migration is run in it's own isolated transaction.
	// If a migration fails, only that migration will be rolled back.
	IndividualTransactions TransactionMode = iota

	// In this mode, all migrations are run inside a single transaction. If
	// one migration fails, all migrations are rolled back.
	SingleTransaction
)

// MigrationError is an error that gets returned when an individual migration
// fails.
type MigrationError struct {
	Migration

	// The underlying error.
	Err error
}

// Error implements the error interface.
func (e *MigrationError) Error() string {
	return fmt.Sprintf("migration %d failed: %v", e.ID, e.Err)
}

// The default table to store what migrations have been run.
const DefaultTable = "schema_migrations"

// Migration represents a sql migration that can be migrated up or down.
type Migration struct {
	// ID is a unique, numeric, identifier for this migration.
	ID int

	// Up is a function that gets called when this migration should go up.
	Up func(tx *sql.Tx) error

	// Down is a function that gets called when this migration should go
	// down.
	Down func(tx *sql.Tx) error
}

// byID implements the sort.Interface interface for sorting migrations by
// ID.
type byID []Migration

func (m byID) Len() int           { return len(m) }
func (m byID) Less(i, j int) bool { return m[i].ID < m[j].ID }
func (m byID) Swap(i, j int)      { m[i], m[j] = m[j], m[i] }

// Migrator performs migrations.
type Migrator struct {
	// Table is the table to store what migrations have been run. The zero
	// value is DefaultTable.
	Table string

	// Locker is a sync.Locker to use to ensure that only 1 process is
	// running migrations.
	sync.Locker

	// The TransactionMode to use. The zero value is IndividualTransactions,
	// which runs each migration in it's own transaction.
	TransactionMode TransactionMode

	db *sql.DB
}

// postgresLocker implements the sync.Locker interface using pg_advisory_lock.
type postgresLocker struct {
	key uint32
	db  *sql.DB
}

// NewPostgresLocker returns a new sync.Locker that obtains locks with
// pg_advisory_lock.
func newPostgresLocker(db *sql.DB) sync.Locker {
	key := crc32.ChecksumIEEE([]byte("migrations"))
	return &postgresLocker{
		key: key,
		db:  db,
	}
}

// Lock obtains the advisory lock.
func (l *postgresLocker) Lock() {
	l.do("lock")
}

// Unlock removes the advisory Lock
func (l *postgresLocker) Unlock() {
	l.do("unlock")
}

func (l *postgresLocker) do(m string) {
	_, err := l.db.Exec(fmt.Sprintf("SELECT pg_advisory_%s(%d)", m, l.key))
	if err != nil {
		panic(fmt.Sprintf("migrate: %v", err))
	}
}

// NewMigrator returns a new Migrator instance that will use the sql.DB to
// perform the migrations.
func NewMigrator(db *sql.DB) *Migrator {
	return &Migrator{
		db:     db,
		Locker: new(sync.Mutex),
	}
}

// NewPostgresMigrator returns a new Migrator instance that uses the underlying
// sql.DB connection to a postgres database to perform migrations. It will use
// Postgres's advisory locks to ensure that only 1 migration is run at a time.
func NewPostgresMigrator(db *sql.DB) *Migrator {
	m := NewMigrator(db)
	m.Locker = newPostgresLocker(db)
	return m
}

// Exec runs the migrations in the given direction.
func (m *Migrator) Exec(dir MigrationDirection, migrations ...Migration) error {
	m.Lock()
	defer m.Unlock()

	_, err := m.db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version integer primary key not null)", m.table()))
	if err != nil {
		return err
	}

	var tx *sql.Tx
	if m.TransactionMode == SingleTransaction {
		tx, err = m.db.Begin()
		if err != nil {
			return err
		}
	}

	for _, migration := range sortMigrations(dir, migrations) {
		if m.TransactionMode == IndividualTransactions {
			tx, err = m.db.Begin()
			if err != nil {
				return err
			}
		}

		if err := m.runMigration(tx, dir, migration); err != nil {
			tx.Rollback()
			return err
		}

		if m.TransactionMode == IndividualTransactions {
			if err := tx.Commit(); err != nil {
				return err
			}
		}
	}

	if m.TransactionMode == SingleTransaction {
		if err := tx.Commit(); err != nil {
			return err
		}
	}

	return nil
}

// runMigration runs the given Migration in the given direction using the given
// transaction. This function does not commit or rollback the transaction,
// that's the responsibility of the consumer dependending on whether an error
// gets returned.
func (m *Migrator) runMigration(tx *sql.Tx, dir MigrationDirection, migration Migration) error {
	shouldMigrate, err := m.shouldMigrate(tx, migration.ID, dir)
	if err != nil {
		return err
	}

	if !shouldMigrate {
		return nil
	}

	var migrate func(tx *sql.Tx) error
	switch dir {
	case Up:
		migrate = migration.Up
	default:
		migrate = migration.Down
	}

	if err := migrate(tx); err != nil {
		return &MigrationError{Migration: migration, Err: err}
	}

	var query string
	switch dir {
	case Up:
		// Yes. This is a sql injection vulnerability. This gets around
		// the different bindings for sqlite3/postgres.
		//
		// If you're running migrations from user input, you're doing
		// something wrong.
		query = fmt.Sprintf("INSERT INTO %s (version) VALUES (%d)", m.table(), migration.ID)
	default:
		query = fmt.Sprintf("DELETE FROM %s WHERE version = %d", m.table(), migration.ID)
	}

	_, err = tx.Exec(query)
	return err
}

func (m *Migrator) shouldMigrate(tx *sql.Tx, id int, dir MigrationDirection) (bool, error) {
	// Check if this migration has already ran
	var _id int
	err := tx.QueryRow(fmt.Sprintf("SELECT version FROM %s WHERE version = %d", m.table(), id)).Scan(&_id)
	if err != nil && err != sql.ErrNoRows {
		return false, err
	}

	switch dir {
	case Up:
		// If the migration doesn't exist, then we need to run it.
		return err == sql.ErrNoRows, nil
	default:
		// If the migration exists, then we need to remove it.
		return err != sql.ErrNoRows, nil
	}
}

// table returns the name of the table to use to track the migrations.
func (m *Migrator) table() string {
	if m.Table == "" {
		return DefaultTable
	}

	return m.Table
}

// Exec is a convenience method that runs the migrations against the default
// table.
func Exec(db *sql.DB, dir MigrationDirection, migrations ...Migration) error {
	return NewMigrator(db).Exec(dir, migrations...)
}

// Queries returns a func(tx *sql.Tx) error function that performs the given sql
// queries in multiple Exec calls.
func Queries(queries []string) func(*sql.Tx) error {
	return func(tx *sql.Tx) error {
		for _, query := range queries {
			if _, err := tx.Exec(query); err != nil {
				return err
			}
		}

		return nil
	}
}

// sortMigrations sorts the migrations by id.
//
// When the direction is "Up", the migrations will be sorted by ID ascending.
// When the direction is "Down", the migrations will be sorted by ID descending.
func sortMigrations(dir MigrationDirection, migrations []Migration) []Migration {
	var m byID
	for _, migration := range migrations {
		m = append(m, migration)
	}

	switch dir {
	case Up:
		sort.Sort(byID(m))
	default:
		sort.Sort(sort.Reverse(byID(m)))
	}

	return m
}