pgsql: Replace liamstask/goose by remind101/migrate

Fixes #93
pull/261/head
Quentin Machu 8 years ago
parent 90cc8243ba
commit b8865b2106

@ -0,0 +1,53 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import (
"database/sql"
"github.com/remind101/migrate"
)
func init() {
// This migration removes the data maintained by the previous migration tool
// (liamstask/goose), and if it was present, mark the 00002_initial_schema
// migration as done.
RegisterMigration(migrate.Migration{
ID: 1,
Up: func(tx *sql.Tx) error {
// Verify that goose was in use before, otherwise skip this migration.
var e bool
err := tx.QueryRow("SELECT true FROM pg_class WHERE relname = $1", "goose_db_version").Scan(&e)
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return err
}
// Delete goose's data.
_, err = tx.Exec("DROP TABLE goose_db_version CASCADE")
if err != nil {
return err
}
// Mark the '00002_initial_schema' as done.
_, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES (2)")
return err
},
Down: migrate.Queries([]string{}),
})
}

@ -0,0 +1,128 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
// This migration creates the initial Clair's schema.
RegisterMigration(migrate.Migration{
ID: 2,
Up: migrate.Queries([]string{
`CREATE TABLE IF NOT EXISTS Namespace (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NULL);`,
`CREATE TABLE IF NOT EXISTS Layer (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NOT NULL UNIQUE,
engineversion SMALLINT NOT NULL,
parent_id INT NULL REFERENCES Layer ON DELETE CASCADE,
namespace_id INT NULL REFERENCES Namespace,
created_at TIMESTAMP WITH TIME ZONE);`,
`CREATE INDEX ON Layer (parent_id);`,
`CREATE INDEX ON Layer (namespace_id);`,
`CREATE TABLE IF NOT EXISTS Feature (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
UNIQUE (namespace_id, name));`,
`CREATE TABLE IF NOT EXISTS FeatureVersion (
id SERIAL PRIMARY KEY,
feature_id INT NOT NULL REFERENCES Feature,
version VARCHAR(128) NOT NULL);`,
`CREATE INDEX ON FeatureVersion (feature_id);`,
`CREATE TYPE modification AS ENUM ('add', 'del');`,
`CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion (
id SERIAL PRIMARY KEY,
layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE,
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
modification modification NOT NULL,
UNIQUE (layer_id, featureversion_id));`,
`CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);`,
`CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);`,
`CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);`,
`CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`,
`CREATE TABLE IF NOT EXISTS Vulnerability (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
description TEXT NULL,
link VARCHAR(128) NULL,
severity severity NOT NULL,
metadata TEXT NULL,
created_at TIMESTAMP WITH TIME ZONE,
deleted_at TIMESTAMP WITH TIME ZONE NULL);`,
`CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
feature_id INT NOT NULL REFERENCES Feature,
version VARCHAR(128) NOT NULL,
UNIQUE (vulnerability_id, feature_id));`,
`CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);`,
`CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE,
UNIQUE (vulnerability_id, featureversion_id));`,
`CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);`,
`CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);`,
`CREATE TABLE IF NOT EXISTS KeyValue (
id SERIAL PRIMARY KEY,
key VARCHAR(128) NOT NULL UNIQUE,
value TEXT);`,
`CREATE TABLE IF NOT EXISTS Lock (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
owner VARCHAR(64) NOT NULL,
until TIMESTAMP WITH TIME ZONE);`,
`CREATE INDEX ON Lock (owner);`,
`CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
created_at TIMESTAMP WITH TIME ZONE,
notified_at TIMESTAMP WITH TIME ZONE NULL,
deleted_at TIMESTAMP WITH TIME ZONE NULL,
old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE,
new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`,
`CREATE INDEX ON Vulnerability_Notification (notified_at);`,
}),
Down: migrate.Queries([]string{
`DROP TABLE IF EXISTS
Namespace,
Layer,
Feature,
FeatureVersion,
Layer_diff_FeatureVersion,
Vulnerability,
Vulnerability_FixedIn_Feature,
Vulnerability_Affects_FeatureVersion,
Vulnerability_Notification,
KeyValue,
Lock
CASCADE;`,
}),
})
}

@ -1,174 +0,0 @@
-- Copyright 2015 clair authors
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- +goose Up
-- -----------------------------------------------------
-- Table Namespace
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Namespace (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NULL);
-- -----------------------------------------------------
-- Table Layer
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Layer (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NOT NULL UNIQUE,
engineversion SMALLINT NOT NULL,
parent_id INT NULL REFERENCES Layer ON DELETE CASCADE,
namespace_id INT NULL REFERENCES Namespace,
created_at TIMESTAMP WITH TIME ZONE);
CREATE INDEX ON Layer (parent_id);
CREATE INDEX ON Layer (namespace_id);
-- -----------------------------------------------------
-- Table Feature
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Feature (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
UNIQUE (namespace_id, name));
-- -----------------------------------------------------
-- Table FeatureVersion
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS FeatureVersion (
id SERIAL PRIMARY KEY,
feature_id INT NOT NULL REFERENCES Feature,
version VARCHAR(128) NOT NULL);
CREATE INDEX ON FeatureVersion (feature_id);
-- -----------------------------------------------------
-- Table Layer_diff_FeatureVersion
-- -----------------------------------------------------
CREATE TYPE modification AS ENUM ('add', 'del');
CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion (
id SERIAL PRIMARY KEY,
layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE,
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
modification modification NOT NULL,
UNIQUE (layer_id, featureversion_id));
CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);
CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);
CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);
-- -----------------------------------------------------
-- Table Vulnerability
-- -----------------------------------------------------
CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');
CREATE TABLE IF NOT EXISTS Vulnerability (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
description TEXT NULL,
link VARCHAR(128) NULL,
severity severity NOT NULL,
metadata TEXT NULL,
created_at TIMESTAMP WITH TIME ZONE,
deleted_at TIMESTAMP WITH TIME ZONE NULL);
-- -----------------------------------------------------
-- Table Vulnerability_FixedIn_Feature
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
feature_id INT NOT NULL REFERENCES Feature,
version VARCHAR(128) NOT NULL,
UNIQUE (vulnerability_id, feature_id));
CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);
-- -----------------------------------------------------
-- Table Vulnerability_Affects_FeatureVersion
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE,
UNIQUE (vulnerability_id, featureversion_id));
CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);
CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);
-- -----------------------------------------------------
-- Table KeyValue
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS KeyValue (
id SERIAL PRIMARY KEY,
key VARCHAR(128) NOT NULL UNIQUE,
value TEXT);
-- -----------------------------------------------------
-- Table Lock
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Lock (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
owner VARCHAR(64) NOT NULL,
until TIMESTAMP WITH TIME ZONE);
CREATE INDEX ON Lock (owner);
-- -----------------------------------------------------
-- Table VulnerabilityNotification
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
created_at TIMESTAMP WITH TIME ZONE,
notified_at TIMESTAMP WITH TIME ZONE NULL,
deleted_at TIMESTAMP WITH TIME ZONE NULL,
old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE,
new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);
CREATE INDEX ON Vulnerability_Notification (notified_at);
-- +goose Down
DROP TABLE IF EXISTS Namespace,
Layer,
Feature,
FeatureVersion,
Layer_diff_FeatureVersion,
Vulnerability,
Vulnerability_FixedIn_Feature,
Vulnerability_Affects_FeatureVersion,
Vulnerability_Notification,
KeyValue,
Lock
CASCADE;

@ -0,0 +1,27 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package migrations regroups every migrations available to the pgsql database
// backend.
package migrations
import "github.com/remind101/migrate"
// Migrations contains every available migrations.
var Migrations []migrate.Migration
// RegisterMigration adds the specified migration to the available migrations.
func RegisterMigration(migration migrate.Migration) {
Migrations = append(Migrations, migration)
}

