package migrate_test

import (
	"database/sql"
	"errors"
	"flag"
	"fmt"
	"os"
	"os/exec"
	"strings"
	"testing"

	_ "github.com/lib/pq"
	_ "github.com/mattn/go-sqlite3"
	"github.com/remind101/migrate"
	"github.com/stretchr/testify/assert"
)

const (
	Sqlite   = "sqlite3"
	Postgres = "postgres"
)

// A flag to determine what database to run the suite against.
var database = flag.String("test.database", Sqlite, "The name of the database to run against. (sqlite3, postgres).")

var testMigrations = []migrate.Migration{
	{
		ID: 1,
		Up: func(tx *sql.Tx) error {
			_, err := tx.Exec("CREATE TABLE people (id int)")
			return err
		},
		Down: func(tx *sql.Tx) error {
			_, err := tx.Exec("DROP TABLE people")
			return err
		},
	},
	{
		ID: 2,
		// For simple sql migrations, you can use the migrate.Queries
		// helper.
		Up: migrate.Queries([]string{
			"ALTER TABLE people ADD COLUMN first_name text",
		}),
		Down: func(tx *sql.Tx) error {
			// It's not possible to remove a column with
			// sqlite.
			_, err := tx.Exec("SELECT 1 FROM people")
			return err
		},
	},
}

func TestMigrate(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migrations := testMigrations[:]

	err := migrate.Exec(db, migrate.Up, migrations...)
	assert.NoError(t, err)
	assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)

	err = migrate.Exec(db, migrate.Down, migrations...)
	assert.NoError(t, err)
	assert.Equal(t, []int{}, appliedMigrations(t, db))
	assertSchema(t, ``, db)
}

func TestMigrate_Individual(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	err := migrate.Exec(db, migrate.Up, testMigrations[0])
	assert.NoError(t, err)
	assert.Equal(t, []int{1}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)

	err = migrate.Exec(db, migrate.Up, testMigrations[1])
	assert.NoError(t, err)
	assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}

func TestMigrate_AlreadyRan(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migration := testMigrations[0]

	err := migrate.Exec(db, migrate.Up, migration)
	assert.NoError(t, err)
	assert.Equal(t, []int{1}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)

	err = migrate.Exec(db, migrate.Up, migration)
	assert.NoError(t, err)
	assert.Equal(t, []int{1}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
}

func TestMigrate_SingleTransactionMode_Rollback(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migrator := migrate.NewMigrator(db)
	migrator.TransactionMode = migrate.SingleTransaction

	migrations := []migrate.Migration{
		testMigrations[0],
		testMigrations[1],
		migrate.Migration{
			ID: 3,
			Up: func(tx *sql.Tx) error {
				return errors.New("Rollback")
			},
		},
	}

	err := migrator.Exec(migrate.Up, migrations...)
	assert.Error(t, err)
	assert.Equal(t, []int{}, appliedMigrations(t, db))
	assertSchema(t, ``, db)
}

func TestMigrate_SingleTransactionMode_Commit(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migrator := migrate.NewMigrator(db)
	migrator.TransactionMode = migrate.SingleTransaction

	err := migrator.Exec(migrate.Up, testMigrations...)
	assert.NoError(t, err)
	assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}

func TestMigrate_Order(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migrations := []migrate.Migration{
		testMigrations[1],
		testMigrations[0],
	}

	err := migrate.Exec(db, migrate.Up, migrations...)
	assert.NoError(t, err)
	assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
	assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}

func TestMigrate_Rollback(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migration := migrate.Migration{
		ID: 1,
		Up: func(tx *sql.Tx) error {
			// This should completely ok
			if _, err := tx.Exec("CREATE TABLE people (id int)"); err != nil {
				return err
			}
			// This should throw an error
			if _, err := tx.Exec("ALTER TABLE foo ADD COLUMN first_name text"); err != nil {
				return err
			}
			return nil
		},
	}

	err := migrate.Exec(db, migrate.Up, migration)
	assert.Error(t, err)
	assert.Equal(t, []int{}, appliedMigrations(t, db))
	// If the transaction wasn't rolled back, we'd see a people table.
	assertSchema(t, ``, db)
	assert.IsType(t, &migrate.MigrationError{}, err)
}

func TestMigrate_Locking(t *testing.T) {
	db := newDB(t)
	defer db.Close()

	migrator := migrate.NewMigrator(db)
	if *database == Postgres {
		migrator = migrate.NewPostgresMigrator(db)
	}

	err := migrator.Exec(migrate.Up, testMigrations...)
	assert.NoError(t, err)
	assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
	assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))

	var called int
	// Generates a migration that sends on the given channel when it starts.
	migration := migrate.Migration{
		ID: 3,
		Up: func(tx *sql.Tx) error {
			called++
			_, err := tx.Exec(`INSERT INTO people (id, first_name) VALUES (1, 'Eric')`)
			return err
		},
	}

	m1 := make(chan error)
	m2 := make(chan error)

	// Start two migrations in parallel.
	go func() {
		m1 <- migrator.Exec(migrate.Up, migration)
	}()
	go func() {
		m2 <- migrator.Exec(migrate.Up, migration)
	}()

	assert.Nil(t, <-m1)
	assert.Nil(t, <-m2)
	assert.Equal(t, 1, called)

	assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
	assert.Equal(t, []int{1, 2, 3}, appliedMigrations(t, db))
}

