From 8bebea3643e294bb11a1766ec450b1e518b0003b Mon Sep 17 00:00:00 2001 From: Sida Chen Date: Wed, 6 Mar 2019 16:22:29 -0500 Subject: [PATCH] pgsql: Split testutil.go into multiple files - assertion.go assertions used in pgsql tests - data.go contains go representation of data.sql - testdb.go contains test db/tx creation logic - testutil.go contains misc other things --- database/pgsql/testutil.go | 424 ------------------ database/pgsql/testutil/assertion.go | 98 ++++ database/pgsql/testutil/data.go | 171 +++++++ .../pgsql/{testdata => testutil}/data.sql | 0 database/pgsql/testutil/testdb.go | 207 +++++++++ database/pgsql/testutil/testutil.go | 94 ++++ 6 files changed, 570 insertions(+), 424 deletions(-) delete mode 100644 database/pgsql/testutil.go create mode 100644 database/pgsql/testutil/assertion.go create mode 100644 database/pgsql/testutil/data.go rename database/pgsql/{testdata => testutil}/data.sql (100%) create mode 100644 database/pgsql/testutil/testdb.go create mode 100644 database/pgsql/testutil/testutil.go diff --git a/database/pgsql/testutil.go b/database/pgsql/testutil.go deleted file mode 100644 index b1bfabf7..00000000 --- a/database/pgsql/testutil.go +++ /dev/null @@ -1,424 +0,0 @@ -// Copyright 2018 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "database/sql" - "fmt" - "io/ioutil" - "math/rand" - "os" - "path/filepath" - "runtime" - "strings" - "testing" - "time" - - "github.com/remind101/migrate" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "github.com/coreos/clair/database" - "github.com/coreos/clair/database/pgsql/migrations" - "github.com/coreos/clair/pkg/pagination" -) - -// int keys must be the consistent with the database ID. -var ( - realFeatures = map[int]database.Feature{ - 1: {"ourchat", "0.5", "dpkg", "source"}, - 2: {"openssl", "1.0", "dpkg", "source"}, - 3: {"openssl", "2.0", "dpkg", "source"}, - 4: {"fake", "2.0", "rpm", "source"}, - 5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"}, - } - - realNamespaces = map[int]database.Namespace{ - 1: {"debian:7", "dpkg"}, - 2: {"debian:8", "dpkg"}, - 3: {"fake:1.0", "rpm"}, - 4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"}, - } - - realNamespacedFeatures = map[int]database.NamespacedFeature{ - 1: {realFeatures[1], realNamespaces[1]}, - 2: {realFeatures[2], realNamespaces[1]}, - 3: {realFeatures[2], realNamespaces[2]}, - 4: {realFeatures[3], realNamespaces[1]}, - } - - realDetectors = map[int]database.Detector{ - 1: database.NewNamespaceDetector("os-release", "1.0"), - 2: database.NewFeatureDetector("dpkg", "1.0"), - 3: database.NewFeatureDetector("rpm", "1.0"), - 4: database.NewNamespaceDetector("apt-sources", "1.0"), - } - - realLayers = map[int]database.Layer{ - 2: { - Hash: "layer-1", - By: []database.Detector{realDetectors[1], realDetectors[2]}, - Features: []database.LayerFeature{ - {realFeatures[1], realDetectors[2], database.Namespace{}}, - {realFeatures[2], realDetectors[2], database.Namespace{}}, - }, - Namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, - }, - }, - 6: { - Hash: "layer-4", - By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]}, - Features: []database.LayerFeature{ - {realFeatures[4], realDetectors[3], database.Namespace{}}, - {realFeatures[3], realDetectors[2], database.Namespace{}}, - }, - Namespaces: []database.LayerNamespace{ - {realNamespaces[1], realDetectors[1]}, - {realNamespaces[3], realDetectors[4]}, - }, - }, - } - - realAncestries = map[int]database.Ancestry{ - 2: { - Name: "ancestry-2", - By: []database.Detector{realDetectors[2], realDetectors[1]}, - Layers: []database.AncestryLayer{ - { - "layer-0", - []database.AncestryFeature{}, - }, - { - "layer-1", - []database.AncestryFeature{}, - }, - { - "layer-2", - []database.AncestryFeature{ - { - realNamespacedFeatures[1], - realDetectors[2], - realDetectors[1], - }, - }, - }, - { - "layer-3b", - []database.AncestryFeature{ - { - realNamespacedFeatures[3], - realDetectors[2], - realDetectors[1], - }, - }, - }, - }, - }, - } - - realVulnerability = map[int]database.Vulnerability{ - 1: { - Name: "CVE-OPENSSL-1-DEB7", - Namespace: realNamespaces[1], - Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", - Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", - Severity: database.HighSeverity, - }, - 2: { - Name: "CVE-NOPE", - Namespace: realNamespaces[1], - Description: "A vulnerability affecting nothing", - Severity: database.UnknownSeverity, - }, - } - - realNotification = map[int]database.VulnerabilityNotification{ - 1: { - NotificationHook: database.NotificationHook{ - Name: "test", - }, - Old: takeVulnerabilityPointerFromMap(realVulnerability, 2), - New: takeVulnerabilityPointerFromMap(realVulnerability, 1), - }, - } - - fakeFeatures = map[int]database.Feature{ - 1: { - Name: "ourchat", - Version: "0.6", - VersionFormat: "dpkg", - Type: "source", - }, - } - - fakeNamespaces = map[int]database.Namespace{ - 1: {"green hat", "rpm"}, - } - fakeNamespacedFeatures = map[int]database.NamespacedFeature{ - 1: { - Feature: fakeFeatures[0], - Namespace: realNamespaces[0], - }, - } - - fakeDetector = map[int]database.Detector{ - 1: { - Name: "fake", - Version: "1.0", - DType: database.FeatureDetectorType, - }, - 2: { - Name: "fake2", - Version: "2.0", - DType: database.NamespaceDetectorType, - }, - } -) - -func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability { - x := m[id] - return &x -} - -func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry { - x := m[id] - return &x -} - -func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer { - x := m[id] - return &x -} - -func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { - rows, err := tx.Query("SELECT name, version_format FROM namespace") - if err != nil { - t.FailNow() - } - defer rows.Close() - - namespaces := []database.Namespace{} - for rows.Next() { - var ns database.Namespace - err := rows.Scan(&ns.Name, &ns.VersionFormat) - if err != nil { - t.FailNow() - } - namespaces = append(namespaces, ns) - } - - return namespaces -} - -func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool { - if expected == actual { - return true - } - - if expected == nil || actual == nil { - return assert.Equal(t, expected, actual) - } - - return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) && - AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) && - AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New) -} - -func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool { - if expected == actual { - return true - } - - if expected == nil || actual == nil { - return assert.Equal(t, expected, actual) - } - - return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) && - assert.Equal(t, expected.Limit, actual.Limit) && - assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) && - assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) && - assert.Equal(t, expected.End, actual.End) && - database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected) -} - -func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page { - if token == pagination.FirstPageToken { - return Page{} - } - - p := Page{} - if err := key.UnmarshalToken(token, &p); err != nil { - panic(err) - } - - return p -} - -func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token { - token, err := key.MarshalToken(v) - if err != nil { - panic(err) - } - - return token -} - -var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';` - -func createAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) { - uri := "postgres@127.0.0.1:5432" - connectionTemplate := "postgresql://%s?sslmode=disable" - if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" { - uri = envURI - } - - db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri)) - if err != nil { - panic(err) - } - - testName = strings.ToLower(testName) - dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05")) - t.Logf("creating temporary database name = %s", dbName) - _, err = db.Exec("CREATE DATABASE " + dbName) - if err != nil { - panic(err) - } - - testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName)) - if err != nil { - panic(err) - } - - return testDB, func() { - t.Logf("cleaning up temporary database %s", dbName) - defer db.Close() - if err := testDB.Close(); err != nil { - panic(err) - } - - if _, err := db.Exec(`DROP DATABASE ` + dbName); err != nil { - panic(err) - } - - // ensure the database is cleaned up - var count int - if err := db.QueryRow(userDBCount).Scan(&count); err != nil { - panic(err) - } - } -} - -func createTestPgSQL(t *testing.T, testName string) (*pgSQL, func()) { - connection, cleanup := createAndConnectTestDB(t, testName) - err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...) - if err != nil { - require.Nil(t, err, err.Error()) - } - - return &pgSQL{connection, nil, Config{PaginationKey: pagination.Must(pagination.NewKey()).String()}}, cleanup -} - -func createTestPgSQLWithFixtures(t *testing.T, testName string) (*pgSQL, func()) { - connection, cleanup := createTestPgSQL(t, testName) - session, err := connection.Begin() - if err != nil { - panic(err) - } - - defer session.Rollback() - - loadFixtures(session.(*pgSession)) - if err = session.Commit(); err != nil { - panic(err) - } - - return connection, cleanup -} - -func createTestPgSession(t *testing.T, testName string) (*pgSession, func()) { - connection, cleanup := createTestPgSQL(t, testName) - session, err := connection.Begin() - if err != nil { - panic(err) - } - - return session.(*pgSession), func() { - session.Rollback() - cleanup() - } -} - -func createTestPgSessionWithFixtures(t *testing.T, testName string) (*pgSession, func()) { - tx, cleanup := createTestPgSession(t, testName) - defer func() { - // ensure to cleanup when loadFixtures failed - if r := recover(); r != nil { - cleanup() - } - }() - - loadFixtures(tx) - return tx, cleanup -} - -func loadFixtures(tx *pgSession) { - _, filename, _, _ := runtime.Caller(0) - fixturePath := filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql" - d, err := ioutil.ReadFile(fixturePath) - if err != nil { - panic(err) - } - - _, err = tx.Exec(string(d)) - if err != nil { - panic(err) - } -} - -func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { - return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) -} - -func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { - if assert.Len(t, actual, len(expected)) { - has := map[database.AffectedFeature]bool{} - for _, i := range expected { - has[i] = false - } - for _, i := range actual { - if visited, ok := has[i]; !ok { - return false - } else if visited { - return false - } - has[i] = true - } - return true - } - return false -} - -func genRandomNamespaces(t *testing.T, count int) []database.Namespace { - r := make([]database.Namespace, count) - for i := 0; i < count; i++ { - r[i] = database.Namespace{ - Name: fmt.Sprint(rand.Int()), - VersionFormat: "dpkg", - } - } - return r -} diff --git a/database/pgsql/testutil/assertion.go b/database/pgsql/testutil/assertion.go new file mode 100644 index 00000000..b823c773 --- /dev/null +++ b/database/pgsql/testutil/assertion.go @@ -0,0 +1,98 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "testing" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" + "github.com/stretchr/testify/assert" +) + +func AssertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) && + AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) && + AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New) +} + +func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return database.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) && + assert.Equal(t, expected.Limit, actual.Limit) && + assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) && + assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) && + assert.Equal(t, expected.End, actual.End) && + database.AssertIntStringMapEqual(t, expected.Affected, actual.Affected) +} + +func AssertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && AssertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) +} + +func AssertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.AffectedFeature]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + if visited, ok := has[i]; !ok { + return false + } else if visited { + return false + } + has[i] = true + } + return true + } + return false +} + +func AssertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.NamespacedFeature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") { + return false + } + } + return true + } + return false +} diff --git a/database/pgsql/testutil/data.go b/database/pgsql/testutil/data.go new file mode 100644 index 00000000..f6d9fe16 --- /dev/null +++ b/database/pgsql/testutil/data.go @@ -0,0 +1,171 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import "github.com/coreos/clair/database" + +// int keys must be the consistent with the database ID. +var ( + RealFeatures = map[int]database.Feature{ + 1: {"ourchat", "0.5", "dpkg", "source"}, + 2: {"openssl", "1.0", "dpkg", "source"}, + 3: {"openssl", "2.0", "dpkg", "source"}, + 4: {"fake", "2.0", "rpm", "source"}, + 5: {"mount", "2.31.1-0.4ubuntu3.1", "dpkg", "binary"}, + } + + RealNamespaces = map[int]database.Namespace{ + 1: {"debian:7", "dpkg"}, + 2: {"debian:8", "dpkg"}, + 3: {"fake:1.0", "rpm"}, + 4: {"cpe:/o:redhat:enterprise_linux:7::server", "rpm"}, + } + + RealNamespacedFeatures = map[int]database.NamespacedFeature{ + 1: {RealFeatures[1], RealNamespaces[1]}, + 2: {RealFeatures[2], RealNamespaces[1]}, + 3: {RealFeatures[2], RealNamespaces[2]}, + 4: {RealFeatures[3], RealNamespaces[1]}, + } + + RealDetectors = map[int]database.Detector{ + 1: database.NewNamespaceDetector("os-release", "1.0"), + 2: database.NewFeatureDetector("dpkg", "1.0"), + 3: database.NewFeatureDetector("rpm", "1.0"), + 4: database.NewNamespaceDetector("apt-sources", "1.0"), + } + + RealLayers = map[int]database.Layer{ + 2: { + Hash: "layer-1", + By: []database.Detector{RealDetectors[1], RealDetectors[2]}, + Features: []database.LayerFeature{ + {RealFeatures[1], RealDetectors[2], database.Namespace{}}, + {RealFeatures[2], RealDetectors[2], database.Namespace{}}, + }, + Namespaces: []database.LayerNamespace{ + {RealNamespaces[1], RealDetectors[1]}, + }, + }, + 6: { + Hash: "layer-4", + By: []database.Detector{RealDetectors[1], RealDetectors[2], RealDetectors[3], RealDetectors[4]}, + Features: []database.LayerFeature{ + {RealFeatures[4], RealDetectors[3], database.Namespace{}}, + {RealFeatures[3], RealDetectors[2], database.Namespace{}}, + }, + Namespaces: []database.LayerNamespace{ + {RealNamespaces[1], RealDetectors[1]}, + {RealNamespaces[3], RealDetectors[4]}, + }, + }, + } + + RealAncestries = map[int]database.Ancestry{ + 2: { + Name: "ancestry-2", + By: []database.Detector{RealDetectors[2], RealDetectors[1]}, + Layers: []database.AncestryLayer{ + { + "layer-0", + []database.AncestryFeature{}, + }, + { + "layer-1", + []database.AncestryFeature{}, + }, + { + "layer-2", + []database.AncestryFeature{ + { + RealNamespacedFeatures[1], + RealDetectors[2], + RealDetectors[1], + }, + }, + }, + { + "layer-3b", + []database.AncestryFeature{ + { + RealNamespacedFeatures[3], + RealDetectors[2], + RealDetectors[1], + }, + }, + }, + }, + }, + } + + RealVulnerability = map[int]database.Vulnerability{ + 1: { + Name: "CVE-OPENSSL-1-DEB7", + Namespace: RealNamespaces[1], + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, + }, + 2: { + Name: "CVE-NOPE", + Namespace: RealNamespaces[1], + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, + }, + } + + RealNotification = map[int]database.VulnerabilityNotification{ + 1: { + NotificationHook: database.NotificationHook{ + Name: "test", + }, + Old: takeVulnerabilityPointerFromMap(RealVulnerability, 2), + New: takeVulnerabilityPointerFromMap(RealVulnerability, 1), + }, + } + + FakeFeatures = map[int]database.Feature{ + 1: { + Name: "ourchat", + Version: "0.6", + VersionFormat: "dpkg", + Type: "source", + }, + } + + FakeNamespaces = map[int]database.Namespace{ + 1: {"green hat", "rpm"}, + } + + FakeNamespacedFeatures = map[int]database.NamespacedFeature{ + 1: { + Feature: FakeFeatures[0], + Namespace: RealNamespaces[0], + }, + } + + FakeDetector = map[int]database.Detector{ + 1: { + Name: "fake", + Version: "1.0", + DType: database.FeatureDetectorType, + }, + 2: { + Name: "fake2", + Version: "2.0", + DType: database.NamespaceDetectorType, + }, + } +) diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testutil/data.sql similarity index 100% rename from database/pgsql/testdata/data.sql rename to database/pgsql/testutil/data.sql diff --git a/database/pgsql/testutil/testdb.go b/database/pgsql/testutil/testdb.go new file mode 100644 index 00000000..558a3dc3 --- /dev/null +++ b/database/pgsql/testutil/testdb.go @@ -0,0 +1,207 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "database/sql" + "fmt" + "io/ioutil" + "os" + "path/filepath" + "runtime" + "strings" + "testing" + "time" + + "github.com/coreos/clair/database/pgsql/migrations" + "github.com/coreos/clair/pkg/pagination" + "github.com/remind101/migrate" +) + +var TestPaginationKey = pagination.Must(pagination.NewKey()) + +var userDBCount = `SELECT count(datname) FROM pg_database WHERE datistemplate = FALSE AND datname != 'postgres';` + +func CreateAndConnectTestDB(t *testing.T, testName string) (*sql.DB, func()) { + uri := "postgres@127.0.0.1:5432" + connectionTemplate := "postgresql://%s?sslmode=disable" + if envURI := os.Getenv("CLAIR_TEST_PGSQL"); envURI != "" { + uri = envURI + } + + db, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri)) + if err != nil { + panic(err) + } + + testName = strings.ToLower(testName) + dbName := fmt.Sprintf("test_%s_%s", testName, time.Now().UTC().Format("2006_01_02_15_04_05")) + t.Logf("creating temporary database name = %s", dbName) + _, err = db.Exec("CREATE DATABASE " + dbName) + if err != nil { + panic(err) + } + + testDB, err := sql.Open("postgres", fmt.Sprintf(connectionTemplate, uri+"/"+dbName)) + if err != nil { + panic(err) + } + + return testDB, func() { + cleanupTestDB(t, dbName, db, testDB) + } +} + +func cleanupTestDB(t *testing.T, name string, db, testDB *sql.DB) { + t.Logf("cleaning up temporary database %s", name) + if db == nil { + panic("db is none") + } + + if testDB == nil { + panic("testDB is none") + } + + defer db.Close() + + if err := testDB.Close(); err != nil { + panic(err) + } + + // Kill any opened connection. + if _, err := db.Exec(` + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = $1 + AND pid <> pg_backend_pid()`, name); err != nil { + panic(err) + } + + if _, err := db.Exec(`DROP DATABASE ` + name); err != nil { + panic(err) + } + + // ensure the database is cleaned up + var count int + if err := db.QueryRow(userDBCount).Scan(&count); err != nil { + panic(err) + } +} + +func CreateTestDB(t *testing.T, testName string) (*sql.DB, func()) { + connection, cleanup := CreateAndConnectTestDB(t, testName) + err := migrate.NewPostgresMigrator(connection).Exec(migrate.Up, migrations.Migrations...) + if err != nil { + panic(err) + } + + return connection, cleanup +} + +func CreateTestDBWithFixture(t *testing.T, testName string) (*sql.DB, func()) { + connection, cleanup := CreateTestDB(t, testName) + session, err := connection.Begin() + if err != nil { + panic(err) + } + + defer session.Rollback() + + loadFixtures(session) + if err = session.Commit(); err != nil { + panic(err) + } + + return connection, cleanup +} + +func CreateTestTx(t *testing.T, testName string) (*sql.Tx, func()) { + connection, cleanup := CreateTestDB(t, testName) + session, err := connection.Begin() + if session == nil { + panic("session is none") + } + + if err != nil { + panic(err) + } + + return session, func() { + session.Rollback() + cleanup() + } +} + +func CreateTestTxWithFixtures(t *testing.T, testName string) (*sql.Tx, func()) { + tx, cleanup := CreateTestTx(t, testName) + defer func() { + // ensure to cleanup when loadFixtures failed + if r := recover(); r != nil { + cleanup() + } + }() + + loadFixtures(tx) + return tx, cleanup +} + +func loadFixtures(tx *sql.Tx) { + _, filename, _, _ := runtime.Caller(0) + fixturePath := filepath.Join(filepath.Dir(filename), "data.sql") + d, err := ioutil.ReadFile(fixturePath) + if err != nil { + panic(err) + } + + _, err = tx.Exec(string(d)) + if err != nil { + panic(err) + } +} + +func OpenSessionForTest(t *testing.T, name string, loadFixture bool) (*sql.DB, *sql.Tx) { + var db *sql.DB + if loadFixture { + db, _ = CreateTestDB(t, name) + } else { + db, _ = CreateTestDBWithFixture(t, name) + } + + tx, err := db.Begin() + if err != nil { + panic(err) + } + + return db, tx +} + +func RestartTransaction(db *sql.DB, tx *sql.Tx, commit bool) *sql.Tx { + if !commit { + if err := tx.Rollback(); err != nil { + panic(err) + } + } else { + if err := tx.Commit(); err != nil { + panic(err) + } + } + + tx, err := db.Begin() + if err != nil { + panic(err) + } + + return tx +} diff --git a/database/pgsql/testutil/testutil.go b/database/pgsql/testutil/testutil.go new file mode 100644 index 00000000..03c81dd7 --- /dev/null +++ b/database/pgsql/testutil/testutil.go @@ -0,0 +1,94 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package testutil + +import ( + "database/sql" + "fmt" + "math/rand" + "testing" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/page" + "github.com/coreos/clair/pkg/pagination" +) + +func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability { + x := m[id] + return &x +} + +func TakeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry { + x := m[id] + return &x +} + +func TakeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer { + x := m[id] + return &x +} + +func ListNamespaces(t *testing.T, tx *sql.Tx) []database.Namespace { + rows, err := tx.Query("SELECT name, version_format FROM namespace") + if err != nil { + t.FailNow() + } + defer rows.Close() + + namespaces := []database.Namespace{} + for rows.Next() { + var ns database.Namespace + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + t.FailNow() + } + namespaces = append(namespaces, ns) + } + + return namespaces +} + +func mustUnmarshalToken(key pagination.Key, token pagination.Token) page.Page { + if token == pagination.FirstPageToken { + return page.Page{} + } + + p := page.Page{} + if err := key.UnmarshalToken(token, &p); err != nil { + panic(err) + } + + return p +} + +func MustMarshalToken(key pagination.Key, v interface{}) pagination.Token { + token, err := key.MarshalToken(v) + if err != nil { + panic(err) + } + + return token +} + +func GenRandomNamespaces(t *testing.T, count int) []database.Namespace { + r := make([]database.Namespace, count) + for i := 0; i < count; i++ { + r[i] = database.Namespace{ + Name: fmt.Sprint(rand.Int()), + VersionFormat: "dpkg", + } + } + return r +}