@ -20,20 +20,20 @@ import (
"fmt"
"io/ioutil"
"net/url"
"path/filepath"
"runtime"
"strings"
"time"
"bitbucket.org/liamstask/goose/lib/goose"
"gopkg.in/yaml.v2"
"github.com/coreos/pkg/capnslog"
"github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"gopkg.in/yaml.v2"
"github.com/remind101/migrate"
"github.com/coreos/clair/config"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors"
)
@ -144,7 +144,7 @@ func openDatabase(registrableComponentConfig config.RegistrableComponentConfig)
// Create database.
if pg.config.ManageDatabaseLifecycle {
log.Info("pgsql: creating database")
if err := createDatabase(pgSourceURL, dbName); err != nil {
if err = createDatabase(pgSourceURL, dbName); err != nil {
return nil, err
}
}
@ -157,13 +157,13 @@ func openDatabase(registrableComponentConfig config.RegistrableComponentConfig)
}
// Verify database state.
if err := pg.DB.Ping(); err != nil {
if err = pg.DB.Ping(); err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
// Run migrations.
if err := migrate(pg.config.Source); err != nil {
if err = migrateDatabase(pg.DB); err != nil {
pg.Close()
return nil, err
}
@ -214,29 +214,10 @@ func parseConnectionString(source string) (dbName string, pgSourceURL string, er
}
// migrate runs all available migrations on a pgSQL database.
func migrate(source string) error {
func migrateDatabase(db *sql.DB) error {
log.Info("running database migrations")
_, filename, _, _ := runtime.Caller(1)
migrationDir := filepath.Join(filepath.Dir(filename), "/migrations/")
conf := &goose.DBConf{
MigrationsDir: migrationDir,
Driver: goose.DBDriver{
Name: "postgres",
OpenStr: source,
Import: "github.com/lib/pq",
Dialect: &goose.PostgresDialect{},
},
}
// Determine the most recent revision available from the migrations folder.
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
if err != nil {
return fmt.Errorf("pgsql: could not get most recent migration: %v", err)
}
// Run migrations.
err = goose.RunMigrations(conf, conf.MigrationsDir, target)
err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...)
if err != nil {
return fmt.Errorf("pgsql: an error occured while running migrations: %v", err)
}

10
glide.lock generated

@ -1,10 +1,6 @@
hash: 6787deaa403181a20c0c0ba7e3135ba0f300281e7b3f98294593aa9f60563627
updated: 2016-06-07T06:56:38.153128536+02:00
hash: 90e72dd70a9102940fc6126afc2d7d73ee79faaf75cecb3710b15f5a121c67ce
updated: 2016-11-11T15:36:41.356471726+01:00
imports:
- name: bitbucket.org/liamstask/goose
version: 8488cc47d90c8a502b1c41a462a6d9cc8ee0a895
subpackages:
- lib/goose
- name: github.com/beorn7/perks
version: b965b613227fddccbfffe13eae360ed3fa822f8d
subpackages:
@ -82,6 +78,8 @@ imports:
- model
- name: github.com/prometheus/procfs
version: 406e5b7bfd8201a36e2bb5f7bdae0b03380c2ce8
- name: github.com/remind101/migrate
version: d22d647232c20dbea6d2aa1dda7f2737cccce614
- name: github.com/shiena/ansicolor
version: a422bbe96644373c5753384a59d678f7d261ff10
- name: github.com/stretchr/testify

@ -1,9 +1,5 @@
package: github.com/coreos/clair
import:
- package: bitbucket.org/liamstask/goose
version: 8488cc47d90c8a502b1c41a462a6d9cc8ee0a895
subpackages:
- lib/goose
- package: github.com/beorn7/perks
version: b965b613227fddccbfffe13eae360ed3fa822f8d
subpackages:
@ -94,3 +90,4 @@ import:
- package: github.com/fatih/color
version: ^0.1.0
- package: github.com/kr/text
- package: github.com/remind101/migrate

@ -1,4 +0,0 @@
.DS_Store
cmd/goose/goose*
*.swp
*.test

@ -1,7 +0,0 @@
Copyright (c) <2012> <Liam Staskawicz>
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -1,265 +0,0 @@
# goose
goose is a database migration tool.
You can manage your database's evolution by creating incremental SQL or Go scripts.
[![Build Status](https://drone.io/bitbucket.org/liamstask/goose/status.png)](https://drone.io/bitbucket.org/liamstask/goose/latest)
# Install
$ go get bitbucket.org/liamstask/goose/cmd/goose
This will install the `goose` binary to your `$GOPATH/bin` directory.
You can also build goose into your own applications by importing `bitbucket.org/liamstask/goose/lib/goose`. Documentation is available at [godoc.org](http://godoc.org/bitbucket.org/liamstask/goose/lib/goose).
NOTE: the API is still new, and may undergo some changes.
# Usage
goose provides several commands to help manage your database schema.
## create
Create a new Go migration.
$ goose create AddSomeColumns
$ goose: created db/migrations/20130106093224_AddSomeColumns.go
Edit the newly created script to define the behavior of your migration.
You can also create an SQL migration:
$ goose create AddSomeColumns sql
$ goose: created db/migrations/20130106093224_AddSomeColumns.sql
## up
Apply all available migrations.
$ goose up
$ goose: migrating db environment 'development', current version: 0, target: 3
$ OK 001_basics.sql
$ OK 002_next.sql
$ OK 003_and_again.go
### option: pgschema
Use the `pgschema` flag with the `up` command specify a postgres schema.
$ goose -pgschema=my_schema_name up
$ goose: migrating db environment 'development', current version: 0, target: 3
$ OK 001_basics.sql
$ OK 002_next.sql
$ OK 003_and_again.go
## down
Roll back a single migration from the current version.
$ goose down
$ goose: migrating db environment 'development', current version: 3, target: 2
$ OK 003_and_again.go
## redo
Roll back the most recently applied migration, then run it again.
$ goose redo
$ goose: migrating db environment 'development', current version: 3, target: 2
$ OK 003_and_again.go
$ goose: migrating db environment 'development', current version: 2, target: 3
$ OK 003_and_again.go
## status
Print the status of all migrations:
$ goose status
$ goose: status for environment 'development'
$ Applied At Migration
$ =======================================
$ Sun Jan 6 11:25:03 2013 -- 001_basics.sql
$ Sun Jan 6 11:25:03 2013 -- 002_next.sql
$ Pending -- 003_and_again.go
## dbversion
Print the current version of the database:
$ goose dbversion
$ goose: dbversion 002
`goose -h` provides more detailed info on each command.
# Migrations
goose supports migrations written in SQL or in Go - see the `goose create` command above for details on how to generate them.
## SQL Migrations
A sample SQL migration looks like:
```sql
-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE post;
```
Notice the annotations in the comments. Any statements following `-- +goose Up` will be executed as part of a forward migration, and any statements following `-- +goose Down` will be executed as part of a rollback.
By default, SQL statements are delimited by semicolons - in fact, query statements must end with a semicolon to be properly recognized by goose.
More complex statements (PL/pgSQL) that have semicolons within them must be annotated with `-- +goose StatementBegin` and `-- +goose StatementEnd` to be properly recognized. For example:
```sql
-- +goose Up
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE )
returns void AS $$
DECLARE
create_query text;
BEGIN
FOR create_query IN SELECT
'CREATE TABLE IF NOT EXISTS histories_'
|| TO_CHAR( d, 'YYYY_MM' )
|| ' ( CHECK( created_at >= timestamp '''
|| TO_CHAR( d, 'YYYY-MM-DD 00:00:00' )
|| ''' AND created_at < timestamp '''
|| TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' )
|| ''' ) ) inherits ( histories );'
FROM generate_series( $1, $2, '1 month' ) AS d
LOOP
EXECUTE create_query;
END LOOP; -- LOOP END
END; -- FUNCTION END
$$
language plpgsql;
-- +goose StatementEnd
```
## Go Migrations
A sample Go migration looks like:
```go
package main
import (
"database/sql"
"fmt"
)
func Up_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Up!")
}
func Down_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Down!")
}
```
`Up_20130106222315()` will be executed as part of a forward migration, and `Down_20130106222315()` will be executed as part of a rollback.
The numeric portion of the function name (`20130106222315`) must be the leading portion of migration's filename, such as `20130106222315_descriptive_name.go`. `goose create` does this by default.
A transaction is provided, rather than the DB instance directly, since goose also needs to record the schema version within the same transaction. Each migration should run as a single transaction to ensure DB integrity, so it's good practice anyway.
# Configuration
goose expects you to maintain a folder (typically called "db"), which contains the following:
* a `dbconf.yml` file that describes the database configurations you'd like to use
* a folder called "migrations" which contains `.sql` and/or `.go` scripts that implement your migrations
You may use the `-path` option to specify an alternate location for the folder containing your config and migrations.
A sample `dbconf.yml` looks like
```yml
development:
driver: postgres
open: user=liam dbname=tester sslmode=disable
```
Here, `development` specifies the name of the environment, and the `driver` and `open` elements are passed directly to database/sql to access the specified database.
You may include as many environments as you like, and you can use the `-env` command line option to specify which one to use. goose defaults to using an environment called `development`.
goose will expand environment variables in the `open` element. For an example, see the Heroku section below.
## Other Drivers
goose knows about some common SQL drivers, but it can still be used to run Go-based migrations with any driver supported by `database/sql`. An import path and known dialect are required.
Currently, available dialects are: "postgres", "mysql", or "sqlite3"
To run Go-based migrations with another driver, specify its import path and dialect, as shown below.
```yml
customdriver:
driver: custom
open: custom open string
import: github.com/custom/driver
dialect: mysql
```
NOTE: Because migrations written in SQL are executed directly by the goose binary, only drivers compiled into goose may be used for these migrations.
## Using goose with Heroku
These instructions assume that you're using [Keith Rarick's Heroku Go buildpack](https://github.com/kr/heroku-buildpack-go). First, add a file to your project called (e.g.) `install_goose.go` to trigger building of the goose executable during deployment, with these contents:
```go
// use build constraints to work around http://code.google.com/p/go/issues/detail?id=4210
// +build heroku
// note: need at least one blank line after build constraint
package main
import _ "bitbucket.org/liamstask/goose/cmd/goose"
```
[Set up your Heroku database(s) as usual.](https://devcenter.heroku.com/articles/heroku-postgresql)
Then make use of environment variable expansion in your `dbconf.yml`:
```yml
production:
driver: postgres
open: $DATABASE_URL
```
To run goose in production, use `heroku run`:
heroku run goose -env production up
# Contributors
Thank you!
* Josh Bleecher Snyder (josharian)
* Abigail Walthall (ghthor)
* Daniel Heath (danielrheath)
* Chris Baynes (chris_baynes)
* Michael Gerow (gerow)
* Vytautas Šaltenis (rtfb)
* James Cooper (coopernurse)
* Gyepi Sam (gyepisam)
* Matt Sherman (clipperhouse)
* runner_mei
* John Luebs (jkl1337)
* Luke Hutton (lukehutton)
* Kevin Gorjan (kevingorjan)
* Brendan Fosberry (Fozz)
* Nate Guerin (gusennan)

@ -1,27 +0,0 @@
package main
import (
"flag"
)
// shamelessly snagged from the go tool
// each command gets its own set of args,
// defines its own entry point, and provides its own help
type Command struct {
Run func(cmd *Command, args ...string)
Flag flag.FlagSet
Name string
Usage string
Summary string
Help string
}
func (c *Command) Exec(args []string) {
c.Flag.Usage = func() {
// helpFunc(c, c.Name)
}
c.Flag.Parse(args)
c.Run(c, c.Flag.Args()...)
}

@ -1,51 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"fmt"
"log"
"os"
"path/filepath"
"time"
)
var createCmd = &Command{
Name: "create",
Usage: "",
Summary: "Create the scaffolding for a new migration",
Help: `create extended help here...`,
Run: createRun,
}
func createRun(cmd *Command, args ...string) {
if len(args) < 1 {
log.Fatal("goose create: migration name required")
}
migrationType := "go" // default to Go migrations
if len(args) >= 2 {
migrationType = args[1]
}
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
if err = os.MkdirAll(conf.MigrationsDir, 0777); err != nil {
log.Fatal(err)
}
n, err := goose.CreateMigration(args[0], migrationType, conf.MigrationsDir, time.Now())
if err != nil {
log.Fatal(err)
}
a, e := filepath.Abs(n)
if e != nil {
log.Fatal(e)
}
fmt.Println("goose: created", a)
}

@ -1,29 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"fmt"
"log"
)
var dbVersionCmd = &Command{
Name: "dbversion",
Usage: "",
Summary: "Print the current version of the database",
Help: `dbversion extended help here...`,
Run: dbVersionRun,
}
func dbVersionRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
current, err := goose.GetDBVersion(conf)
if err != nil {
log.Fatal(err)
}
fmt.Printf("goose: dbversion %v\n", current)
}

@ -1,36 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"log"
)
var downCmd = &Command{
Name: "down",
Usage: "",
Summary: "Roll back the version by 1",
Help: `down extended help here...`,
Run: downRun,
}
func downRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
current, err := goose.GetDBVersion(conf)
if err != nil {
log.Fatal(err)
}
previous, err := goose.GetPreviousDBVersion(conf.MigrationsDir, current)
if err != nil {
log.Fatal(err)
}
if err = goose.RunMigrations(conf, conf.MigrationsDir, previous); err != nil {
log.Fatal(err)
}
}

@ -1,39 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"log"
)
var redoCmd = &Command{
Name: "redo",
Usage: "",
Summary: "Re-run the latest migration",
Help: `redo extended help here...`,
Run: redoRun,
}
func redoRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
current, err := goose.GetDBVersion(conf)
if err != nil {
log.Fatal(err)
}
previous, err := goose.GetPreviousDBVersion(conf.MigrationsDir, current)
if err != nil {
log.Fatal(err)
}
if err := goose.RunMigrations(conf, conf.MigrationsDir, previous); err != nil {
log.Fatal(err)
}
if err := goose.RunMigrations(conf, conf.MigrationsDir, current); err != nil {
log.Fatal(err)
}
}