func assertSchema(t testing.TB, expectedSchema string, db *sql.DB) {
	if *database == Sqlite {
		schema, err := sqliteSchema(db)
		if err != nil {
			t.Fatal(err)
		}

		assert.Equal(t, strings.TrimSpace(expectedSchema), schema)
	}
}

func sqliteSchema(db *sql.DB) (string, error) {
	var tables []string
	rows, err := db.Query(`SELECT name, sql FROM sqlite_master
WHERE type='table'
ORDER BY name;`)
	if err != nil {
		return "", err
	}
	defer rows.Close()
	for rows.Next() {
		var name, sql string
		if err := rows.Scan(&name, &sql); err != nil {
			return "", err
		}
		if name == migrate.DefaultTable {
			continue
		}
		tables = append(tables, fmt.Sprintf("%s\n%s", name, sql))
	}
	return strings.Join(tables, "\n\n"), nil
}

func appliedMigrations(t testing.TB, db *sql.DB) []int {
	rows, err := db.Query("SELECT version FROM " + migrate.DefaultTable)
	if err != nil {
		t.Fatal(err)
	}
	defer rows.Close()

	ids := []int{}
	for rows.Next() {
		var id int
		if err := rows.Scan(&id); err != nil {
			t.Fatal(err)
		}
		ids = append(ids, id)
	}

	return ids
}

// factory methods to open a database connection to a type of database.
var databases = map[string]func() (*sql.DB, error){
	Postgres: func() (*sql.DB, error) {
		name := "migrate_test"

		command := func(name string, arg ...string) *exec.Cmd {
			cmd := exec.Command(name, arg...)
			cmd.Stderr = os.Stderr
			cmd.Stdout = os.Stdout
			return cmd
		}

		command("dropdb", name).Run()
		if err := command("createdb", name).Run(); err != nil {
			return nil, err
		}

		return sql.Open("postgres", fmt.Sprintf("postgres://localhost/%s?sslmode=disable", name))
	},
	Sqlite: func() (*sql.DB, error) {
		os.Remove("migrate_test.db")
		return sql.Open("sqlite3", "migrate_test.db?cache=shared&mode=wrc")
	},
}

func newDB(t testing.TB) *sql.DB {
	open, ok := databases[*database]
	if !ok {
		t.Fatal(fmt.Sprintf("Unknown database: %s", *database))
	}

	db, err := open()
	if err != nil {
		t.Fatal(err)
	}

	return db
}