diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index ba13ed46..06c93d82 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -19,6 +19,7 @@ import ( "database/sql" "fmt" "io/ioutil" + "os" "path" "runtime" "strings" @@ -174,9 +175,17 @@ func dropDatabase(dataSource, databaseName string) error { } defer db.Close() + // Kill any opened connection. + if _, err := db.Exec(` + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = $1 + AND pid <> pg_backend_pid()`, databaseName); err != nil { + return fmt.Errorf("could not drop database: %v", err) + } + // Drop database. - _, err = db.Exec("DROP DATABASE " + databaseName) - if err != nil { + if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil { return fmt.Errorf("could not drop database: %v", err) } @@ -187,33 +196,33 @@ func dropDatabase(dataSource, databaseName string) error { // Its Close() method drops the database. type pgSQLTest struct { *pgSQL - dataSource string - dbName string -} - -func (pgSQL *pgSQLTest) Close() { - pgSQL.DB.Close() - dropDatabase(pgSQL.dataSource+"dbname=postgres", pgSQL.dbName) + dataSourceDefaultDatabase string + dbName string } // OpenForTest creates a test Datastore backed by a new PostgreSQL database. // It creates a new unique and prefixed ("test_") database. // Using Close() will drop the database. func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { - dataSource := "host=127.0.0.1 sslmode=disable " + // Define the PostgreSQL connection strings. + dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname=" + if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" { + dataSource = dataSourceEnv + " dbname=" + } dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1) + dataSourceDefaultDatabase := dataSource + "postgres" + dataSourceTestDatabase := dataSource + dbName // Create database. - err := createDatabase(dataSource+"dbname=postgres", dbName) - if err != nil { + if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil { log.Error(err) return nil, database.ErrCantOpen } // Open database. - db, err := Open(&config.DatabaseConfig{Source: dataSource + "dbname=" + dbName, CacheSize: 0}) + db, err := Open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0}) if err != nil { - dropDatabase(dataSource, dbName) + dropDatabase(dataSourceDefaultDatabase, dbName) log.Error(err) return nil, database.ErrCantOpen } @@ -224,13 +233,21 @@ func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql") _, err = db.(*pgSQL).Exec(string(d)) if err != nil { - dropDatabase(dataSource+"dbname=postgres", dbName) + dropDatabase(dataSourceDefaultDatabase, dbName) log.Error(err) return nil, database.ErrCantOpen } } - return &pgSQLTest{pgSQL: db.(*pgSQL), dataSource: dataSource, dbName: dbName}, nil + return &pgSQLTest{ + pgSQL: db.(*pgSQL), + dataSourceDefaultDatabase: dataSourceDefaultDatabase, + dbName: dbName}, nil +} + +func (pgSQL *pgSQLTest) Close() { + pgSQL.DB.Close() + dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName) } // handleError logs an error with an extra description and masks the error if it's an SQL one.