@ -1,77 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"database/sql"
"fmt"
"log"
"path/filepath"
"time"
)
var statusCmd = &Command{
Name: "status",
Usage: "",
Summary: "dump the migration status for the current DB",
Help: `status extended help here...`,
Run: statusRun,
}
type StatusData struct {
Source string
Status string
}
func statusRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
// collect all migrations
min := int64(0)
max := int64((1 << 63) - 1)
migrations, e := goose.CollectMigrations(conf.MigrationsDir, min, max)
if e != nil {
log.Fatal(e)
}
db, e := goose.OpenDBFromDBConf(conf)
if e != nil {
log.Fatal("couldn't open DB:", e)
}
defer db.Close()
// must ensure that the version table exists if we're running on a pristine DB
if _, e := goose.EnsureDBVersion(conf, db); e != nil {
log.Fatal(e)
}
fmt.Printf("goose: status for environment '%v'\n", conf.Env)
fmt.Println(" Applied At Migration")
fmt.Println(" =======================================")
for _, m := range migrations {
printMigrationStatus(db, m.Version, filepath.Base(m.Source))
}
}
func printMigrationStatus(db *sql.DB, version int64, script string) {
var row goose.MigrationRecord
q := fmt.Sprintf("SELECT tstamp, is_applied FROM goose_db_version WHERE version_id=%d ORDER BY tstamp DESC LIMIT 1", version)
e := db.QueryRow(q).Scan(&row.TStamp, &row.IsApplied)
if e != nil && e != sql.ErrNoRows {
log.Fatal(e)
}
var appliedAt string
if row.IsApplied {
appliedAt = row.TStamp.Format(time.ANSIC)
} else {
appliedAt = "Pending"
}
fmt.Printf(" %-24s -- %v\n", appliedAt, script)
}

@ -1,31 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"log"
)
var upCmd = &Command{
Name: "up",
Usage: "",
Summary: "Migrate the DB to the most recent version available",
Help: `up extended help here...`,
Run: upRun,
}
func upRun(cmd *Command, args ...string) {
conf, err := dbConfFromFlags()
if err != nil {
log.Fatal(err)
}
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
if err != nil {
log.Fatal(err)
}
if err := goose.RunMigrations(conf, conf.MigrationsDir, target); err != nil {
log.Fatal(err)
}
}

@ -1,78 +0,0 @@
package main
import (
"bitbucket.org/liamstask/goose/lib/goose"
"flag"
"fmt"
"os"
"strings"
"text/template"
)
// global options. available to any subcommands.
var flagPath = flag.String("path", "db", "folder containing db info")
var flagEnv = flag.String("env", "development", "which DB environment to use")
var flagPgSchema = flag.String("pgschema", "", "which postgres-schema to migrate (default = none)")
// helper to create a DBConf from the given flags
func dbConfFromFlags() (dbconf *goose.DBConf, err error) {
return goose.NewDBConf(*flagPath, *flagEnv, *flagPgSchema)
}
var commands = []*Command{
upCmd,
downCmd,
redoCmd,
statusCmd,
createCmd,
dbVersionCmd,
}
func main() {
flag.Usage = usage
flag.Parse()
args := flag.Args()
if len(args) == 0 || args[0] == "-h" {
flag.Usage()
return
}
var cmd *Command
name := args[0]
for _, c := range commands {
if strings.HasPrefix(c.Name, name) {
cmd = c
break
}
}
if cmd == nil {
fmt.Printf("error: unknown command %q\n", name)
flag.Usage()
os.Exit(1)
}
cmd.Exec(args[1:])
}
func usage() {
fmt.Print(usagePrefix)
flag.PrintDefaults()
usageTmpl.Execute(os.Stdout, commands)
}
var usagePrefix = `
goose is a database migration management system for Go projects.
Usage:
goose [options] <subcommand> [subcommand options]
Options:
`
var usageTmpl = template.Must(template.New("usage").Parse(
`
Commands:{{range .}}
{{.Name | printf "%-10s"}} {{.Summary}}{{end}}
`))

@ -1,22 +0,0 @@
test:
driver: postgres
open: user=liam dbname=tester sslmode=disable
development:
driver: postgres
open: user=liam dbname=tester sslmode=disable
production:
driver: postgres
open: user=liam dbname=tester sslmode=verify-full
customimport:
driver: customdriver
open: customdriver open
import: github.com/custom/driver
dialect: mysql
environment_variable_config:
driver: $DB_DRIVER
open: $DATABASE_URL

@ -1,11 +0,0 @@
-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE post;

@ -1,12 +0,0 @@
-- +goose Up
CREATE TABLE fancier_post (
id int NOT NULL,
title text,
body text,
created_on timestamp without time zone,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE fancier_post;

@ -1,15 +0,0 @@
package main
import (
"database/sql"
"fmt"
)
func Up_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Up!")
}
func Down_20130106222315(txn *sql.Tx) {
fmt.Println("Hello from migration 20130106222315 Down!")
}

