clair/vendor/github.com/remind101/migrate/migrate_test.go

344 lines
7.4 KiB
Go
Raw Normal View History

package migrate_test
import (
"database/sql"
"errors"
"flag"
"fmt"
"os"
"os/exec"
"strings"
"testing"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/remind101/migrate"
"github.com/stretchr/testify/assert"
)
const (
Sqlite = "sqlite3"
Postgres = "postgres"
)
// A flag to determine what database to run the suite against.
var database = flag.String("test.database", Sqlite, "The name of the database to run against. (sqlite3, postgres).")
var testMigrations = []migrate.Migration{
{
ID: 1,
Up: func(tx *sql.Tx) error {
_, err := tx.Exec("CREATE TABLE people (id int)")
return err
},
Down: func(tx *sql.Tx) error {
_, err := tx.Exec("DROP TABLE people")
return err
},
},
{
ID: 2,
// For simple sql migrations, you can use the migrate.Queries
// helper.
Up: migrate.Queries([]string{
"ALTER TABLE people ADD COLUMN first_name text",
}),
Down: func(tx *sql.Tx) error {
// It's not possible to remove a column with
// sqlite.
_, err := tx.Exec("SELECT 1 FROM people")
return err
},
},
}
func TestMigrate(t *testing.T) {
db := newDB(t)
defer db.Close()
migrations := testMigrations[:]
err := migrate.Exec(db, migrate.Up, migrations...)
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
err = migrate.Exec(db, migrate.Down, migrations...)
assert.NoError(t, err)
assert.Equal(t, []int{}, appliedMigrations(t, db))
assertSchema(t, ``, db)
}
func TestMigrate_Individual(t *testing.T) {
db := newDB(t)
defer db.Close()
err := migrate.Exec(db, migrate.Up, testMigrations[0])
assert.NoError(t, err)
assert.Equal(t, []int{1}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
err = migrate.Exec(db, migrate.Up, testMigrations[1])
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}
func TestMigrate_AlreadyRan(t *testing.T) {
db := newDB(t)
defer db.Close()
migration := testMigrations[0]
err := migrate.Exec(db, migrate.Up, migration)
assert.NoError(t, err)
assert.Equal(t, []int{1}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
err = migrate.Exec(db, migrate.Up, migration)
assert.NoError(t, err)
assert.Equal(t, []int{1}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
}
func TestMigrate_SingleTransactionMode_Rollback(t *testing.T) {
db := newDB(t)
defer db.Close()
migrator := migrate.NewMigrator(db)
migrator.TransactionMode = migrate.SingleTransaction
migrations := []migrate.Migration{
testMigrations[0],
testMigrations[1],
migrate.Migration{
ID: 3,
Up: func(tx *sql.Tx) error {
return errors.New("Rollback")
},
},
}
err := migrator.Exec(migrate.Up, migrations...)
assert.Error(t, err)
assert.Equal(t, []int{}, appliedMigrations(t, db))
assertSchema(t, ``, db)
}
func TestMigrate_SingleTransactionMode_Commit(t *testing.T) {
db := newDB(t)
defer db.Close()
migrator := migrate.NewMigrator(db)
migrator.TransactionMode = migrate.SingleTransaction
err := migrator.Exec(migrate.Up, testMigrations...)
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}
func TestMigrate_Order(t *testing.T) {
db := newDB(t)
defer db.Close()
migrations := []migrate.Migration{
testMigrations[1],
testMigrations[0],
}
err := migrate.Exec(db, migrate.Up, migrations...)
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}
func TestMigrate_Rollback(t *testing.T) {
db := newDB(t)
defer db.Close()
migration := migrate.Migration{
ID: 1,
Up: func(tx *sql.Tx) error {
// This should completely ok
if _, err := tx.Exec("CREATE TABLE people (id int)"); err != nil {
return err
}
// This should throw an error
if _, err := tx.Exec("ALTER TABLE foo ADD COLUMN first_name text"); err != nil {
return err
}
return nil
},
}
err := migrate.Exec(db, migrate.Up, migration)
assert.Error(t, err)
assert.Equal(t, []int{}, appliedMigrations(t, db))
// If the transaction wasn't rolled back, we'd see a people table.
assertSchema(t, ``, db)
assert.IsType(t, &migrate.MigrationError{}, err)
}
func TestMigrate_Locking(t *testing.T) {
db := newDB(t)
defer db.Close()
migrator := migrate.NewMigrator(db)
if *database == Postgres {
migrator = migrate.NewPostgresMigrator(db)
}
err := migrator.Exec(migrate.Up, testMigrations...)
assert.NoError(t, err)
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
var called int
// Generates a migration that sends on the given channel when it starts.
migration := migrate.Migration{
ID: 3,
Up: func(tx *sql.Tx) error {
called++
_, err := tx.Exec(`INSERT INTO people (id, first_name) VALUES (1, 'Eric')`)
return err
},
}
m1 := make(chan error)
m2 := make(chan error)
// Start two migrations in parallel.
go func() {
m1 <- migrator.Exec(migrate.Up, migration)
}()
go func() {
m2 <- migrator.Exec(migrate.Up, migration)
}()
assert.Nil(t, <-m1)
assert.Nil(t, <-m2)
assert.Equal(t, 1, called)
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
assert.Equal(t, []int{1, 2, 3}, appliedMigrations(t, db))
}
func assertSchema(t testing.TB, expectedSchema string, db *sql.DB) {
if *database == Sqlite {
schema, err := sqliteSchema(db)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, strings.TrimSpace(expectedSchema), schema)
}
}
func sqliteSchema(db *sql.DB) (string, error) {
var tables []string
rows, err := db.Query(`SELECT name, sql FROM sqlite_master
WHERE type='table'
ORDER BY name;`)
if err != nil {
return "", err
}
defer rows.Close()
for rows.Next() {
var name, sql string
if err := rows.Scan(&name, &sql); err != nil {
return "", err
}
if name == migrate.DefaultTable {
continue
}
tables = append(tables, fmt.Sprintf("%s\n%s", name, sql))
}
return strings.Join(tables, "\n\n"), nil
}
func appliedMigrations(t testing.TB, db *sql.DB) []int {
rows, err := db.Query("SELECT version FROM " + migrate.DefaultTable)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
ids := []int{}
for rows.Next() {
var id int
if err := rows.Scan(&id); err != nil {
t.Fatal(err)
}
ids = append(ids, id)
}
return ids
}
// factory methods to open a database connection to a type of database.
var databases = map[string]func() (*sql.DB, error){
Postgres: func() (*sql.DB, error) {
name := "migrate_test"
command := func(name string, arg ...string) *exec.Cmd {
cmd := exec.Command(name, arg...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
return cmd
}
command("dropdb", name).Run()
if err := command("createdb", name).Run(); err != nil {
return nil, err
}
return sql.Open("postgres", fmt.Sprintf("postgres://localhost/%s?sslmode=disable", name))
},
Sqlite: func() (*sql.DB, error) {
os.Remove("migrate_test.db")
return sql.Open("sqlite3", "migrate_test.db?cache=shared&mode=wrc")
},
}
func newDB(t testing.TB) *sql.DB {
open, ok := databases[*database]
if !ok {
t.Fatal(fmt.Sprintf("Unknown database: %s", *database))
}
db, err := open()
if err != nil {
t.Fatal(err)
}
return db
}