database: Allow specifying datastore driver by config

Fixes #145
This commit is contained in:
Quentin Machu 2016-05-02 18:33:03 -04:00
parent 53e62577bc
commit e7b960c05b
17 changed files with 264 additions and 229 deletions

View File

@ -26,7 +26,7 @@ import (
"github.com/coreos/clair/api" "github.com/coreos/clair/api"
"github.com/coreos/clair/api/context" "github.com/coreos/clair/api/context"
"github.com/coreos/clair/config" "github.com/coreos/clair/config"
"github.com/coreos/clair/database/pgsql" "github.com/coreos/clair/database"
"github.com/coreos/clair/notifier" "github.com/coreos/clair/notifier"
"github.com/coreos/clair/updater" "github.com/coreos/clair/updater"
"github.com/coreos/clair/utils" "github.com/coreos/clair/utils"
@ -42,7 +42,7 @@ func Boot(config *config.Config) {
st := utils.NewStopper() st := utils.NewStopper()
// Open database // Open database
db, err := pgsql.Open(config.Database) db, err := database.Open(config.Database)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -20,11 +20,11 @@ import (
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"github.com/coreos/pkg/capnslog"
"github.com/coreos/clair" "github.com/coreos/clair"
"github.com/coreos/clair/config" "github.com/coreos/clair/config"
"github.com/coreos/pkg/capnslog"
// Register components // Register components
_ "github.com/coreos/clair/notifier/notifiers" _ "github.com/coreos/clair/notifier/notifiers"
@ -43,6 +43,8 @@ import (
_ "github.com/coreos/clair/worker/detectors/namespace/lsbrelease" _ "github.com/coreos/clair/worker/detectors/namespace/lsbrelease"
_ "github.com/coreos/clair/worker/detectors/namespace/osrelease" _ "github.com/coreos/clair/worker/detectors/namespace/osrelease"
_ "github.com/coreos/clair/worker/detectors/namespace/redhatrelease" _ "github.com/coreos/clair/worker/detectors/namespace/redhatrelease"
_ "github.com/coreos/clair/database/pgsql"
) )
var log = capnslog.NewPackageLogger("github.com/coreos/clair/cmd/clair", "main") var log = capnslog.NewPackageLogger("github.com/coreos/clair/cmd/clair", "main")

View File

@ -15,13 +15,16 @@
# The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined. # The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined.
clair: clair:
database: database:
# Database driver
type: pgsql
options:
# PostgreSQL Connection string # PostgreSQL Connection string
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html # http://www.postgresql.org/docs/9.4/static/libpq-connect.html
source: source:
# Number of elements kept in the cache # Number of elements kept in the cache
# Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database.
cacheSize: 16384 cachesize: 16384
api: api:
# API server port # API server port

View File

@ -27,6 +27,14 @@ import (
// ErrDatasourceNotLoaded is returned when the datasource variable in the configuration file is not loaded properly // ErrDatasourceNotLoaded is returned when the datasource variable in the configuration file is not loaded properly
var ErrDatasourceNotLoaded = errors.New("could not load configuration: no database source specified") var ErrDatasourceNotLoaded = errors.New("could not load configuration: no database source specified")
// RegistrableComponentConfig is a configuration block that can be used to
// determine which registrable component should be initialized and pass
// custom configuration to it.
type RegistrableComponentConfig struct {
Type string
Options map[string]interface{}
}
// File represents a YAML configuration file that namespaces all Clair // File represents a YAML configuration file that namespaces all Clair
// configuration under the top-level "clair" key. // configuration under the top-level "clair" key.
type File struct { type File struct {
@ -35,19 +43,12 @@ type File struct {
// Config is the global configuration for an instance of Clair. // Config is the global configuration for an instance of Clair.
type Config struct { type Config struct {
Database *DatabaseConfig Database RegistrableComponentConfig
Updater *UpdaterConfig Updater *UpdaterConfig
Notifier *NotifierConfig Notifier *NotifierConfig
API *APIConfig API *APIConfig
} }
// DatabaseConfig is the configuration used to specify how Clair connects
// to a database.
type DatabaseConfig struct {
Source string
CacheSize int
}
// UpdaterConfig is the configuration for the Updater service. // UpdaterConfig is the configuration for the Updater service.
type UpdaterConfig struct { type UpdaterConfig struct {
Interval time.Duration Interval time.Duration
@ -72,8 +73,8 @@ type APIConfig struct {
// DefaultConfig is a configuration that can be used as a fallback value. // DefaultConfig is a configuration that can be used as a fallback value.
func DefaultConfig() Config { func DefaultConfig() Config {
return Config{ return Config{
Database: &DatabaseConfig{ Database: RegistrableComponentConfig{
CacheSize: 16384, Type: "pgsql",
}, },
Updater: &UpdaterConfig{ Updater: &UpdaterConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
@ -116,12 +117,8 @@ func Load(path string) (config *Config, err error) {
} }
config = &cfgFile.Clair config = &cfgFile.Clair
if config.Database.Source == "" {
err = ErrDatasourceNotLoaded
return
}
// Generate a pagination key if none is provided. // Generate a pagination key if none is provided.
// TODO(Quentin-M): Move to the API code.
if config.API.PaginationKey == "" { if config.API.PaginationKey == "" {
var key fernet.Key var key fernet.Key
if err = key.Generate(); err != nil { if err = key.Generate(); err != nil {

View File

@ -1,81 +0,0 @@
package config
import (
"io/ioutil"
"log"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
const wrongConfig = `
dummyKey:
wrong:true
`
const goodConfig = `
clair:
database:
source: postgresql://postgres:root@postgres:5432?sslmode=disable
cacheSize: 16384
api:
port: 6060
healthport: 6061
timeout: 900s
paginationKey:
servername:
cafile:
keyfile:
certfile:
updater:
interval: 2h
notifier:
attempts: 3
renotifyInterval: 2h
http:
endpoint:
servername:
cafile:
keyfile:
certfile:
proxy:
`
func TestLoadWrongConfiguration(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "clair-config")
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name()) // clean up
if _, err := tmpfile.Write([]byte(wrongConfig)); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
_, err = Load(tmpfile.Name())
assert.EqualError(t, err, ErrDatasourceNotLoaded.Error())
}
func TestLoad(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "clair-config")
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name()) // clean up
if _, err := tmpfile.Write([]byte(goodConfig)); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
_, err = Load(tmpfile.Name())
assert.NoError(t, err)
}

View File

@ -17,7 +17,10 @@ package database
import ( import (
"errors" "errors"
"fmt"
"time" "time"
"github.com/coreos/clair/config"
) )
var ( var (
@ -28,11 +31,37 @@ var (
// ErrInconsistent is an error that occurs when a database consistency check // ErrInconsistent is an error that occurs when a database consistency check
// fails (ie. when an entity which is supposed to be unique is detected twice) // fails (ie. when an entity which is supposed to be unique is detected twice)
ErrInconsistent = errors.New("database: inconsistent database") ErrInconsistent = errors.New("database: inconsistent database")
// ErrCantOpen is an error that occurs when the database could not be opened
ErrCantOpen = errors.New("database: could not open database")
) )
var drivers = make(map[string]Driver)
// Driver is a function that opens a Datastore specified by its database driver type and specific
// configuration.
type Driver func(config.RegistrableComponentConfig) (Datastore, error)
// Register makes a Constructor available by the provided name.
//
// If this function is called twice with the same name or if the Constructor is
// nil, it panics.
func Register(name string, driver Driver) {
if driver == nil {
panic("database: could not register nil Driver")
}
if _, dup := drivers[name]; dup {
panic("database: could not register duplicate Driver: " + name)
}
drivers[name] = driver
}
// Open opens a Datastore specified by a configuration.
func Open(cfg config.RegistrableComponentConfig) (Datastore, error) {
driver, ok := drivers[cfg.Type]
if !ok {
return nil, fmt.Errorf("database: unknown Driver %q (forgotten configuration or import?)", cfg.Type)
}
return driver(cfg)
}
// Datastore is the interface that describes a database backend implementation. // Datastore is the interface that describes a database backend implementation.
type Datastore interface { type Datastore interface {
// # Namespace // # Namespace

View File

@ -23,11 +23,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/utils" "github.com/coreos/clair/utils"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -36,7 +37,7 @@ const (
) )
func TestRaceAffects(t *testing.T) { func TestRaceAffects(t *testing.T) {
datastore, err := OpenForTest("RaceAffects", false) datastore, err := openDatabaseForTest("RaceAffects", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -17,13 +17,14 @@ package pgsql
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestInsertFeature(t *testing.T) { func TestInsertFeature(t *testing.T) {
datastore, err := OpenForTest("InsertFeature", false) datastore, err := openDatabaseForTest("InsertFeature", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -21,7 +21,7 @@ import (
) )
func TestKeyValue(t *testing.T) { func TestKeyValue(t *testing.T) {
datastore, err := OpenForTest("KeyValue", false) datastore, err := openDatabaseForTest("KeyValue", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -41,9 +41,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
var namespaceName sql.NullString var namespaceName sql.NullString
t := time.Now() t := time.Now()
err := pgSQL.QueryRow(searchLayer, name). err := pgSQL.QueryRow(searchLayer, name).Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, &namespaceName)
Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID,
&namespaceName)
observeQueryTime("FindLayer", "searchLayer", t) observeQueryTime("FindLayer", "searchLayer", t)
if err != nil { if err != nil {

View File

@ -18,14 +18,15 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestFindLayer(t *testing.T) { func TestFindLayer(t *testing.T) {
datastore, err := OpenForTest("FindLayer", true) datastore, err := openDatabaseForTest("FindLayer", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -102,7 +103,7 @@ func TestFindLayer(t *testing.T) {
} }
func TestInsertLayer(t *testing.T) { func TestInsertLayer(t *testing.T) {
datastore, err := OpenForTest("InsertLayer", false) datastore, err := openDatabaseForTest("InsertLayer", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -22,7 +22,7 @@ import (
) )
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
datastore, err := OpenForTest("InsertNamespace", false) datastore, err := openDatabaseForTest("InsertNamespace", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -18,12 +18,13 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/coreos/clair/database"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database"
) )
func TestInsertNamespace(t *testing.T) { func TestInsertNamespace(t *testing.T) {
datastore, err := OpenForTest("InsertNamespace", false) datastore, err := openDatabaseForTest("InsertNamespace", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -44,7 +45,7 @@ func TestInsertNamespace(t *testing.T) {
} }
func TestListNamespace(t *testing.T) { func TestListNamespace(t *testing.T) {
datastore, err := OpenForTest("ListNamespaces", true) datastore, err := openDatabaseForTest("ListNamespaces", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -4,14 +4,15 @@ import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestNotification(t *testing.T) { func TestNotification(t *testing.T) {
datastore, err := OpenForTest("Notification", false) datastore, err := openDatabaseForTest("Notification", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -19,22 +19,23 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "net/url"
"path" "path"
"runtime" "runtime"
"strings" "strings"
"time" "time"
"bitbucket.org/liamstask/goose/lib/goose" "bitbucket.org/liamstask/goose/lib/goose"
"github.com/coreos/pkg/capnslog"
"github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"gopkg.in/yaml.v2"
"github.com/coreos/clair/config" "github.com/coreos/clair/config"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/utils" "github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/pkg/capnslog"
"github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus"
) )
var ( var (
@ -72,6 +73,8 @@ func init() {
prometheus.MustRegister(promCacheQueriesTotal) prometheus.MustRegister(promCacheQueriesTotal)
prometheus.MustRegister(promQueryDurationMilliseconds) prometheus.MustRegister(promQueryDurationMilliseconds)
prometheus.MustRegister(promConcurrentLockVAFV) prometheus.MustRegister(promConcurrentLockVAFV)
database.Register("pgsql", openDatabase)
} }
type Queryer interface { type Queryer interface {
@ -82,45 +85,136 @@ type Queryer interface {
type pgSQL struct { type pgSQL struct {
*sql.DB *sql.DB
cache *lru.ARCCache cache *lru.ARCCache
config Config
} }
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
// the configuration.
func (pgSQL *pgSQL) Close() { func (pgSQL *pgSQL) Close() {
if pgSQL.DB != nil {
pgSQL.DB.Close() pgSQL.DB.Close()
}
if pgSQL.config.ManageDatabaseLifecycle {
dbName, pgSourceURL, _ := parseConnectionString(pgSQL.config.Source)
dropDatabase(pgSourceURL, dbName)
}
} }
// Ping verifies that the database is accessible.
func (pgSQL *pgSQL) Ping() bool { func (pgSQL *pgSQL) Ping() bool {
return pgSQL.DB.Ping() == nil return pgSQL.DB.Ping() == nil
} }
// Open creates a Datastore backed by a PostgreSQL database. // Config is the configuration that is used by openDatabase.
// type Config struct {
// It will run immediately every necessary migration on the database. Source string
func Open(config *config.DatabaseConfig) (database.Datastore, error) { CacheSize int
// Run migrations.
if err := migrate(config.Source); err != nil { ManageDatabaseLifecycle bool
log.Error(err) FixturePath string
return nil, database.ErrCantOpen }
// openDatabase opens a PostgresSQL-backed Datastore using the given configuration.
// It immediately every necessary migrations. If ManageDatabaseLifecycle is specified,
// the database will be created first. If FixturePath is specified, every SQL queries that are
// present insides will be executed.
func openDatabase(registrableComponentConfig config.RegistrableComponentConfig) (database.Datastore, error) {
var pg pgSQL
var err error
// Parse configuration.
pg.config = Config{
CacheSize: 16384,
}
bytes, err := yaml.Marshal(registrableComponentConfig.Options)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
err = yaml.Unmarshal(bytes, &pg.config)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
if err != nil {
return nil, err
}
// Create database.
if pg.config.ManageDatabaseLifecycle {
log.Info("pgsql: creating database")
if err := createDatabase(pgSourceURL, dbName); err != nil {
return nil, err
}
} }
// Open database. // Open database.
db, err := sql.Open("postgres", config.Source) pg.DB, err = sql.Open("postgres", pg.config.Source)
if err != nil { if err != nil {
log.Error(err) pg.Close()
return nil, database.ErrCantOpen return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
// Verify database state.
if err := pg.DB.Ping(); err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
// Run migrations.
if err := migrate(pg.config.Source); err != nil {
pg.Close()
return nil, err
}
// Load fixture data.
if pg.config.FixturePath != "" {
log.Info("pgsql: loading fixtures")
d, err := ioutil.ReadFile(pg.config.FixturePath)
if err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open fixture file: %v", err)
}
_, err = pg.DB.Exec(string(d))
if err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err)
}
} }
// Initialize cache. // Initialize cache.
// TODO(Quentin-M): Benchmark with a simple LRU Cache. // TODO(Quentin-M): Benchmark with a simple LRU Cache.
var cache *lru.ARCCache if pg.config.CacheSize > 0 {
if config.CacheSize > 0 { pg.cache, _ = lru.NewARC(pg.config.CacheSize)
cache, _ = lru.NewARC(config.CacheSize)
} }
return &pgSQL{DB: db, cache: cache}, nil return &pg, nil
}
func parseConnectionString(source string) (dbName string, pgSourceURL string, err error) {
if source == "" {
return "", "", cerrors.NewBadRequestError("pgsql: no database connection string specified")
}
sourceURL, err := url.Parse(source)
if err != nil {
return "", "", cerrors.NewBadRequestError("pgsql: database connection string is not a valid URL")
}
dbName = strings.TrimPrefix(sourceURL.Path, "/")
pgSource := *sourceURL
pgSource.Path = "/postgres"
pgSourceURL = pgSource.String()
return
} }
// migrate runs all available migrations on a pgSQL database. // migrate runs all available migrations on a pgSQL database.
func migrate(dataSource string) error { func migrate(source string) error {
log.Info("running database migrations") log.Info("running database migrations")
_, filename, _, _ := runtime.Caller(1) _, filename, _, _ := runtime.Caller(1)
@ -129,7 +223,7 @@ func migrate(dataSource string) error {
MigrationsDir: migrationDir, MigrationsDir: migrationDir,
Driver: goose.DBDriver{ Driver: goose.DBDriver{
Name: "postgres", Name: "postgres",
OpenStr: dataSource, OpenStr: source,
Import: "github.com/lib/pq", Import: "github.com/lib/pq",
Dialect: &goose.PostgresDialect{}, Dialect: &goose.PostgresDialect{},
}, },
@ -138,13 +232,13 @@ func migrate(dataSource string) error {
// Determine the most recent revision available from the migrations folder. // Determine the most recent revision available from the migrations folder.
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir) target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
if err != nil { if err != nil {
return err return fmt.Errorf("pgsql: could not get most recent migration: %v", err)
} }
// Run migrations // Run migrations.
err = goose.RunMigrations(conf, conf.MigrationsDir, target) err = goose.RunMigrations(conf, conf.MigrationsDir, target)
if err != nil { if err != nil {
return err return fmt.Errorf("pgsql: an error occured while running migrations: %v", err)
} }
log.Info("database migration ran successfully") log.Info("database migration ran successfully")
@ -152,109 +246,51 @@ func migrate(dataSource string) error {
} }
// createDatabase creates a new database. // createDatabase creates a new database.
// The dataSource parameter should not contain a dbname. // The source parameter should not contain a dbname.
func createDatabase(dataSource, databaseName string) error { func createDatabase(source, dbName string) error {
// Open database. // Open database.
db, err := sql.Open("postgres", dataSource) db, err := sql.Open("postgres", source)
if err != nil { if err != nil {
return fmt.Errorf("could not open database (CreateDatabase): %v", err) return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
} }
defer db.Close() defer db.Close()
// Create database. // Create database.
_, err = db.Exec("CREATE DATABASE " + databaseName) _, err = db.Exec("CREATE DATABASE " + dbName)
if err != nil { if err != nil {
return fmt.Errorf("could not create database: %v", err) return fmt.Errorf("pgsql: could not create database: %v", err)
} }
return nil return nil
} }
// dropDatabase drops an existing database. // dropDatabase drops an existing database.
// The dataSource parameter should not contain a dbname. // The source parameter should not contain a dbname.
func dropDatabase(dataSource, databaseName string) error { func dropDatabase(source, dbName string) error {
// Open database. // Open database.
db, err := sql.Open("postgres", dataSource) db, err := sql.Open("postgres", source)
if err != nil { if err != nil {
return fmt.Errorf("could not open database (DropDatabase): %v", err) return fmt.Errorf("could not open database (DropDatabase): %v", err)
} }
defer db.Close() defer db.Close()
// Kill any opened connection. // Kill any opened connection.
if _, err := db.Exec(` if _, err = db.Exec(`
SELECT pg_terminate_backend(pg_stat_activity.pid) SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity FROM pg_stat_activity
WHERE pg_stat_activity.datname = $1 WHERE pg_stat_activity.datname = $1
AND pid <> pg_backend_pid()`, databaseName); err != nil { AND pid <> pg_backend_pid()`, dbName); err != nil {
return fmt.Errorf("could not drop database: %v", err) return fmt.Errorf("could not drop database: %v", err)
} }
// Drop database. // Drop database.
if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil { if _, err = db.Exec("DROP DATABASE " + dbName); err != nil {
return fmt.Errorf("could not drop database: %v", err) return fmt.Errorf("could not drop database: %v", err)
} }
return nil return nil
} }
// pgSQLTest wraps pgSQL for testing purposes.
// Its Close() method drops the database.
type pgSQLTest struct {
*pgSQL
dataSourceDefaultDatabase string
dbName string
}
// OpenForTest creates a test Datastore backed by a new PostgreSQL database.
// It creates a new unique and prefixed ("test_") database.
// Using Close() will drop the database.
func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
// Define the PostgreSQL connection strings.
dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname="
if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" {
dataSource = dataSourceEnv + " dbname="
}
dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
dataSourceDefaultDatabase := dataSource + "postgres"
dataSourceTestDatabase := dataSource + dbName
// Create database.
if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil {
log.Error(err)
return nil, database.ErrCantOpen
}
// Open database.
db, err := Open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0})
if err != nil {
dropDatabase(dataSourceDefaultDatabase, dbName)
log.Error(err)
return nil, database.ErrCantOpen
}
// Load test data if specified.
if withTestData {
_, filename, _, _ := runtime.Caller(0)
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql")
_, err = db.(*pgSQL).Exec(string(d))
if err != nil {
dropDatabase(dataSourceDefaultDatabase, dbName)
log.Error(err)
return nil, database.ErrCantOpen
}
}
return &pgSQLTest{
pgSQL: db.(*pgSQL),
dataSourceDefaultDatabase: dataSourceDefaultDatabase,
dbName: dbName}, nil
}
func (pgSQL *pgSQLTest) Close() {
pgSQL.DB.Close()
dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName)
}
// handleError logs an error with an extra description and masks the error if it's an SQL one. // handleError logs an error with an extra description and masks the error if it's an SQL one.
// This ensures we never return plain SQL errors and leak anything. // This ensures we never return plain SQL errors and leak anything.
func handleError(desc string, err error) error { func handleError(desc string, err error) error {

View File

@ -0,0 +1,45 @@
package pgsql
import (
"fmt"
"os"
"path"
"runtime"
"strings"
"github.com/coreos/clair/config"
"github.com/pborman/uuid"
)
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
ds, err := openDatabase(generateTestConfig(testName, loadFixture))
if err != nil {
return nil, err
}
datastore := ds.(*pgSQL)
return datastore, nil
}
func generateTestConfig(testName string, loadFixture bool) config.RegistrableComponentConfig {
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
var fixturePath string
if loadFixture {
_, filename, _, _ := runtime.Caller(0)
fixturePath = path.Join(path.Dir(filename)) + "/testdata/data.sql"
}
source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName)
if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" {
source = fmt.Sprintf(sourceEnv, dbName)
}
return config.RegistrableComponentConfig{
Options: map[string]interface{}{
"source": source,
"cachesize": 0,
"managedatabaselifecycle": true,
"fixturepath": fixturePath,
},
}
}

View File

@ -18,14 +18,15 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestFindVulnerability(t *testing.T) { func TestFindVulnerability(t *testing.T) {
datastore, err := OpenForTest("FindVulnerability", true) datastore, err := openDatabaseForTest("FindVulnerability", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -75,7 +76,7 @@ func TestFindVulnerability(t *testing.T) {
} }
func TestDeleteVulnerability(t *testing.T) { func TestDeleteVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", true) datastore, err := openDatabaseForTest("InsertVulnerability", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -97,7 +98,7 @@ func TestDeleteVulnerability(t *testing.T) {
} }
func TestInsertVulnerability(t *testing.T) { func TestInsertVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", false) datastore, err := openDatabaseForTest("InsertVulnerability", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return