@ -1,139 +0,0 @@
package goose
import (
"database/sql"
"errors"
"fmt"
"os"
"path/filepath"
"github.com/kylelemons/go-gypsy/yaml"
"github.com/lib/pq"
)
// DBDriver encapsulates the info needed to work with
// a specific database driver
type DBDriver struct {
Name string
OpenStr string
Import string
Dialect SqlDialect
}
type DBConf struct {
MigrationsDir string
Env string
Driver DBDriver
PgSchema string
}
// extract configuration details from the given file
func NewDBConf(p, env string, pgschema string) (*DBConf, error) {
cfgFile := filepath.Join(p, "dbconf.yml")
f, err := yaml.ReadFile(cfgFile)
if err != nil {
return nil, err
}
drv, err := f.Get(fmt.Sprintf("%s.driver", env))
if err != nil {
return nil, err
}
drv = os.ExpandEnv(drv)
open, err := f.Get(fmt.Sprintf("%s.open", env))
if err != nil {
return nil, err
}
open = os.ExpandEnv(open)
// Automatically parse postgres urls
if drv == "postgres" {
// Assumption: If we can parse the URL, we should
if parsedURL, err := pq.ParseURL(open); err == nil && parsedURL != "" {
open = parsedURL
}
}
d := newDBDriver(drv, open)
// allow the configuration to override the Import for this driver
if imprt, err := f.Get(fmt.Sprintf("%s.import", env)); err == nil {
d.Import = imprt
}
// allow the configuration to override the Dialect for this driver
if dialect, err := f.Get(fmt.Sprintf("%s.dialect", env)); err == nil {
d.Dialect = dialectByName(dialect)
}
if !d.IsValid() {
return nil, errors.New(fmt.Sprintf("Invalid DBConf: %v", d))
}
return &DBConf{
MigrationsDir: filepath.Join(p, "migrations"),
Env: env,
Driver: d,
PgSchema: pgschema,
}, nil
}
// Create a new DBDriver and populate driver specific
// fields for drivers that we know about.
// Further customization may be done in NewDBConf
func newDBDriver(name, open string) DBDriver {
d := DBDriver{
Name: name,
OpenStr: open,
}
switch name {
case "postgres":
d.Import = "github.com/lib/pq"
d.Dialect = &PostgresDialect{}
case "mymysql":
d.Import = "github.com/ziutek/mymysql/godrv"
d.Dialect = &MySqlDialect{}
case "mysql":
d.Import = "github.com/go-sql-driver/mysql"
d.Dialect = &MySqlDialect{}
case "sqlite3":
d.Import = "github.com/mattn/go-sqlite3"
d.Dialect = &Sqlite3Dialect{}
}
return d
}
// ensure we have enough info about this driver
func (drv *DBDriver) IsValid() bool {
return len(drv.Import) > 0 && drv.Dialect != nil
}
// OpenDBFromDBConf wraps database/sql.DB.Open() and configures
// the newly opened DB based on the given DBConf.
//
// Callers must Close() the returned DB.
func OpenDBFromDBConf(conf *DBConf) (*sql.DB, error) {
db, err := sql.Open(conf.Driver.Name, conf.Driver.OpenStr)
if err != nil {
return nil, err
}
// if a postgres schema has been specified, apply it
if conf.Driver.Name == "postgres" && conf.PgSchema != "" {
if _, err := db.Exec("SET search_path TO " + conf.PgSchema); err != nil {
return nil, err
}
}
return db, nil
}

@ -1,70 +0,0 @@
package goose
import (
"os"
"reflect"
"testing"
)
func TestBasics(t *testing.T) {
dbconf, err := NewDBConf("../../db-sample", "test", "")
if err != nil {
t.Fatal(err)
}
got := []string{dbconf.MigrationsDir, dbconf.Env, dbconf.Driver.Name, dbconf.Driver.OpenStr}
want := []string{"../../db-sample/migrations", "test", "postgres", "user=liam dbname=tester sslmode=disable"}
for i, s := range got {
if s != want[i] {
t.Errorf("Unexpected DBConf value. got %v, want %v", s, want[i])
}
}
}
func TestImportOverride(t *testing.T) {
dbconf, err := NewDBConf("../../db-sample", "customimport", "")
if err != nil {
t.Fatal(err)
}
got := dbconf.Driver.Import
want := "github.com/custom/driver"
if got != want {
t.Errorf("bad custom import. got %v want %v", got, want)
}
}
func TestDriverSetFromEnvironmentVariable(t *testing.T) {
databaseUrlEnvVariableKey := "DB_DRIVER"
databaseUrlEnvVariableVal := "sqlite3"
databaseOpenStringKey := "DATABASE_URL"
databaseOpenStringVal := "db.db"
os.Setenv(databaseUrlEnvVariableKey, databaseUrlEnvVariableVal)
os.Setenv(databaseOpenStringKey, databaseOpenStringVal)
dbconf, err := NewDBConf("../../db-sample", "environment_variable_config", "")
if err != nil {
t.Fatal(err)
}
got := reflect.TypeOf(dbconf.Driver.Dialect)
want := reflect.TypeOf(&Sqlite3Dialect{})
if got != want {
t.Errorf("Not able to read the driver type from environment variable."+
"got %v want %v", got, want)
}
gotOpenString := dbconf.Driver.OpenStr
wantOpenString := databaseOpenStringVal
if gotOpenString != wantOpenString {
t.Errorf("Not able to read the open string from the environment."+
"got %v want %v", gotOpenString, wantOpenString)
}
}

@ -1,123 +0,0 @@
package goose
import (
"database/sql"
"github.com/mattn/go-sqlite3"
)
// SqlDialect abstracts the details of specific SQL dialects
// for goose's few SQL specific statements
type SqlDialect interface {
createVersionTableSql() string // sql string to create the goose_db_version table
insertVersionSql() string // sql string to insert the initial version table row
dbVersionQuery(db *sql.DB) (*sql.Rows, error)
}
// drivers that we don't know about can ask for a dialect by name
func dialectByName(d string) SqlDialect {
switch d {
case "postgres":
return &PostgresDialect{}
case "mysql":
return &MySqlDialect{}
case "sqlite3":
return &Sqlite3Dialect{}
}
return nil
}
////////////////////////////
// Postgres
////////////////////////////
type PostgresDialect struct{}
func (pg PostgresDialect) createVersionTableSql() string {
return `CREATE TABLE goose_db_version (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`
}
func (pg PostgresDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2);"
}
func (pg PostgresDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC")
// XXX: check for postgres specific error indicating the table doesn't exist.
// for now, assume any error is because the table doesn't exist,
// in which case we'll try to create it.
if err != nil {
return nil, ErrTableDoesNotExist
}
return rows, err
}
////////////////////////////
// MySQL
////////////////////////////
type MySqlDialect struct{}
func (m MySqlDialect) createVersionTableSql() string {
return `CREATE TABLE goose_db_version (
id serial NOT NULL,
version_id bigint NOT NULL,
is_applied boolean NOT NULL,
tstamp timestamp NULL default now(),
PRIMARY KEY(id)
);`
}
func (m MySqlDialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
}
func (m MySqlDialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC")
// XXX: check for mysql specific error indicating the table doesn't exist.
// for now, assume any error is because the table doesn't exist,
// in which case we'll try to create it.
if err != nil {
return nil, ErrTableDoesNotExist
}
return rows, err
}
////////////////////////////
// sqlite3
////////////////////////////
type Sqlite3Dialect struct{}
func (m Sqlite3Dialect) createVersionTableSql() string {
return `CREATE TABLE goose_db_version (
id INTEGER PRIMARY KEY AUTOINCREMENT,
version_id INTEGER NOT NULL,
is_applied INTEGER NOT NULL,
tstamp TIMESTAMP DEFAULT (datetime('now'))
);`
}
func (m Sqlite3Dialect) insertVersionSql() string {
return "INSERT INTO goose_db_version (version_id, is_applied) VALUES (?, ?);"
}
func (m Sqlite3Dialect) dbVersionQuery(db *sql.DB) (*sql.Rows, error) {
rows, err := db.Query("SELECT version_id, is_applied from goose_db_version ORDER BY id DESC")
switch err.(type) {
case sqlite3.Error:
return nil, ErrTableDoesNotExist
}
return rows, err
}

