parent
53e62577bc
commit
e7b960c05b
4
clair.go
4
clair.go
@ -26,7 +26,7 @@ import (
|
||||
"github.com/coreos/clair/api"
|
||||
"github.com/coreos/clair/api/context"
|
||||
"github.com/coreos/clair/config"
|
||||
"github.com/coreos/clair/database/pgsql"
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/notifier"
|
||||
"github.com/coreos/clair/updater"
|
||||
"github.com/coreos/clair/utils"
|
||||
@ -42,7 +42,7 @@ func Boot(config *config.Config) {
|
||||
st := utils.NewStopper()
|
||||
|
||||
// Open database
|
||||
db, err := pgsql.Open(config.Database)
|
||||
db, err := database.Open(config.Database)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -20,11 +20,11 @@ import (
|
||||
"runtime/pprof"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/pkg/capnslog"
|
||||
|
||||
"github.com/coreos/clair"
|
||||
"github.com/coreos/clair/config"
|
||||
|
||||
"github.com/coreos/pkg/capnslog"
|
||||
|
||||
// Register components
|
||||
_ "github.com/coreos/clair/notifier/notifiers"
|
||||
|
||||
@ -43,6 +43,8 @@ import (
|
||||
_ "github.com/coreos/clair/worker/detectors/namespace/lsbrelease"
|
||||
_ "github.com/coreos/clair/worker/detectors/namespace/osrelease"
|
||||
_ "github.com/coreos/clair/worker/detectors/namespace/redhatrelease"
|
||||
|
||||
_ "github.com/coreos/clair/database/pgsql"
|
||||
)
|
||||
|
||||
var log = capnslog.NewPackageLogger("github.com/coreos/clair/cmd/clair", "main")
|
||||
|
@ -15,13 +15,16 @@
|
||||
# The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined.
|
||||
clair:
|
||||
database:
|
||||
# Database driver
|
||||
type: pgsql
|
||||
options:
|
||||
# PostgreSQL Connection string
|
||||
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html
|
||||
source:
|
||||
|
||||
# Number of elements kept in the cache
|
||||
# Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database.
|
||||
cacheSize: 16384
|
||||
cachesize: 16384
|
||||
|
||||
api:
|
||||
# API server port
|
||||
|
@ -27,6 +27,14 @@ import (
|
||||
// ErrDatasourceNotLoaded is returned when the datasource variable in the configuration file is not loaded properly
|
||||
var ErrDatasourceNotLoaded = errors.New("could not load configuration: no database source specified")
|
||||
|
||||
// RegistrableComponentConfig is a configuration block that can be used to
|
||||
// determine which registrable component should be initialized and pass
|
||||
// custom configuration to it.
|
||||
type RegistrableComponentConfig struct {
|
||||
Type string
|
||||
Options map[string]interface{}
|
||||
}
|
||||
|
||||
// File represents a YAML configuration file that namespaces all Clair
|
||||
// configuration under the top-level "clair" key.
|
||||
type File struct {
|
||||
@ -35,19 +43,12 @@ type File struct {
|
||||
|
||||
// Config is the global configuration for an instance of Clair.
|
||||
type Config struct {
|
||||
Database *DatabaseConfig
|
||||
Database RegistrableComponentConfig
|
||||
Updater *UpdaterConfig
|
||||
Notifier *NotifierConfig
|
||||
API *APIConfig
|
||||
}
|
||||
|
||||
// DatabaseConfig is the configuration used to specify how Clair connects
|
||||
// to a database.
|
||||
type DatabaseConfig struct {
|
||||
Source string
|
||||
CacheSize int
|
||||
}
|
||||
|
||||
// UpdaterConfig is the configuration for the Updater service.
|
||||
type UpdaterConfig struct {
|
||||
Interval time.Duration
|
||||
@ -72,8 +73,8 @@ type APIConfig struct {
|
||||
// DefaultConfig is a configuration that can be used as a fallback value.
|
||||
func DefaultConfig() Config {
|
||||
return Config{
|
||||
Database: &DatabaseConfig{
|
||||
CacheSize: 16384,
|
||||
Database: RegistrableComponentConfig{
|
||||
Type: "pgsql",
|
||||
},
|
||||
Updater: &UpdaterConfig{
|
||||
Interval: 1 * time.Hour,
|
||||
@ -116,12 +117,8 @@ func Load(path string) (config *Config, err error) {
|
||||
}
|
||||
config = &cfgFile.Clair
|
||||
|
||||
if config.Database.Source == "" {
|
||||
err = ErrDatasourceNotLoaded
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a pagination key if none is provided.
|
||||
// TODO(Quentin-M): Move to the API code.
|
||||
if config.API.PaginationKey == "" {
|
||||
var key fernet.Key
|
||||
if err = key.Generate(); err != nil {
|
||||
|
@ -1,81 +0,0 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const wrongConfig = `
|
||||
dummyKey:
|
||||
wrong:true
|
||||
`
|
||||
|
||||
const goodConfig = `
|
||||
clair:
|
||||
database:
|
||||
source: postgresql://postgres:root@postgres:5432?sslmode=disable
|
||||
cacheSize: 16384
|
||||
api:
|
||||
port: 6060
|
||||
healthport: 6061
|
||||
timeout: 900s
|
||||
paginationKey:
|
||||
servername:
|
||||
cafile:
|
||||
keyfile:
|
||||
certfile:
|
||||
updater:
|
||||
interval: 2h
|
||||
notifier:
|
||||
attempts: 3
|
||||
renotifyInterval: 2h
|
||||
http:
|
||||
endpoint:
|
||||
servername:
|
||||
cafile:
|
||||
keyfile:
|
||||
certfile:
|
||||
proxy:
|
||||
`
|
||||
|
||||
func TestLoadWrongConfiguration(t *testing.T) {
|
||||
tmpfile, err := ioutil.TempFile("", "clair-config")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
if _, err := tmpfile.Write([]byte(wrongConfig)); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = Load(tmpfile.Name())
|
||||
|
||||
assert.EqualError(t, err, ErrDatasourceNotLoaded.Error())
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
tmpfile, err := ioutil.TempFile("", "clair-config")
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
defer os.Remove(tmpfile.Name()) // clean up
|
||||
|
||||
if _, err := tmpfile.Write([]byte(goodConfig)); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if err := tmpfile.Close(); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = Load(tmpfile.Name())
|
||||
assert.NoError(t, err)
|
||||
}
|
@ -17,7 +17,10 @@ package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/coreos/clair/config"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -28,11 +31,37 @@ var (
|
||||
// ErrInconsistent is an error that occurs when a database consistency check
|
||||
// fails (ie. when an entity which is supposed to be unique is detected twice)
|
||||
ErrInconsistent = errors.New("database: inconsistent database")
|
||||
|
||||
// ErrCantOpen is an error that occurs when the database could not be opened
|
||||
ErrCantOpen = errors.New("database: could not open database")
|
||||
)
|
||||
|
||||
var drivers = make(map[string]Driver)
|
||||
|
||||
// Driver is a function that opens a Datastore specified by its database driver type and specific
|
||||
// configuration.
|
||||
type Driver func(config.RegistrableComponentConfig) (Datastore, error)
|
||||
|
||||
// Register makes a Constructor available by the provided name.
|
||||
//
|
||||
// If this function is called twice with the same name or if the Constructor is
|
||||
// nil, it panics.
|
||||
func Register(name string, driver Driver) {
|
||||
if driver == nil {
|
||||
panic("database: could not register nil Driver")
|
||||
}
|
||||
if _, dup := drivers[name]; dup {
|
||||
panic("database: could not register duplicate Driver: " + name)
|
||||
}
|
||||
drivers[name] = driver
|
||||
}
|
||||
|
||||
// Open opens a Datastore specified by a configuration.
|
||||
func Open(cfg config.RegistrableComponentConfig) (Datastore, error) {
|
||||
driver, ok := drivers[cfg.Type]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("database: unknown Driver %q (forgotten configuration or import?)", cfg.Type)
|
||||
}
|
||||
return driver(cfg)
|
||||
}
|
||||
|
||||
// Datastore is the interface that describes a database backend implementation.
|
||||
type Datastore interface {
|
||||
// # Namespace
|
||||
|
@ -23,11 +23,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/utils"
|
||||
"github.com/coreos/clair/utils/types"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -36,7 +37,7 @@ const (
|
||||
)
|
||||
|
||||
func TestRaceAffects(t *testing.T) {
|
||||
datastore, err := OpenForTest("RaceAffects", false)
|
||||
datastore, err := openDatabaseForTest("RaceAffects", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -17,13 +17,14 @@ package pgsql
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/utils/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestInsertFeature(t *testing.T) {
|
||||
datastore, err := OpenForTest("InsertFeature", false)
|
||||
datastore, err := openDatabaseForTest("InsertFeature", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -21,7 +21,7 @@ import (
|
||||
)
|
||||
|
||||
func TestKeyValue(t *testing.T) {
|
||||
datastore, err := OpenForTest("KeyValue", false)
|
||||
datastore, err := openDatabaseForTest("KeyValue", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -41,9 +41,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
|
||||
var namespaceName sql.NullString
|
||||
|
||||
t := time.Now()
|
||||
err := pgSQL.QueryRow(searchLayer, name).
|
||||
Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID,
|
||||
&namespaceName)
|
||||
err := pgSQL.QueryRow(searchLayer, name).Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, &namespaceName)
|
||||
observeQueryTime("FindLayer", "searchLayer", t)
|
||||
|
||||
if err != nil {
|
||||
|
@ -18,14 +18,15 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
cerrors "github.com/coreos/clair/utils/errors"
|
||||
"github.com/coreos/clair/utils/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFindLayer(t *testing.T) {
|
||||
datastore, err := OpenForTest("FindLayer", true)
|
||||
datastore, err := openDatabaseForTest("FindLayer", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@ -102,7 +103,7 @@ func TestFindLayer(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInsertLayer(t *testing.T) {
|
||||
datastore, err := OpenForTest("InsertLayer", false)
|
||||
datastore, err := openDatabaseForTest("InsertLayer", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -22,7 +22,7 @@ import (
|
||||
)
|
||||
|
||||
func TestLock(t *testing.T) {
|
||||
datastore, err := OpenForTest("InsertNamespace", false)
|
||||
datastore, err := openDatabaseForTest("InsertNamespace", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -18,12 +18,13 @@ import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
func TestInsertNamespace(t *testing.T) {
|
||||
datastore, err := OpenForTest("InsertNamespace", false)
|
||||
datastore, err := openDatabaseForTest("InsertNamespace", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@ -44,7 +45,7 @@ func TestInsertNamespace(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListNamespace(t *testing.T) {
|
||||
datastore, err := OpenForTest("ListNamespaces", true)
|
||||
datastore, err := openDatabaseForTest("ListNamespaces", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -4,14 +4,15 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
cerrors "github.com/coreos/clair/utils/errors"
|
||||
"github.com/coreos/clair/utils/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNotification(t *testing.T) {
|
||||
datastore, err := OpenForTest("Notification", false)
|
||||
datastore, err := openDatabaseForTest("Notification", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
@ -19,22 +19,23 @@ import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"net/url"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"bitbucket.org/liamstask/goose/lib/goose"
|
||||
"github.com/coreos/pkg/capnslog"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/lib/pq"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/coreos/clair/config"
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/utils"
|
||||
cerrors "github.com/coreos/clair/utils/errors"
|
||||
"github.com/coreos/pkg/capnslog"
|
||||
"github.com/hashicorp/golang-lru"
|
||||
"github.com/lib/pq"
|
||||
"github.com/pborman/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -72,6 +73,8 @@ func init() {
|
||||
prometheus.MustRegister(promCacheQueriesTotal)
|
||||
prometheus.MustRegister(promQueryDurationMilliseconds)
|
||||
prometheus.MustRegister(promConcurrentLockVAFV)
|
||||
|
||||
database.Register("pgsql", openDatabase)
|
||||
}
|
||||
|
||||
type Queryer interface {
|
||||
@ -82,45 +85,136 @@ type Queryer interface {
|
||||
type pgSQL struct {
|
||||
*sql.DB
|
||||
cache *lru.ARCCache
|
||||
config Config
|
||||
}
|
||||
|
||||
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
|
||||
// the configuration.
|
||||
func (pgSQL *pgSQL) Close() {
|
||||
if pgSQL.DB != nil {
|
||||
pgSQL.DB.Close()
|
||||
}
|
||||
|
||||
if pgSQL.config.ManageDatabaseLifecycle {
|
||||
dbName, pgSourceURL, _ := parseConnectionString(pgSQL.config.Source)
|
||||
dropDatabase(pgSourceURL, dbName)
|
||||
}
|
||||
}
|
||||
|
||||
// Ping verifies that the database is accessible.
|
||||
func (pgSQL *pgSQL) Ping() bool {
|
||||
return pgSQL.DB.Ping() == nil
|
||||
}
|
||||
|
||||
// Open creates a Datastore backed by a PostgreSQL database.
|
||||
//
|
||||
// It will run immediately every necessary migration on the database.
|
||||
func Open(config *config.DatabaseConfig) (database.Datastore, error) {
|
||||
// Run migrations.
|
||||
if err := migrate(config.Source); err != nil {
|
||||
log.Error(err)
|
||||
return nil, database.ErrCantOpen
|
||||
// Config is the configuration that is used by openDatabase.
|
||||
type Config struct {
|
||||
Source string
|
||||
CacheSize int
|
||||
|
||||
ManageDatabaseLifecycle bool
|
||||
FixturePath string
|
||||
}
|
||||
|
||||
// openDatabase opens a PostgresSQL-backed Datastore using the given configuration.
|
||||
// It immediately every necessary migrations. If ManageDatabaseLifecycle is specified,
|
||||
// the database will be created first. If FixturePath is specified, every SQL queries that are
|
||||
// present insides will be executed.
|
||||
func openDatabase(registrableComponentConfig config.RegistrableComponentConfig) (database.Datastore, error) {
|
||||
var pg pgSQL
|
||||
var err error
|
||||
|
||||
// Parse configuration.
|
||||
pg.config = Config{
|
||||
CacheSize: 16384,
|
||||
}
|
||||
bytes, err := yaml.Marshal(registrableComponentConfig.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
err = yaml.Unmarshal(bytes, &pg.config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
|
||||
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create database.
|
||||
if pg.config.ManageDatabaseLifecycle {
|
||||
log.Info("pgsql: creating database")
|
||||
if err := createDatabase(pgSourceURL, dbName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Open database.
|
||||
db, err := sql.Open("postgres", config.Source)
|
||||
pg.DB, err = sql.Open("postgres", pg.config.Source)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
return nil, database.ErrCantOpen
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
|
||||
}
|
||||
|
||||
// Verify database state.
|
||||
if err := pg.DB.Ping(); err != nil {
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
|
||||
}
|
||||
|
||||
// Run migrations.
|
||||
if err := migrate(pg.config.Source); err != nil {
|
||||
pg.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load fixture data.
|
||||
if pg.config.FixturePath != "" {
|
||||
log.Info("pgsql: loading fixtures")
|
||||
|
||||
d, err := ioutil.ReadFile(pg.config.FixturePath)
|
||||
if err != nil {
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: could not open fixture file: %v", err)
|
||||
}
|
||||
|
||||
_, err = pg.DB.Exec(string(d))
|
||||
if err != nil {
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize cache.
|
||||
// TODO(Quentin-M): Benchmark with a simple LRU Cache.
|
||||
var cache *lru.ARCCache
|
||||
if config.CacheSize > 0 {
|
||||
cache, _ = lru.NewARC(config.CacheSize)
|
||||
if pg.config.CacheSize > 0 {
|
||||
pg.cache, _ = lru.NewARC(pg.config.CacheSize)
|
||||
}
|
||||
|
||||
return &pgSQL{DB: db, cache: cache}, nil
|
||||
return &pg, nil
|
||||
}
|
||||
|
||||
func parseConnectionString(source string) (dbName string, pgSourceURL string, err error) {
|
||||
if source == "" {
|
||||
return "", "", cerrors.NewBadRequestError("pgsql: no database connection string specified")
|
||||
}
|
||||
|
||||
sourceURL, err := url.Parse(source)
|
||||
if err != nil {
|
||||
return "", "", cerrors.NewBadRequestError("pgsql: database connection string is not a valid URL")
|
||||
}
|
||||
|
||||
dbName = strings.TrimPrefix(sourceURL.Path, "/")
|
||||
|
||||
pgSource := *sourceURL
|
||||
pgSource.Path = "/postgres"
|
||||
pgSourceURL = pgSource.String()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// migrate runs all available migrations on a pgSQL database.
|
||||
func migrate(dataSource string) error {
|
||||
func migrate(source string) error {
|
||||
log.Info("running database migrations")
|
||||
|
||||
_, filename, _, _ := runtime.Caller(1)
|
||||
@ -129,7 +223,7 @@ func migrate(dataSource string) error {
|
||||
MigrationsDir: migrationDir,
|
||||
Driver: goose.DBDriver{
|
||||
Name: "postgres",
|
||||
OpenStr: dataSource,
|
||||
OpenStr: source,
|
||||
Import: "github.com/lib/pq",
|
||||
Dialect: &goose.PostgresDialect{},
|
||||
},
|
||||
@ -138,13 +232,13 @@ func migrate(dataSource string) error {
|
||||
// Determine the most recent revision available from the migrations folder.
|
||||
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("pgsql: could not get most recent migration: %v", err)
|
||||
}
|
||||
|
||||
// Run migrations
|
||||
// Run migrations.
|
||||
err = goose.RunMigrations(conf, conf.MigrationsDir, target)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("pgsql: an error occured while running migrations: %v", err)
|
||||
}
|
||||
|
||||
log.Info("database migration ran successfully")
|
||||
@ -152,109 +246,51 @@ func migrate(dataSource string) error {
|
||||
}
|
||||
|
||||
// createDatabase creates a new database.
|
||||
// The dataSource parameter should not contain a dbname.
|
||||
func createDatabase(dataSource, databaseName string) error {
|
||||
// The source parameter should not contain a dbname.
|
||||
func createDatabase(source, dbName string) error {
|
||||
// Open database.
|
||||
db, err := sql.Open("postgres", dataSource)
|
||||
db, err := sql.Open("postgres", source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open database (CreateDatabase): %v", err)
|
||||
return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create database.
|
||||
_, err = db.Exec("CREATE DATABASE " + databaseName)
|
||||
_, err = db.Exec("CREATE DATABASE " + dbName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not create database: %v", err)
|
||||
return fmt.Errorf("pgsql: could not create database: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// dropDatabase drops an existing database.
|
||||
// The dataSource parameter should not contain a dbname.
|
||||
func dropDatabase(dataSource, databaseName string) error {
|
||||
// The source parameter should not contain a dbname.
|
||||
func dropDatabase(source, dbName string) error {
|
||||
// Open database.
|
||||
db, err := sql.Open("postgres", dataSource)
|
||||
db, err := sql.Open("postgres", source)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not open database (DropDatabase): %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Kill any opened connection.
|
||||
if _, err := db.Exec(`
|
||||
if _, err = db.Exec(`
|
||||
SELECT pg_terminate_backend(pg_stat_activity.pid)
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = $1
|
||||
AND pid <> pg_backend_pid()`, databaseName); err != nil {
|
||||
AND pid <> pg_backend_pid()`, dbName); err != nil {
|
||||
return fmt.Errorf("could not drop database: %v", err)
|
||||
}
|
||||
|
||||
// Drop database.
|
||||
if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil {
|
||||
if _, err = db.Exec("DROP DATABASE " + dbName); err != nil {
|
||||
return fmt.Errorf("could not drop database: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// pgSQLTest wraps pgSQL for testing purposes.
|
||||
// Its Close() method drops the database.
|
||||
type pgSQLTest struct {
|
||||
*pgSQL
|
||||
dataSourceDefaultDatabase string
|
||||
dbName string
|
||||
}
|
||||
|
||||
// OpenForTest creates a test Datastore backed by a new PostgreSQL database.
|
||||
// It creates a new unique and prefixed ("test_") database.
|
||||
// Using Close() will drop the database.
|
||||
func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
|
||||
// Define the PostgreSQL connection strings.
|
||||
dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname="
|
||||
if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" {
|
||||
dataSource = dataSourceEnv + " dbname="
|
||||
}
|
||||
dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
|
||||
dataSourceDefaultDatabase := dataSource + "postgres"
|
||||
dataSourceTestDatabase := dataSource + dbName
|
||||
|
||||
// Create database.
|
||||
if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil {
|
||||
log.Error(err)
|
||||
return nil, database.ErrCantOpen
|
||||
}
|
||||
|
||||
// Open database.
|
||||
db, err := Open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0})
|
||||
if err != nil {
|
||||
dropDatabase(dataSourceDefaultDatabase, dbName)
|
||||
log.Error(err)
|
||||
return nil, database.ErrCantOpen
|
||||
}
|
||||
|
||||
// Load test data if specified.
|
||||
if withTestData {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql")
|
||||
_, err = db.(*pgSQL).Exec(string(d))
|
||||
if err != nil {
|
||||
dropDatabase(dataSourceDefaultDatabase, dbName)
|
||||
log.Error(err)
|
||||
return nil, database.ErrCantOpen
|
||||
}
|
||||
}
|
||||
|
||||
return &pgSQLTest{
|
||||
pgSQL: db.(*pgSQL),
|
||||
dataSourceDefaultDatabase: dataSourceDefaultDatabase,
|
||||
dbName: dbName}, nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQLTest) Close() {
|
||||
pgSQL.DB.Close()
|
||||
dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName)
|
||||
}
|
||||
|
||||
// handleError logs an error with an extra description and masks the error if it's an SQL one.
|
||||
// This ensures we never return plain SQL errors and leak anything.
|
||||
func handleError(desc string, err error) error {
|
||||
|
45
database/pgsql/pgsql_test.go
Normal file
45
database/pgsql/pgsql_test.go
Normal file
@ -0,0 +1,45 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/coreos/clair/config"
|
||||
"github.com/pborman/uuid"
|
||||
)
|
||||
|
||||
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
|
||||
ds, err := openDatabase(generateTestConfig(testName, loadFixture))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datastore := ds.(*pgSQL)
|
||||
return datastore, nil
|
||||
}
|
||||
|
||||
func generateTestConfig(testName string, loadFixture bool) config.RegistrableComponentConfig {
|
||||
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
|
||||
|
||||
var fixturePath string
|
||||
if loadFixture {
|
||||
_, filename, _, _ := runtime.Caller(0)
|
||||
fixturePath = path.Join(path.Dir(filename)) + "/testdata/data.sql"
|
||||
}
|
||||
|
||||
source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName)
|
||||
if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" {
|
||||
source = fmt.Sprintf(sourceEnv, dbName)
|
||||
}
|
||||
|
||||
return config.RegistrableComponentConfig{
|
||||
Options: map[string]interface{}{
|
||||
"source": source,
|
||||
"cachesize": 0,
|
||||
"managedatabaselifecycle": true,
|
||||
"fixturepath": fixturePath,
|
||||
},
|
||||
}
|
||||
}
|
@ -18,14 +18,15 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
cerrors "github.com/coreos/clair/utils/errors"
|
||||
"github.com/coreos/clair/utils/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFindVulnerability(t *testing.T) {
|
||||
datastore, err := OpenForTest("FindVulnerability", true)
|
||||
datastore, err := openDatabaseForTest("FindVulnerability", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@ -75,7 +76,7 @@ func TestFindVulnerability(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteVulnerability(t *testing.T) {
|
||||
datastore, err := OpenForTest("InsertVulnerability", true)
|
||||
datastore, err := openDatabaseForTest("InsertVulnerability", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
@ -97,7 +98,7 @@ func TestDeleteVulnerability(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestInsertVulnerability(t *testing.T) {
|
||||
datastore, err := OpenForTest("InsertVulnerability", false)
|
||||
datastore, err := openDatabaseForTest("InsertVulnerability", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
Loading…
Reference in New Issue
Block a user