clair/database/pgsql/testutil/testdb.go

208 lines
4.5 KiB
Go
Raw Permalink Normal View History

// Copyright 2019 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package testutil
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
"testing"
"time"
"github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/pkg/pagination"
"github.com/remind101/migrate"
)
var TestPaginationKey = pagination.Must(pagination.NewKey())
var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';`
func CreateAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) {
uri := "postgres@127.0.0.1:5432"
connectionTemplate := "postgresql://%s?sslmode=disable"
if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" {
uri = envURI
}
db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri))
if err != nil {
panic(err)
}
testName = strings.ToLower(testName)
dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05"))
t.Logf("creating temporary database name = %s", dbName)
_, err = db.Exec("CREATE DATABASE " + dbName)
if err != nil {
panic(err)
}
testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName))
if err != nil {
panic(err)
}
return testDB, func() {
cleanupTestDB(t, dbName, db, testDB)
}
}
func cleanupTestDB(t *testing.T, name string, db, testDB *sql.DB) {
t.Logf("cleaning up temporary database %s", name)
if db == nil {
panic("db is none")
}
if testDB == nil {
panic("testDB is none")
}
defer db.Close()
if err := testDB.Close(); err != nil {
panic(err)
}
// Kill any opened connection.
if _, err := db.Exec(`
SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity
WHERE pg_stat_activity.datname = $1
AND pid <> pg_backend_pid()`, name); err != nil {
panic(err)
}
if _, err := db.Exec(`DROP DATABASE ` + name); err != nil {
panic(err)
}
// ensure the database is cleaned up
var count int
if err := db.QueryRow(userDBCount).Scan(&count); err != nil {
panic(err)
}
}
func CreateTestDB(t *testing.T, testName string) (*sql.DB, func()) {
connection, cleanup := CreateAndConnectTestDB(t, testName)
err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...)
if err != nil {
panic(err)
}
return connection, cleanup
}
func CreateTestDBWithFixture(t *testing.T, testName string) (*sql.DB, func()) {
connection, cleanup := CreateTestDB(t, testName)
session, err := connection.Begin()
if err != nil {
panic(err)
}
defer session.Rollback()
loadFixtures(session)
if err = session.Commit(); err != nil {
panic(err)
}
return connection, cleanup
}
func CreateTestTx(t *testing.T, testName string) (*sql.Tx, func()) {
connection, cleanup := CreateTestDB(t, testName)
session, err := connection.Begin()
if session == nil {
panic("session is none")
}
if err != nil {
panic(err)
}
return session, func() {
session.Rollback()
cleanup()
}
}
func CreateTestTxWithFixtures(t *testing.T, testName string) (*sql.Tx, func()) {
tx, cleanup := CreateTestTx(t, testName)
defer func() {
// ensure to cleanup when loadFixtures failed
if r := recover(); r != nil {
cleanup()
}
}()
loadFixtures(tx)
return tx, cleanup
}
func loadFixtures(tx *sql.Tx) {
_, filename, _, _ := runtime.Caller(0)
fixturePath := filepath.Join(filepath.Dir(filename), "data.sql")
d, err := ioutil.ReadFile(fixturePath)
if err != nil {
panic(err)
}
_, err = tx.Exec(string(d))
if err != nil {
panic(err)
}
}
func OpenSessionForTest(t *testing.T, name string, loadFixture bool) (*sql.DB, *sql.Tx) {
var db *sql.DB
if loadFixture {
db, _ = CreateTestDB(t, name)
} else {
db, _ = CreateTestDBWithFixture(t, name)
}
tx, err := db.Begin()
if err != nil {
panic(err)
}
return db, tx
}
func RestartTransaction(db *sql.DB, tx *sql.Tx, commit bool) *sql.Tx {
if !commit {
if err := tx.Rollback(); err != nil {
panic(err)
}
} else {
if err := tx.Commit(); err != nil {
panic(err)
}
}
tx, err := db.Begin()
if err != nil {
panic(err)
}
return tx
}