@ -1,413 +0,0 @@
package goose
import (
"database/sql"
"errors"
"fmt"
"log"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"text/template"
"time"
_ "github.com/go-sql-driver/mysql"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
_ "github.com/ziutek/mymysql/godrv"
)
var (
ErrTableDoesNotExist = errors.New("table does not exist")
ErrNoPreviousVersion = errors.New("no previous version found")
)
type MigrationRecord struct {
VersionId int64
TStamp time.Time
IsApplied bool // was this a result of up() or down()
}
type Migration struct {
Version int64
Next int64 // next version, or -1 if none
Previous int64 // previous version, -1 if none
Source string // path to .go or .sql script
}
type migrationSorter []*Migration
// helpers so we can use pkg sort
func (ms migrationSorter) Len() int { return len(ms) }
func (ms migrationSorter) Swap(i, j int) { ms[i], ms[j] = ms[j], ms[i] }
func (ms migrationSorter) Less(i, j int) bool { return ms[i].Version < ms[j].Version }
func newMigration(v int64, src string) *Migration {
return &Migration{v, -1, -1, src}
}
func RunMigrations(conf *DBConf, migrationsDir string, target int64) (err error) {
db, err := OpenDBFromDBConf(conf)
if err != nil {
return err
}
defer db.Close()
return RunMigrationsOnDb(conf, migrationsDir, target, db)
}
// Runs migration on a specific database instance.
func RunMigrationsOnDb(conf *DBConf, migrationsDir string, target int64, db *sql.DB) (err error) {
current, err := EnsureDBVersion(conf, db)
if err != nil {
return err
}
migrations, err := CollectMigrations(migrationsDir, current, target)
if err != nil {
return err
}
if len(migrations) == 0 {
fmt.Printf("goose: no migrations to run. current version: %d\n", current)
return nil
}
ms := migrationSorter(migrations)
direction := current < target
ms.Sort(direction)
fmt.Printf("goose: migrating db environment '%v', current version: %d, target: %d\n",
conf.Env, current, target)
for _, m := range ms {
switch filepath.Ext(m.Source) {
case ".go":
err = runGoMigration(conf, m.Source, m.Version, direction)
case ".sql":
err = runSQLMigration(conf, db, m.Source, m.Version, direction)
}
if err != nil {
return errors.New(fmt.Sprintf("FAIL %v, quitting migration", err))
}
fmt.Println("OK ", filepath.Base(m.Source))
}
return nil
}
// collect all the valid looking migration scripts in the
// migrations folder, and key them by version
func CollectMigrations(dirpath string, current, target int64) (m []*Migration, err error) {
// extract the numeric component of each migration,
// filter out any uninteresting files,
// and ensure we only have one file per migration version.
filepath.Walk(dirpath, func(name string, info os.FileInfo, err error) error {
if v, e := NumericComponent(name); e == nil {
for _, g := range m {
if v == g.Version {
log.Fatalf("more than one file specifies the migration for version %d (%s and %s)",
v, g.Source, filepath.Join(dirpath, name))
}
}
if versionFilter(v, current, target) {
m = append(m, newMigration(v, name))
}
}
return nil
})
return m, nil
}
func versionFilter(v, current, target int64) bool {
if target > current {
return v > current && v <= target
}
if target < current {
return v <= current && v > target
}
return false
}
func (ms migrationSorter) Sort(direction bool) {
// sort ascending or descending by version
if direction {
sort.Sort(ms)
} else {
sort.Sort(sort.Reverse(ms))
}
// now that we're sorted in the appropriate direction,
// populate next and previous for each migration
for i, m := range ms {
prev := int64(-1)
if i > 0 {
prev = ms[i-1].Version
ms[i-1].Next = m.Version
}
ms[i].Previous = prev
}
}
// look for migration scripts with names in the form:
// XXX_descriptivename.ext
// where XXX specifies the version number
// and ext specifies the type of migration
func NumericComponent(name string) (int64, error) {
base := filepath.Base(name)
if ext := filepath.Ext(base); ext != ".go" && ext != ".sql" {
return 0, errors.New("not a recognized migration file type")
}
idx := strings.Index(base, "_")
if idx < 0 {
return 0, errors.New("no separator found")
}
n, e := strconv.ParseInt(base[:idx], 10, 64)
if e == nil && n <= 0 {
return 0, errors.New("migration IDs must be greater than zero")
}
return n, e
}
// retrieve the current version for this DB.
// Create and initialize the DB version table if it doesn't exist.
func EnsureDBVersion(conf *DBConf, db *sql.DB) (int64, error) {
rows, err := conf.Driver.Dialect.dbVersionQuery(db)
if err != nil {
if err == ErrTableDoesNotExist {
return 0, createVersionTable(conf, db)
}
return 0, err
}
defer rows.Close()
// The most recent record for each migration specifies
// whether it has been applied or rolled back.
// The first version we find that has been applied is the current version.
toSkip := make([]int64, 0)
for rows.Next() {
var row MigrationRecord
if err = rows.Scan(&row.VersionId, &row.IsApplied); err != nil {
log.Fatal("error scanning rows:", err)
}
// have we already marked this version to be skipped?
skip := false
for _, v := range toSkip {
if v == row.VersionId {
skip = true
break
}
}
if skip {
continue
}
// if version has been applied we're done
if row.IsApplied {
return row.VersionId, nil
}
// latest version of migration has not been applied.
toSkip = append(toSkip, row.VersionId)
}
panic("failure in EnsureDBVersion()")
}
// Create the goose_db_version table
// and insert the initial 0 value into it
func createVersionTable(conf *DBConf, db *sql.DB) error {
txn, err := db.Begin()
if err != nil {
return err
}
d := conf.Driver.Dialect
if _, err := txn.Exec(d.createVersionTableSql()); err != nil {
txn.Rollback()
return err
}
version := 0
applied := true
if _, err := txn.Exec(d.insertVersionSql(), version, applied); err != nil {
txn.Rollback()
return err
}
return txn.Commit()
}
// wrapper for EnsureDBVersion for callers that don't already have
// their own DB instance
func GetDBVersion(conf *DBConf) (version int64, err error) {
db, err := OpenDBFromDBConf(conf)
if err != nil {
return -1, err
}
defer db.Close()
version, err = EnsureDBVersion(conf, db)
if err != nil {
return -1, err
}
return version, nil
}
func GetPreviousDBVersion(dirpath string, version int64) (previous int64, err error) {
previous = -1
sawGivenVersion := false
filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v > previous && v < version {
previous = v
}
if v == version {
sawGivenVersion = true
}
}
}
return nil
})
if previous == -1 {
if sawGivenVersion {
// the given version is (likely) valid but we didn't find
// anything before it.
// 'previous' must reflect that no migrations have been applied.
previous = 0
} else {
err = ErrNoPreviousVersion
}
}
return
}
// helper to identify the most recent possible version
// within a folder of migration scripts
func GetMostRecentDBVersion(dirpath string) (version int64, err error) {
version = -1
filepath.Walk(dirpath, func(name string, info os.FileInfo, walkerr error) error {
if walkerr != nil {
return walkerr
}
if !info.IsDir() {
if v, e := NumericComponent(name); e == nil {
if v > version {
version = v
}
}
}
return nil
})
if version == -1 {
err = errors.New("no valid version found")
}
return
}
func CreateMigration(name, migrationType, dir string, t time.Time) (path string, err error) {
if migrationType != "go" && migrationType != "sql" {
return "", errors.New("migration type must be 'go' or 'sql'")
}
timestamp := t.Format("20060102150405")
filename := fmt.Sprintf("%v_%v.%v", timestamp, name, migrationType)
fpath := filepath.Join(dir, filename)
var tmpl *template.Template
if migrationType == "sql" {
tmpl = sqlMigrationTemplate
} else {
tmpl = goMigrationTemplate
}
path, err = writeTemplateToFile(fpath, tmpl, timestamp)
return
}
// Update the version table for the given migration,
// and finalize the transaction.
func FinalizeMigration(conf *DBConf, txn *sql.Tx, direction bool, v int64) error {
// XXX: drop goose_db_version table on some minimum version number?
stmt := conf.Driver.Dialect.insertVersionSql()
if _, err := txn.Exec(stmt, v, direction); err != nil {
txn.Rollback()
return err
}
return txn.Commit()
}
var goMigrationTemplate = template.Must(template.New("goose.go-migration").Parse(`
package main
import (
"database/sql"
)
// Up is executed when this migration is applied
func Up_{{ . }}(txn *sql.Tx) {
}
// Down is executed when this migration is rolled back
func Down_{{ . }}(txn *sql.Tx) {
}
`))
var sqlMigrationTemplate = template.Must(template.New("goose.sql-migration").Parse(`
-- +goose Up
-- SQL in section 'Up' is executed when this migration is applied
-- +goose Down
-- SQL section 'Down' is executed when this migration is rolled back
`))

@ -1,71 +0,0 @@
package goose
import (
"testing"
)
func TestMigrationMapSortUp(t *testing.T) {
ms := migrationSorter{}
// insert in any order
ms = append(ms, newMigration(20120000, "test"))
ms = append(ms, newMigration(20128000, "test"))
ms = append(ms, newMigration(20129000, "test"))
ms = append(ms, newMigration(20127000, "test"))
ms.Sort(true) // sort Upwards
sorted := []int64{20120000, 20127000, 20128000, 20129000}
validateMigrationSort(t, ms, sorted)
}
func TestMigrationMapSortDown(t *testing.T) {
ms := migrationSorter{}
// insert in any order
ms = append(ms, newMigration(20120000, "test"))
ms = append(ms, newMigration(20128000, "test"))
ms = append(ms, newMigration(20129000, "test"))
ms = append(ms, newMigration(20127000, "test"))
ms.Sort(false) // sort Downwards
sorted := []int64{20129000, 20128000, 20127000, 20120000}
validateMigrationSort(t, ms, sorted)
}
func validateMigrationSort(t *testing.T, ms migrationSorter, sorted []int64) {
for i, m := range ms {
if sorted[i] != m.Version {
t.Error("incorrect sorted version")
}
var next, prev int64
if i == 0 {
prev = -1
next = ms[i+1].Version
} else if i == len(ms)-1 {
prev = ms[i-1].Version
next = -1
} else {
prev = ms[i-1].Version
next = ms[i+1].Version
}
if m.Next != next {
t.Errorf("mismatched Next. v: %v, got %v, wanted %v\n", m, m.Next, next)
}
if m.Previous != prev {
t.Errorf("mismatched Previous v: %v, got %v, wanted %v\n", m, m.Previous, prev)
}
}
t.Log(ms)
}

