pgsql: Add proper tests for database migration
This commit is contained in:
parent
f61675355e
commit
073c685c5b
60
database/pgsql/migrations_test.go
Normal file
60
database/pgsql/migrations_test.go
Normal file
@ -0,0 +1,60 @@
|
||||
// Copyright 2019 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"github.com/remind101/migrate"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
)
|
||||
|
||||
var userTableCount = `SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname='public'`
|
||||
|
||||
func TestMigration(t *testing.T) {
|
||||
db, cleanup := createAndConnectTestDB(t, "TestMigration")
|
||||
defer cleanup()
|
||||
|
||||
err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...)
|
||||
if err != nil {
|
||||
require.Nil(t, err, err.Error())
|
||||
}
|
||||
|
||||
err = migrate.NewPostgresMigrator(db).Exec(migrate.Down, migrations.Migrations...)
|
||||
if err != nil {
|
||||
require.Nil(t, err, err.Error())
|
||||
}
|
||||
|
||||
rows, err := db.Query(userTableCount)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
var (
|
||||
tables []string
|
||||
table string
|
||||
)
|
||||
for rows.Next() {
|
||||
if err = rows.Scan(&table); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
require.True(t, len(tables) == 1 && tables[0] == "schema_migrations", "Only `schema_migrations` should be left")
|
||||
}
|
@ -15,21 +15,34 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/remind101/migrate"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
)
|
||||
|
||||
// int keys must be the consistent with the database ID.
|
||||
var (
|
||||
realFeatures = map[int]database.Feature{
|
||||
1: {"ourchat", "0.5", "ourchat", "0.5", "dpkg"},
|
||||
2: {"openssl", "1.0", "openssl", "1.0", "dpkg"},
|
||||
3: {"openssl", "2.0", "openssl", "2.0", "dpkg"},
|
||||
4: {"fake", "2.0", "fake", "2.0", "rpm"},
|
||||
1: {"ourchat", "0.5", "dpkg", "source"},
|
||||
2: {"openssl", "1.0", "dpkg", "source"},
|
||||
3: {"openssl", "2.0", "dpkg", "source"},
|
||||
4: {"fake", "2.0", "rpm", "source"},
|
||||
5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"},
|
||||
}
|
||||
|
||||
realNamespaces = map[int]database.Namespace{
|
||||
@ -146,6 +159,7 @@ var (
|
||||
Name: "ourchat",
|
||||
Version: "0.6",
|
||||
VersionFormat: "dpkg",
|
||||
Type: "source",
|
||||
},
|
||||
}
|
||||
|
||||
@ -260,3 +274,150 @@ func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
|
||||
|
||||
return token
|
||||
}
|
||||
|
||||
var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';`
|
||||
|
||||
func createAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) {
|
||||
uri := "postgres@127.0.0.1:5432"
|
||||
connectionTemplate := "postgresql://%s?sslmode=disable"
|
||||
if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" {
|
||||
uri = envURI
|
||||
}
|
||||
|
||||
db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
testName = strings.ToLower(testName)
|
||||
dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05"))
|
||||
t.Logf("creating temporary database name = %s", dbName)
|
||||
_, err = db.Exec("CREATE DATABASE " + dbName)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return testDB, func() {
|
||||
t.Logf("cleaning up temporary database %s", dbName)
|
||||
defer db.Close()
|
||||
if err := testDB.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec(`DROP DATABASE ` + dbName); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// ensure the database is cleaned up
|
||||
var count int
|
||||
if err := db.QueryRow(userDBCount).Scan(&count); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPgSQL(t *testing.T, testName string) (*pgSQL, func()) {
|
||||
connection, cleanup := createAndConnectTestDB(t, testName)
|
||||
err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...)
|
||||
if err != nil {
|
||||
require.Nil(t, err, err.Error())
|
||||
}
|
||||
|
||||
return &pgSQL{connection, nil, Config{PaginationKey: pagination.Must(pagination.NewKey()).String()}}, cleanup
|
||||
}
|
||||
|
||||
func createTestPgSQLWithFixtures(t *testing.T, testName string) (*pgSQL, func()) {
|
||||
connection, cleanup := createTestPgSQL(t, testName)
|
||||
session, err := connection.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
defer session.Rollback()
|
||||
|
||||
loadFixtures(session.(*pgSession))
|
||||
if err = session.Commit(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return connection, cleanup
|
||||
}
|
||||
|
||||
func createTestPgSession(t *testing.T, testName string) (*pgSession, func()) {
|
||||
connection, cleanup := createTestPgSQL(t, testName)
|
||||
session, err := connection.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return session.(*pgSession), func() {
|
||||
session.Rollback()
|
||||
cleanup()
|
||||
}
|
||||
}
|
||||
|
||||
func createTestPgSessionWithFixtures(t *testing.T, testName string) (*pgSession, func()) {
|
||||
tx, cleanup := createTestPgSession(t, testName)
|
||||
defer func() {
|
||||
// ensure to cleanup when loadFixtures failed
|
||||
if r := recover(); r != nil {
|
||||
cleanup()
|
||||
}
|
||||
}()
|
||||
|
||||
loadFixtures(tx)
|
||||
return tx, cleanup
|
||||
}
|
||||
|
||||
func loadFixtures(tx *pgSession) {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
fixturePath := filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql"
|
||||
d, err := ioutil.ReadFile(fixturePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(string(d))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
|
||||
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.AffectedFeature]bool{}
|
||||
for _, i := range expected {
|
||||
has[i] = false
|
||||
}
|
||||
for _, i := range actual {
|
||||
if visited, ok := has[i]; !ok {
|
||||
return false
|
||||
} else if visited {
|
||||
return false
|
||||
}
|
||||
has[i] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
|
||||
r := make([]database.Namespace, count)
|
||||
for i := 0; i < count; i++ {
|
||||
r[i] = database.Namespace{
|
||||
Name: fmt.Sprint(rand.Int()),
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user