@ -1,137 +0,0 @@
package goose
import (
"bytes"
"encoding/gob"
"fmt"
"io/ioutil"
"log"
"os"
"os/exec"
"path/filepath"
"text/template"
)
type templateData struct {
Version int64
Import string
Conf string // gob encoded DBConf
Direction bool
Func string
InsertStmt string
}
func init() {
gob.Register(PostgresDialect{})
gob.Register(MySqlDialect{})
gob.Register(Sqlite3Dialect{})
}
//
// Run a .go migration.
//
// In order to do this, we copy a modified version of the
// original .go migration, and execute it via `go run` along
// with a main() of our own creation.
//
func runGoMigration(conf *DBConf, path string, version int64, direction bool) error {
// everything gets written to a temp dir, and zapped afterwards
d, e := ioutil.TempDir("", "goose")
if e != nil {
log.Fatal(e)
}
defer os.RemoveAll(d)
directionStr := "Down"
if direction {
directionStr = "Up"
}
var bb bytes.Buffer
if err := gob.NewEncoder(&bb).Encode(conf); err != nil {
return err
}
// XXX: there must be a better way of making this byte array
// available to the generated code...
// but for now, print an array literal of the gob bytes
var sb bytes.Buffer
sb.WriteString("[]byte{ ")
for _, b := range bb.Bytes() {
sb.WriteString(fmt.Sprintf("0x%02x, ", b))
}
sb.WriteString("}")
td := &templateData{
Version: version,
Import: conf.Driver.Import,
Conf: sb.String(),
Direction: direction,
Func: fmt.Sprintf("%v_%v", directionStr, version),
InsertStmt: conf.Driver.Dialect.insertVersionSql(),
}
main, e := writeTemplateToFile(filepath.Join(d, "goose_main.go"), goMigrationDriverTemplate, td)
if e != nil {
log.Fatal(e)
}
outpath := filepath.Join(d, filepath.Base(path))
if _, e = copyFile(outpath, path); e != nil {
log.Fatal(e)
}
cmd := exec.Command("go", "run", main, outpath)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
if e = cmd.Run(); e != nil {
log.Fatal("`go run` failed: ", e)
}
return nil
}
//
// template for the main entry point to a go-based migration.
// this gets linked against the substituted versions of the user-supplied
// scripts in order to execute a migration via `go run`
//
var goMigrationDriverTemplate = template.Must(template.New("goose.go-driver").Parse(`
package main
import (
"log"
"bytes"
"encoding/gob"
_ "{{.Import}}"
"bitbucket.org/liamstask/goose/lib/goose"
)
func main() {
var conf goose.DBConf
buf := bytes.NewBuffer({{ .Conf }})
if err := gob.NewDecoder(buf).Decode(&conf); err != nil {
log.Fatal("gob.Decode - ", err)
}
db, err := goose.OpenDBFromDBConf(&conf)
if err != nil {
log.Fatal("failed to open DB:", err)
}
defer db.Close()
txn, err := db.Begin()
if err != nil {
log.Fatal("db.Begin:", err)
}
{{ .Func }}(txn)
err = goose.FinalizeMigration(&conf, txn, {{ .Direction }}, {{ .Version }})
if err != nil {
log.Fatal("Commit() failed:", err)
}
}
`))

@ -1,168 +0,0 @@
package goose
import (
"bufio"
"bytes"
"database/sql"
"io"
"log"
"os"
"path/filepath"
"strings"
)
const sqlCmdPrefix = "-- +goose "
// Checks the line to see if the line has a statement-ending semicolon
// or if the line contains a double-dash comment.
func endsWithSemicolon(line string) bool {
prev := ""
scanner := bufio.NewScanner(strings.NewReader(line))
scanner.Split(bufio.ScanWords)
for scanner.Scan() {
word := scanner.Text()
if strings.HasPrefix(word, "--") {
break
}
prev = word
}
return strings.HasSuffix(prev, ";")
}
// Split the given sql script into individual statements.
//
// The base case is to simply split on semicolons, as these
// naturally terminate a statement.
//
// However, more complex cases like pl/pgsql can have semicolons
// within a statement. For these cases, we provide the explicit annotations
// 'StatementBegin' and 'StatementEnd' to allow the script to
// tell us to ignore semicolons.
func splitSQLStatements(r io.Reader, direction bool) (stmts []string) {
var buf bytes.Buffer
scanner := bufio.NewScanner(r)
// track the count of each section
// so we can diagnose scripts with no annotations
upSections := 0
downSections := 0
statementEnded := false
ignoreSemicolons := false
directionIsActive := false
for scanner.Scan() {
line := scanner.Text()
// handle any goose-specific commands
if strings.HasPrefix(line, sqlCmdPrefix) {
cmd := strings.TrimSpace(line[len(sqlCmdPrefix):])
switch cmd {
case "Up":
directionIsActive = (direction == true)
upSections++
break
case "Down":
directionIsActive = (direction == false)
downSections++
break
case "StatementBegin":
if directionIsActive {
ignoreSemicolons = true
}
break
case "StatementEnd":
if directionIsActive {
statementEnded = (ignoreSemicolons == true)
ignoreSemicolons = false
}
break
}
}
if !directionIsActive {
continue
}
if _, err := buf.WriteString(line + "\n"); err != nil {
log.Fatalf("io err: %v", err)
}
// Wrap up the two supported cases: 1) basic with semicolon; 2) psql statement
// Lines that end with semicolon that are in a statement block
// do not conclude statement.
if (!ignoreSemicolons && endsWithSemicolon(line)) || statementEnded {
statementEnded = false
stmts = append(stmts, buf.String())
buf.Reset()
}
}
if err := scanner.Err(); err != nil {
log.Fatalf("scanning migration: %v", err)
}
// diagnose likely migration script errors
if ignoreSemicolons {
log.Println("WARNING: saw '-- +goose StatementBegin' with no matching '-- +goose StatementEnd'")
}
if bufferRemaining := strings.TrimSpace(buf.String()); len(bufferRemaining) > 0 {
log.Printf("WARNING: Unexpected unfinished SQL query: %s. Missing a semicolon?\n", bufferRemaining)
}
if upSections == 0 && downSections == 0 {
log.Fatalf(`ERROR: no Up/Down annotations found, so no statements were executed.
See https://bitbucket.org/liamstask/goose/overview for details.`)
}
return
}
// Run a migration specified in raw SQL.
//
// Sections of the script can be annotated with a special comment,
// starting with "-- +goose" to specify whether the section should
// be applied during an Up or Down migration
//
// All statements following an Up or Down directive are grouped together
// until another direction directive is found.
func runSQLMigration(conf *DBConf, db *sql.DB, scriptFile string, v int64, direction bool) error {
txn, err := db.Begin()
if err != nil {
log.Fatal("db.Begin:", err)
}
f, err := os.Open(scriptFile)
if err != nil {
log.Fatal(err)
}
// find each statement, checking annotations for up/down direction
// and execute each of them in the current transaction.
// Commits the transaction if successfully applied each statement and
// records the version into the version table or returns an error and
// rolls back the transaction.
for _, query := range splitSQLStatements(f, direction) {
if _, err = txn.Exec(query); err != nil {
txn.Rollback()
log.Fatalf("FAIL %s (%v), quitting migration.", filepath.Base(scriptFile), err)
return err
}
}
if err = FinalizeMigration(conf, txn, direction, v); err != nil {
log.Fatalf("error finalizing migration %s, quitting. (%v)", filepath.Base(scriptFile), err)
}
return nil
}

@ -1,147 +0,0 @@
package goose
import (
"strings"
"testing"
)
func TestSemicolons(t *testing.T) {
type testData struct {
line string
result bool
}
tests := []testData{
{
line: "END;",
result: true,
},
{
line: "END; -- comment",
result: true,
},
{
line: "END ; -- comment",
result: true,
},
{
line: "END -- comment",
result: false,
},
{
line: "END -- comment ;",
result: false,
},
{
line: "END \" ; \" -- comment",
result: false,
},
}
for _, test := range tests {
r := endsWithSemicolon(test.line)
if r != test.result {
t.Errorf("incorrect semicolon. got %v, want %v", r, test.result)
}
}
}
func TestSplitStatements(t *testing.T) {
type testData struct {
sql string
direction bool
count int
}
tests := []testData{
{
sql: functxt,
direction: true,
count: 2,
},
{
sql: functxt,
direction: false,
count: 2,
},
{
sql: multitxt,
direction: true,
count: 2,
},
{
sql: multitxt,
direction: false,
count: 2,
},
}
for _, test := range tests {
stmts := splitSQLStatements(strings.NewReader(test.sql), test.direction)
if len(stmts) != test.count {
t.Errorf("incorrect number of stmts. got %v, want %v", len(stmts), test.count)
}
}
}
var functxt = `-- +goose Up
CREATE TABLE IF NOT EXISTS histories (
id BIGSERIAL PRIMARY KEY,
current_value varchar(2000) NOT NULL,
created_at timestamp with time zone NOT NULL
);
-- +goose StatementBegin
CREATE OR REPLACE FUNCTION histories_partition_creation( DATE, DATE )
returns void AS $$
DECLARE
create_query text;
BEGIN
FOR create_query IN SELECT
'CREATE TABLE IF NOT EXISTS histories_'
|| TO_CHAR( d, 'YYYY_MM' )
|| ' ( CHECK( created_at >= timestamp '''
|| TO_CHAR( d, 'YYYY-MM-DD 00:00:00' )
|| ''' AND created_at < timestamp '''
|| TO_CHAR( d + INTERVAL '1 month', 'YYYY-MM-DD 00:00:00' )
|| ''' ) ) inherits ( histories );'
FROM generate_series( $1, $2, '1 month' ) AS d
LOOP
EXECUTE create_query;
END LOOP; -- LOOP END
END; -- FUNCTION END
$$
language plpgsql;
-- +goose StatementEnd
-- +goose Down
drop function histories_partition_creation(DATE, DATE);
drop TABLE histories;
`
// test multiple up/down transitions in a single script
var multitxt = `-- +goose Up
CREATE TABLE post (
id int NOT NULL,
title text,
body text,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE post;
-- +goose Up
CREATE TABLE fancier_post (
id int NOT NULL,
title text,
body text,
created_on timestamp without time zone,
PRIMARY KEY(id)
);
-- +goose Down
DROP TABLE fancier_post;
`

@ -1,40 +0,0 @@
package goose
import (
"io"
"os"
"text/template"
)
// common routines
func writeTemplateToFile(path string, t *template.Template, data interface{}) (string, error) {
f, e := os.Create(path)
if e != nil {
return "", e
}
defer f.Close()
e = t.Execute(f, data)
if e != nil {
return "", e
}
return f.Name(), nil
}
func copyFile(dst, src string) (int64, error) {
sf, err := os.Open(src)
if err != nil {
return 0, err
}
defer sf.Close()
df, err := os.Create(dst)
if err != nil {
return 0, err
}
defer df.Close()
return io.Copy(df, sf)
}

@ -0,0 +1,17 @@
language: go
go:
- 1.5
services:
- postgres
before_install:
- go install -a -race std
env:
- DB=sqlite3
- DB=postgres
script:
- go test -v -race ./... -test.database=$DB

@ -0,0 +1,23 @@
Copyright (c) 2015, Remind101, Inc.
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

@ -0,0 +1,57 @@
# Migrate
[![Build Status](https://travis-ci.org/remind101/migrate.svg?branch=master)](https://travis-ci.org/remind101/migrate)
Migrate is a Go library for doing migrations. It's stupidly simple and gets out of your way.
## Features
* It's only dependency is `database/sql`.
* It supports any type of migration you want to run (e.g. raw sql, or Go code).
* It doesn't provide a command. It's designed to be embedded in projects and used exclusively as a library.
## Usage
```go
migrations := []migrate.Migration{
{
ID: 1,
Up: func(tx *sql.Tx) error {
_, err := tx.Exec("CREATE TABLE people (id int)")
return err
},
Down: func(tx *sql.Tx) error {
_, err := tx.Exec("DROP TABLE people")
return err
},
},
{
ID: 2,
// For simple sql migrations, you can use the migrate.Queries
// helper.
Up: migrate.Queries([]string{
"ALTER TABLE people ADD COLUMN first_name text",
}),
Down: func(tx *sql.Tx) error {
// It's not possible to remove a column with
// sqlite.
_, err := tx.Exec("SELECT 1 FROM people")
return err
},
},
}
db, _ := sql.Open("sqlite3", ":memory:")
_ = migrate.Exec(db, migrate.Up, migrations...)
```
### Locking
All migrations are run in a transaction, but if you attempt to run a single long running migration concurrently, you could run into a deadlock. For Postgres connections, `migrate` can use [pg_advisory_lock](http://www.postgresql.org/docs/9.1/static/explicit-locking.html) to ensure that only 1 migration is run at a time.
To use this, simply instantiate a `Migrator` instance using `migrate.NewPostgresMigrator`:
```go
migrator := NewPostgresMigrator(db)
_ = migrator.Exec(migrate.Up, migrations...)
```

@ -0,0 +1,41 @@
package migrate_test
import (
"database/sql"
_ "github.com/mattn/go-sqlite3"
"github.com/remind101/migrate"
)
func Example() {
migrations := []migrate.Migration{
{
ID: 1,
Up: func(tx *sql.Tx) error {
_, err := tx.Exec("CREATE TABLE people (id int)")
return err
},
Down: func(tx *sql.Tx) error {
_, err := tx.Exec("DROP TABLE people")
return err
},
},
{
ID: 2,
// For simple sql migrations, you can use the migrate.Queries
// helper.
Up: migrate.Queries([]string{
"ALTER TABLE people ADD COLUMN first_name text",
}),
Down: func(tx *sql.Tx) error {
// It's not possible to remove a column with
// sqlite.
_, err := tx.Exec("SELECT 1 FROM people")
return err
},
},
}
db, _ := sql.Open("sqlite3", ":memory:")
_ = migrate.Exec(db, migrate.Up, migrations...)
}

@ -0,0 +1,293 @@
// Package migrate provides a dead simple Go package for performing sql
// migrations using database/sql.
package migrate
import (
"database/sql"
"fmt"
"hash/crc32"
"sort"
"sync"
)
type MigrationDirection int
const (
Up MigrationDirection = iota
Down
)
type TransactionMode int
const (
// In this mode, each migration is run in it's own isolated transaction.
// If a migration fails, only that migration will be rolled back.
IndividualTransactions TransactionMode = iota
// In this mode, all migrations are run inside a single transaction. If
// one migration fails, all migrations are rolled back.
SingleTransaction
)
// MigrationError is an error that gets returned when an individual migration
// fails.
type MigrationError struct {
Migration
// The underlying error.
Err error
}
// Error implements the error interface.
func (e *MigrationError) Error() string {
return fmt.Sprintf("migration %d failed: %v", e.ID, e.Err)
}
// The default table to store what migrations have been run.
const DefaultTable = "schema_migrations"
// Migration represents a sql migration that can be migrated up or down.
type Migration struct {
// ID is a unique, numeric, identifier for this migration.
ID int
// Up is a function that gets called when this migration should go up.
Up func(tx *sql.Tx) error
// Down is a function that gets called when this migration should go
// down.
Down func(tx *sql.Tx) error
}
// byID implements the sort.Interface interface for sorting migrations by
// ID.
type byID []Migration
func (m byID) Len() int { return len(m) }
func (m byID) Less(i, j int) bool { return m[i].ID < m[j].ID }
func (m byID) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
// Migrator performs migrations.
type Migrator struct {
// Table is the table to store what migrations have been run. The zero
// value is DefaultTable.
Table string
// Locker is a sync.Locker to use to ensure that only 1 process is
// running migrations.
sync.Locker
// The TransactionMode to use. The zero value is IndividualTransactions,
// which runs each migration in it's own transaction.
TransactionMode TransactionMode
db *sql.DB
}
// postgresLocker implements the sync.Locker interface using pg_advisory_lock.
type postgresLocker struct {
key uint32
db *sql.DB
}
// NewPostgresLocker returns a new sync.Locker that obtains locks with
// pg_advisory_lock.
func newPostgresLocker(db *sql.DB) sync.Locker {
key := crc32.ChecksumIEEE([]byte("migrations"))
return &postgresLocker{
key: key,
db: db,
}
}
// Lock obtains the advisory lock.
func (l *postgresLocker) Lock() {
l.do("lock")
}
// Unlock removes the advisory Lock
func (l *postgresLocker) Unlock() {
l.do("unlock")
}
func (l *postgresLocker) do(m string) {
_, err := l.db.Exec(fmt.Sprintf("SELECT pg_advisory_%s(%d)", m, l.key))
if err != nil {
panic(fmt.Sprintf("migrate: %v", err))
}
}
// NewMigrator returns a new Migrator instance that will use the sql.DB to
// perform the migrations.
func NewMigrator(db *sql.DB) *Migrator {
return &Migrator{
db: db,
Locker: new(sync.Mutex),
}
}
// NewPostgresMigrator returns a new Migrator instance that uses the underlying
// sql.DB connection to a postgres database to perform migrations. It will use
// Postgres's advisory locks to ensure that only 1 migration is run at a time.
func NewPostgresMigrator(db *sql.DB) *Migrator {
m := NewMigrator(db)
m.Locker = newPostgresLocker(db)
return m
}
// Exec runs the migrations in the given direction.
func (m *Migrator) Exec(dir MigrationDirection, migrations ...Migration) error {
m.Lock()
defer m.Unlock()
_, err := m.db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version integer primary key not null)", m.table()))
if err != nil {
return err
}
var tx *sql.Tx
if m.TransactionMode == SingleTransaction {
tx, err = m.db.Begin()
if err != nil {
return err
}
}
for _, migration := range sortMigrations(dir, migrations) {
if m.TransactionMode == IndividualTransactions {
tx, err = m.db.Begin()
if err != nil {
return err
}
}
if err := m.runMigration(tx, dir, migration); err != nil {
tx.Rollback()
return err
}
if m.TransactionMode == IndividualTransactions {
if err := tx.Commit(); err != nil {
return err
}
}
}
if m.TransactionMode == SingleTransaction {
if err := tx.Commit(); err != nil {
return err
}
}
return nil
}
// runMigration runs the given Migration in the given direction using the given
// transaction. This function does not commit or rollback the transaction,
// that's the responsibility of the consumer dependending on whether an error
// gets returned.
func (m *Migrator) runMigration(tx *sql.Tx, dir MigrationDirection, migration Migration) error {
shouldMigrate, err := m.shouldMigrate(tx, migration.ID, dir)
if err != nil {
return err
}
if !shouldMigrate {
return nil
}
var migrate func(tx *sql.Tx) error
switch dir {
case Up:
migrate = migration.Up
default:
migrate = migration.Down
}
if err := migrate(tx); err != nil {
return &MigrationError{Migration: migration, Err: err}
}
var query string
switch dir {
case Up:
// Yes. This is a sql injection vulnerability. This gets around
// the different bindings for sqlite3/postgres.
//
// If you're running migrations from user input, you're doing
// something wrong.
query = fmt.Sprintf("INSERT INTO %s (version) VALUES (%d)", m.table(), migration.ID)
default:
query = fmt.Sprintf("DELETE FROM %s WHERE version = %d", m.table(), migration.ID)
}
_, err = tx.Exec(query)
return err
}
func (m *Migrator) shouldMigrate(tx *sql.Tx, id int, dir MigrationDirection) (bool, error) {
// Check if this migration has already ran
var _id int
err := tx.QueryRow(fmt.Sprintf("SELECT version FROM %s WHERE version = %d", m.table(), id)).Scan(&_id)
if err != nil && err != sql.ErrNoRows {
return false, err
}
switch dir {
case Up:
// If the migration doesn't exist, then we need to run it.
return err == sql.ErrNoRows, nil
default:
// If the migration exists, then we need to remove it.
return err != sql.ErrNoRows, nil
}
}
// table returns the name of the table to use to track the migrations.
func (m *Migrator) table() string {
if m.Table == "" {
return DefaultTable
}
return m.Table
}
// Exec is a convenience method that runs the migrations against the default
// table.
func Exec(db *sql.DB, dir MigrationDirection, migrations ...Migration) error {
return NewMigrator(db).Exec(dir, migrations...)
}
// Queries returns a func(tx *sql.Tx) error function that performs the given sql
// queries in multiple Exec calls.
func Queries(queries []string) func(*sql.Tx) error {
return func(tx *sql.Tx) error {
for _, query := range queries {
if _, err := tx.Exec(query); err != nil {
return err
}
}
return nil
}
}
// sortMigrations sorts the migrations by id.
//
// When the direction is "Up", the migrations will be sorted by ID ascending.
// When the direction is "Down", the migrations will be sorted by ID descending.
func sortMigrations(dir MigrationDirection, migrations []Migration) []Migration {
var m byID
for _, migration := range migrations {
m = append(m, migration)
}
switch dir {
case Up:
sort.Sort(byID(m))
default:
sort.Sort(sort.Reverse(byID(m)))
}
return m
}

@ -0,0 +1,343 @@
package migrate_test
import (
"database/sql"
"errors"
"flag"
"fmt"
"os"
"os/exec"
"strings"
"testing"
_ "github.com/lib/pq"
_ "github.com/mattn/go-sqlite3"
"github.com/remind101/migrate"
"github.com/stretchr/testify/assert"
)
const (
Sqlite = "sqlite3"
Postgres = "postgres"
)
// A flag to determine what database to run the suite against.
var database = flag.String("test.database", Sqlite, "The name of the database to run against. (sqlite3, postgres).")
var testMigrations = []migrate.Migration{
{
ID: 1,
Up: func(tx *sql.Tx) error {
_, err := tx.Exec("CREATE TABLE people (id int)")
return err
},
Down: func(tx *sql.Tx) error {
_, err := tx.Exec("DROP TABLE people")
return err
},
},
{
ID: 2,
// For simple sql migrations, you can use the migrate.Queries
// helper.
Up: migrate.Queries([]string{
"ALTER TABLE people ADD COLUMN first_name text",
}),
Down: func(tx *sql.Tx) error {
// It's not possible to remove a column with
// sqlite.
_, err := tx.Exec("SELECT 1 FROM people")
return err
},
},
}
func TestMigrate(t *testing.T) {
db := newDB(t)
defer db.Close()
migrations := testMigrations[:]
err := migrate.Exec(db, migrate.Up, migrations...)
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
err = migrate.Exec(db, migrate.Down, migrations...)
assert.NoError(t, err)
assert.Equal(t, []int{}, appliedMigrations(t, db))
assertSchema(t, ``, db)
}
func TestMigrate_Individual(t *testing.T) {
db := newDB(t)
defer db.Close()
err := migrate.Exec(db, migrate.Up, testMigrations[0])
assert.NoError(t, err)
assert.Equal(t, []int{1}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
err = migrate.Exec(db, migrate.Up, testMigrations[1])
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}
func TestMigrate_AlreadyRan(t *testing.T) {
db := newDB(t)
defer db.Close()
migration := testMigrations[0]
err := migrate.Exec(db, migrate.Up, migration)
assert.NoError(t, err)
assert.Equal(t, []int{1}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
err = migrate.Exec(db, migrate.Up, migration)
assert.NoError(t, err)
assert.Equal(t, []int{1}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int)
`, db)
}
func TestMigrate_SingleTransactionMode_Rollback(t *testing.T) {
db := newDB(t)
defer db.Close()
migrator := migrate.NewMigrator(db)
migrator.TransactionMode = migrate.SingleTransaction
migrations := []migrate.Migration{
testMigrations[0],
testMigrations[1],
migrate.Migration{
ID: 3,
Up: func(tx *sql.Tx) error {
return errors.New("Rollback")
},
},
}
err := migrator.Exec(migrate.Up, migrations...)
assert.Error(t, err)
assert.Equal(t, []int{}, appliedMigrations(t, db))
assertSchema(t, ``, db)
}
func TestMigrate_SingleTransactionMode_Commit(t *testing.T) {
db := newDB(t)
defer db.Close()
migrator := migrate.NewMigrator(db)
migrator.TransactionMode = migrate.SingleTransaction
err := migrator.Exec(migrate.Up, testMigrations...)
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}
func TestMigrate_Order(t *testing.T) {
db := newDB(t)
defer db.Close()
migrations := []migrate.Migration{
testMigrations[1],
testMigrations[0],
}
err := migrate.Exec(db, migrate.Up, migrations...)
assert.NoError(t, err)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
}
func TestMigrate_Rollback(t *testing.T) {
db := newDB(t)
defer db.Close()
migration := migrate.Migration{
ID: 1,
Up: func(tx *sql.Tx) error {
// This should completely ok
if _, err := tx.Exec("CREATE TABLE people (id int)"); err != nil {
return err
}
// This should throw an error
if _, err := tx.Exec("ALTER TABLE foo ADD COLUMN first_name text"); err != nil {
return err
}
return nil
},
}
err := migrate.Exec(db, migrate.Up, migration)
assert.Error(t, err)
assert.Equal(t, []int{}, appliedMigrations(t, db))
// If the transaction wasn't rolled back, we'd see a people table.
assertSchema(t, ``, db)
assert.IsType(t, &migrate.MigrationError{}, err)
}
func TestMigrate_Locking(t *testing.T) {
db := newDB(t)
defer db.Close()
migrator := migrate.NewMigrator(db)
if *database == Postgres {
migrator = migrate.NewPostgresMigrator(db)
}
err := migrator.Exec(migrate.Up, testMigrations...)
assert.NoError(t, err)
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
assert.Equal(t, []int{1, 2}, appliedMigrations(t, db))
var called int
// Generates a migration that sends on the given channel when it starts.
migration := migrate.Migration{
ID: 3,
Up: func(tx *sql.Tx) error {
called++
_, err := tx.Exec(`INSERT INTO people (id, first_name) VALUES (1, 'Eric')`)
return err
},
}
m1 := make(chan error)
m2 := make(chan error)
// Start two migrations in parallel.
go func() {
m1 <- migrator.Exec(migrate.Up, migration)
}()
go func() {
m2 <- migrator.Exec(migrate.Up, migration)
}()
assert.Nil(t, <-m1)
assert.Nil(t, <-m2)
assert.Equal(t, 1, called)
assertSchema(t, `
people
CREATE TABLE people (id int, first_name text)
`, db)
assert.Equal(t, []int{1, 2, 3}, appliedMigrations(t, db))
}
func assertSchema(t testing.TB, expectedSchema string, db *sql.DB) {
if *database == Sqlite {
schema, err := sqliteSchema(db)
if err != nil {
t.Fatal(err)
}
assert.Equal(t, strings.TrimSpace(expectedSchema), schema)
}
}
func sqliteSchema(db *sql.DB) (string, error) {
var tables []string
rows, err := db.Query(`SELECT name, sql FROM sqlite_master
WHERE type='table'
ORDER BY name;`)
if err != nil {
return "", err
}
defer rows.Close()
for rows.Next() {
var name, sql string
if err := rows.Scan(&name, &sql); err != nil {
return "", err
}
if name == migrate.DefaultTable {
continue
}
tables = append(tables, fmt.Sprintf("%s\n%s", name, sql))
}
return strings.Join(tables, "\n\n"), nil
}
func appliedMigrations(t testing.TB, db *sql.DB) []int {
rows, err := db.Query("SELECT version FROM " + migrate.DefaultTable)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
ids := []int{}
for rows.Next() {
var id int
if err := rows.Scan(&id); err != nil {
t.Fatal(err)
}
ids = append(ids, id)
}
return ids
}
// factory methods to open a database connection to a type of database.
var databases = map[string]func() (*sql.DB, error){
Postgres: func() (*sql.DB, error) {
name := "migrate_test"
command := func(name string, arg ...string) *exec.Cmd {
cmd := exec.Command(name, arg...)
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
return cmd
}
command("dropdb", name).Run()
if err := command("createdb", name).Run(); err != nil {
return nil, err
}
return sql.Open("postgres", fmt.Sprintf("postgres://localhost/%s?sslmode=disable", name))
},
Sqlite: func() (*sql.DB, error) {
os.Remove("migrate_test.db")
return sql.Open("sqlite3", "migrate_test.db?cache=shared&mode=wrc")
},
}
func newDB(t testing.TB) *sql.DB {
open, ok := databases[*database]
if !ok {
t.Fatal(fmt.Sprintf("Unknown database: %s", *database))
}
db, err := open()
if err != nil {
t.Fatal(err)
}
return db
}
Loading…
Cancel
Save