Merge pull request #165 from Quentin-M/db_registration

Allow specifying datastore driver by config, relocate upgrade detection, mock datastore
This commit is contained in:
Quentin Machu 2016-05-20 12:20:26 -05:00
commit a03459d02e
26 changed files with 635 additions and 368 deletions

View File

@ -26,7 +26,7 @@ import (
"github.com/coreos/clair/api" "github.com/coreos/clair/api"
"github.com/coreos/clair/api/context" "github.com/coreos/clair/api/context"
"github.com/coreos/clair/config" "github.com/coreos/clair/config"
"github.com/coreos/clair/database/pgsql" "github.com/coreos/clair/database"
"github.com/coreos/clair/notifier" "github.com/coreos/clair/notifier"
"github.com/coreos/clair/updater" "github.com/coreos/clair/updater"
"github.com/coreos/clair/utils" "github.com/coreos/clair/utils"
@ -42,7 +42,7 @@ func Boot(config *config.Config) {
st := utils.NewStopper() st := utils.NewStopper()
// Open database // Open database
db, err := pgsql.Open(config.Database) db, err := database.Open(config.Database)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }

View File

@ -20,11 +20,11 @@ import (
"runtime/pprof" "runtime/pprof"
"strings" "strings"
"github.com/coreos/pkg/capnslog"
"github.com/coreos/clair" "github.com/coreos/clair"
"github.com/coreos/clair/config" "github.com/coreos/clair/config"
"github.com/coreos/pkg/capnslog"
// Register components // Register components
_ "github.com/coreos/clair/notifier/notifiers" _ "github.com/coreos/clair/notifier/notifiers"
@ -43,6 +43,8 @@ import (
_ "github.com/coreos/clair/worker/detectors/namespace/lsbrelease" _ "github.com/coreos/clair/worker/detectors/namespace/lsbrelease"
_ "github.com/coreos/clair/worker/detectors/namespace/osrelease" _ "github.com/coreos/clair/worker/detectors/namespace/osrelease"
_ "github.com/coreos/clair/worker/detectors/namespace/redhatrelease" _ "github.com/coreos/clair/worker/detectors/namespace/redhatrelease"
_ "github.com/coreos/clair/database/pgsql"
) )
var log = capnslog.NewPackageLogger("github.com/coreos/clair/cmd/clair", "main") var log = capnslog.NewPackageLogger("github.com/coreos/clair/cmd/clair", "main")

View File

@ -15,13 +15,16 @@
# The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined. # The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined.
clair: clair:
database: database:
# PostgreSQL Connection string # Database driver
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html type: pgsql
source: options:
# PostgreSQL Connection string
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html
source:
# Number of elements kept in the cache # Number of elements kept in the cache
# Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database.
cacheSize: 16384 cachesize: 16384
api: api:
# API server port # API server port
@ -37,7 +40,7 @@ clair:
# 32-bit URL-safe base64 key used to encrypt pagination tokens # 32-bit URL-safe base64 key used to encrypt pagination tokens
# If one is not provided, it will be generated. # If one is not provided, it will be generated.
# Multiple clair instances in the same cluster need the same value. # Multiple clair instances in the same cluster need the same value.
paginationKey: paginationkey:
# Optional PKI configuration # Optional PKI configuration
# If you want to easily generate client certificates and CAs, try the following projects: # If you want to easily generate client certificates and CAs, try the following projects:
@ -58,7 +61,7 @@ clair:
attempts: 3 attempts: 3
# Duration before a failed notification is retried # Duration before a failed notification is retried
renotifyInterval: 2h renotifyinterval: 2h
http: http:
# Optional endpoint that will receive notifications via POST requests # Optional endpoint that will receive notifications via POST requests

View File

@ -27,6 +27,14 @@ import (
// ErrDatasourceNotLoaded is returned when the datasource variable in the configuration file is not loaded properly // ErrDatasourceNotLoaded is returned when the datasource variable in the configuration file is not loaded properly
var ErrDatasourceNotLoaded = errors.New("could not load configuration: no database source specified") var ErrDatasourceNotLoaded = errors.New("could not load configuration: no database source specified")
// RegistrableComponentConfig is a configuration block that can be used to
// determine which registrable component should be initialized and pass
// custom configuration to it.
type RegistrableComponentConfig struct {
Type string
Options map[string]interface{}
}
// File represents a YAML configuration file that namespaces all Clair // File represents a YAML configuration file that namespaces all Clair
// configuration under the top-level "clair" key. // configuration under the top-level "clair" key.
type File struct { type File struct {
@ -35,19 +43,12 @@ type File struct {
// Config is the global configuration for an instance of Clair. // Config is the global configuration for an instance of Clair.
type Config struct { type Config struct {
Database *DatabaseConfig Database RegistrableComponentConfig
Updater *UpdaterConfig Updater *UpdaterConfig
Notifier *NotifierConfig Notifier *NotifierConfig
API *APIConfig API *APIConfig
} }
// DatabaseConfig is the configuration used to specify how Clair connects
// to a database.
type DatabaseConfig struct {
Source string
CacheSize int
}
// UpdaterConfig is the configuration for the Updater service. // UpdaterConfig is the configuration for the Updater service.
type UpdaterConfig struct { type UpdaterConfig struct {
Interval time.Duration Interval time.Duration
@ -72,8 +73,8 @@ type APIConfig struct {
// DefaultConfig is a configuration that can be used as a fallback value. // DefaultConfig is a configuration that can be used as a fallback value.
func DefaultConfig() Config { func DefaultConfig() Config {
return Config{ return Config{
Database: &DatabaseConfig{ Database: RegistrableComponentConfig{
CacheSize: 16384, Type: "pgsql",
}, },
Updater: &UpdaterConfig{ Updater: &UpdaterConfig{
Interval: 1 * time.Hour, Interval: 1 * time.Hour,
@ -116,11 +117,6 @@ func Load(path string) (config *Config, err error) {
} }
config = &cfgFile.Clair config = &cfgFile.Clair
if config.Database.Source == "" {
err = ErrDatasourceNotLoaded
return
}
// Generate a pagination key if none is provided. // Generate a pagination key if none is provided.
if config.API.PaginationKey == "" { if config.API.PaginationKey == "" {
var key fernet.Key var key fernet.Key

View File

@ -1,81 +0,0 @@
package config
import (
"io/ioutil"
"log"
"os"
"testing"
"github.com/stretchr/testify/assert"
)
const wrongConfig = `
dummyKey:
wrong:true
`
const goodConfig = `
clair:
database:
source: postgresql://postgres:root@postgres:5432?sslmode=disable
cacheSize: 16384
api:
port: 6060
healthport: 6061
timeout: 900s
paginationKey:
servername:
cafile:
keyfile:
certfile:
updater:
interval: 2h
notifier:
attempts: 3
renotifyInterval: 2h
http:
endpoint:
servername:
cafile:
keyfile:
certfile:
proxy:
`
func TestLoadWrongConfiguration(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "clair-config")
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name()) // clean up
if _, err := tmpfile.Write([]byte(wrongConfig)); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
_, err = Load(tmpfile.Name())
assert.EqualError(t, err, ErrDatasourceNotLoaded.Error())
}
func TestLoad(t *testing.T) {
tmpfile, err := ioutil.TempFile("", "clair-config")
if err != nil {
log.Fatal(err)
}
defer os.Remove(tmpfile.Name()) // clean up
if _, err := tmpfile.Write([]byte(goodConfig)); err != nil {
log.Fatal(err)
}
if err := tmpfile.Close(); err != nil {
log.Fatal(err)
}
_, err = Load(tmpfile.Name())
assert.NoError(t, err)
}

View File

@ -17,7 +17,10 @@ package database
import ( import (
"errors" "errors"
"fmt"
"time" "time"
"github.com/coreos/clair/config"
) )
var ( var (
@ -28,11 +31,37 @@ var (
// ErrInconsistent is an error that occurs when a database consistency check // ErrInconsistent is an error that occurs when a database consistency check
// fails (ie. when an entity which is supposed to be unique is detected twice) // fails (ie. when an entity which is supposed to be unique is detected twice)
ErrInconsistent = errors.New("database: inconsistent database") ErrInconsistent = errors.New("database: inconsistent database")
// ErrCantOpen is an error that occurs when the database could not be opened
ErrCantOpen = errors.New("database: could not open database")
) )
var drivers = make(map[string]Driver)
// Driver is a function that opens a Datastore specified by its database driver type and specific
// configuration.
type Driver func(config.RegistrableComponentConfig) (Datastore, error)
// Register makes a Constructor available by the provided name.
//
// If this function is called twice with the same name or if the Constructor is
// nil, it panics.
func Register(name string, driver Driver) {
if driver == nil {
panic("database: could not register nil Driver")
}
if _, dup := drivers[name]; dup {
panic("database: could not register duplicate Driver: " + name)
}
drivers[name] = driver
}
// Open opens a Datastore specified by a configuration.
func Open(cfg config.RegistrableComponentConfig) (Datastore, error) {
driver, ok := drivers[cfg.Type]
if !ok {
return nil, fmt.Errorf("database: unknown Driver %q (forgotten configuration or import?)", cfg.Type)
}
return driver(cfg)
}
// Datastore is the interface that describes a database backend implementation. // Datastore is the interface that describes a database backend implementation.
type Datastore interface { type Datastore interface {
// # Namespace // # Namespace

191
database/mock.go Normal file
View File

@ -0,0 +1,191 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package database
import "time"
// MockDatastore implements Datastore and enables overriding each available method.
// The default behavior of each method is to simply panic.
type MockDatastore struct {
FctListNamespaces func() ([]Namespace, error)
FctInsertLayer func(Layer) error
FctFindLayer func(name string, withFeatures, withVulnerabilities bool) (Layer, error)
FctDeleteLayer func(name string) error
FctListVulnerabilities func(namespaceName string, limit int, page int) ([]Vulnerability, int, error)
FctInsertVulnerabilities func(vulnerabilities []Vulnerability, createNotification bool) error
FctFindVulnerability func(namespaceName, name string) (Vulnerability, error)
FctDeleteVulnerability func(namespaceName, name string) error
FctInsertVulnerabilityFixes func(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error
FctDeleteVulnerabilityFix func(vulnerabilityNamespace, vulnerabilityName, featureName string) error
FctGetAvailableNotification func(renotifyInterval time.Duration) (VulnerabilityNotification, error)
FctGetNotification func(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error)
FctSetNotificationNotified func(name string) error
FctDeleteNotification func(name string) error
FctInsertKeyValue func(key, value string) error
FctGetKeyValue func(key string) (string, error)
FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time)
FctUnlock func(name, owner string)
FctFindLock func(name string) (string, time.Time, error)
FctPing func() bool
FctClose func()
}
func (mds *MockDatastore) ListNamespaces() ([]Namespace, error) {
if mds.FctListNamespaces != nil {
return mds.FctListNamespaces()
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) InsertLayer(layer Layer) error {
if mds.FctInsertLayer != nil {
return mds.FctInsertLayer(layer)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) FindLayer(name string, withFeatures, withVulnerabilities bool) (Layer, error) {
if mds.FctFindLayer != nil {
return mds.FctFindLayer(name, withFeatures, withVulnerabilities)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) DeleteLayer(name string) error {
if mds.FctDeleteLayer != nil {
return mds.FctDeleteLayer(name)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) ListVulnerabilities(namespaceName string, limit int, page int) ([]Vulnerability, int, error) {
if mds.FctListVulnerabilities != nil {
return mds.FctListVulnerabilities(namespaceName, limit, page)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) InsertVulnerabilities(vulnerabilities []Vulnerability, createNotification bool) error {
if mds.FctInsertVulnerabilities != nil {
return mds.FctInsertVulnerabilities(vulnerabilities, createNotification)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) FindVulnerability(namespaceName, name string) (Vulnerability, error) {
if mds.FctFindVulnerability != nil {
return mds.FctFindVulnerability(namespaceName, name)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) DeleteVulnerability(namespaceName, name string) error {
if mds.FctDeleteVulnerability != nil {
return mds.FctDeleteVulnerability(namespaceName, name)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error {
if mds.FctInsertVulnerabilityFixes != nil {
return mds.FctInsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName, fixes)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error {
if mds.FctDeleteVulnerabilityFix != nil {
return mds.FctDeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) {
if mds.FctGetAvailableNotification != nil {
return mds.FctGetAvailableNotification(renotifyInterval)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) {
if mds.FctGetNotification != nil {
return mds.FctGetNotification(name, limit, page)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) SetNotificationNotified(name string) error {
if mds.FctSetNotificationNotified != nil {
return mds.FctSetNotificationNotified(name)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) DeleteNotification(name string) error {
if mds.FctDeleteNotification != nil {
return mds.FctDeleteNotification(name)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) InsertKeyValue(key, value string) error {
if mds.FctInsertKeyValue != nil {
return mds.FctInsertKeyValue(key, value)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) GetKeyValue(key string) (string, error) {
if mds.FctGetKeyValue != nil {
return mds.FctGetKeyValue(key)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) {
if mds.FctLock != nil {
return mds.FctLock(name, owner, duration, renew)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) Unlock(name, owner string) {
if mds.FctUnlock != nil {
mds.FctUnlock(name, owner)
return
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) FindLock(name string) (string, time.Time, error) {
if mds.FctFindLock != nil {
return mds.FctFindLock(name)
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) Ping() bool {
if mds.FctPing != nil {
return mds.FctPing()
}
panic("required mock function not implemented")
}
func (mds *MockDatastore) Close() {
if mds.FctClose != nil {
mds.FctClose()
return
}
panic("required mock function not implemented")
}

View File

@ -23,11 +23,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/utils" "github.com/coreos/clair/utils"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
) )
const ( const (
@ -36,7 +37,7 @@ const (
) )
func TestRaceAffects(t *testing.T) { func TestRaceAffects(t *testing.T) {
datastore, err := OpenForTest("RaceAffects", false) datastore, err := openDatabaseForTest("RaceAffects", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -17,13 +17,14 @@ package pgsql
import ( import (
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestInsertFeature(t *testing.T) { func TestInsertFeature(t *testing.T) {
datastore, err := OpenForTest("InsertFeature", false) datastore, err := openDatabaseForTest("InsertFeature", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -21,7 +21,7 @@ import (
) )
func TestKeyValue(t *testing.T) { func TestKeyValue(t *testing.T) {
datastore, err := OpenForTest("KeyValue", false) datastore, err := openDatabaseForTest("KeyValue", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -41,9 +41,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
var namespaceName sql.NullString var namespaceName sql.NullString
t := time.Now() t := time.Now()
err := pgSQL.QueryRow(searchLayer, name). err := pgSQL.QueryRow(searchLayer, name).Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, &namespaceName)
Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID,
&namespaceName)
observeQueryTime("FindLayer", "searchLayer", t) observeQueryTime("FindLayer", "searchLayer", t)
if err != nil { if err != nil {
@ -335,7 +333,7 @@ func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *
addNV := utils.CompareStringLists(layerFeaturesNV, parentLayerFeaturesNV) addNV := utils.CompareStringLists(layerFeaturesNV, parentLayerFeaturesNV)
delNV := utils.CompareStringLists(parentLayerFeaturesNV, layerFeaturesNV) delNV := utils.CompareStringLists(parentLayerFeaturesNV, layerFeaturesNV)
// Fill the structures containing the added and deleted FeatureVersions // Fill the structures containing the added and deleted FeatureVersions.
for _, nv := range addNV { for _, nv := range addNV {
add = append(add, *layerFeaturesMapNV[nv]) add = append(add, *layerFeaturesMapNV[nv])
} }
@ -377,7 +375,7 @@ func createNV(features []database.FeatureVersion) (map[string]*database.FeatureV
for i := 0; i < len(features); i++ { for i := 0; i < len(features); i++ {
featureVersion := &features[i] featureVersion := &features[i]
nv := featureVersion.Feature.Name + ":" + featureVersion.Version.String() nv := featureVersion.Feature.Namespace.Name + ":" + featureVersion.Feature.Name + ":" + featureVersion.Version.String()
mapNV[nv] = featureVersion mapNV[nv] = featureVersion
sliceNV = append(sliceNV, nv) sliceNV = append(sliceNV, nv)
} }

View File

@ -18,14 +18,15 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestFindLayer(t *testing.T) { func TestFindLayer(t *testing.T) {
datastore, err := OpenForTest("FindLayer", true) datastore, err := openDatabaseForTest("FindLayer", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -102,7 +103,7 @@ func TestFindLayer(t *testing.T) {
} }
func TestInsertLayer(t *testing.T) { func TestInsertLayer(t *testing.T) {
datastore, err := OpenForTest("InsertLayer", false) datastore, err := openDatabaseForTest("InsertLayer", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -169,7 +170,7 @@ func testInsertLayerTree(t *testing.T, datastore database.Datastore) {
Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"}, Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"},
Name: "TestInsertLayerFeature3", Name: "TestInsertLayerFeature3",
}, },
Version: types.NewVersionUnsafe("0.57"), Version: types.NewVersionUnsafe("0.56"),
} }
f6 := database.FeatureVersion{ f6 := database.FeatureVersion{
Feature: database.Feature{ Feature: database.Feature{

View File

@ -22,7 +22,7 @@ import (
) )
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
datastore, err := OpenForTest("InsertNamespace", false) datastore, err := openDatabaseForTest("InsertNamespace", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -18,12 +18,13 @@ import (
"fmt" "fmt"
"testing" "testing"
"github.com/coreos/clair/database"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database"
) )
func TestInsertNamespace(t *testing.T) { func TestInsertNamespace(t *testing.T) {
datastore, err := OpenForTest("InsertNamespace", false) datastore, err := openDatabaseForTest("InsertNamespace", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -44,7 +45,7 @@ func TestInsertNamespace(t *testing.T) {
} }
func TestListNamespace(t *testing.T) { func TestListNamespace(t *testing.T) {
datastore, err := OpenForTest("ListNamespaces", true) datastore, err := openDatabaseForTest("ListNamespaces", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -1,3 +1,17 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql package pgsql
import ( import (

View File

@ -1,17 +1,32 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql package pgsql
import ( import (
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestNotification(t *testing.T) { func TestNotification(t *testing.T) {
datastore, err := OpenForTest("Notification", false) datastore, err := openDatabaseForTest("Notification", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -19,22 +19,23 @@ import (
"database/sql" "database/sql"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "net/url"
"path" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"time" "time"
"bitbucket.org/liamstask/goose/lib/goose" "bitbucket.org/liamstask/goose/lib/goose"
"github.com/coreos/pkg/capnslog"
"github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/prometheus/client_golang/prometheus"
"gopkg.in/yaml.v2"
"github.com/coreos/clair/config" "github.com/coreos/clair/config"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/utils" "github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/pkg/capnslog"
"github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus"
) )
var ( var (
@ -72,6 +73,8 @@ func init() {
prometheus.MustRegister(promCacheQueriesTotal) prometheus.MustRegister(promCacheQueriesTotal)
prometheus.MustRegister(promQueryDurationMilliseconds) prometheus.MustRegister(promQueryDurationMilliseconds)
prometheus.MustRegister(promConcurrentLockVAFV) prometheus.MustRegister(promConcurrentLockVAFV)
database.Register("pgsql", openDatabase)
} }
type Queryer interface { type Queryer interface {
@ -81,55 +84,146 @@ type Queryer interface {
type pgSQL struct { type pgSQL struct {
*sql.DB *sql.DB
cache *lru.ARCCache cache *lru.ARCCache
config Config
} }
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
// the configuration.
func (pgSQL *pgSQL) Close() { func (pgSQL *pgSQL) Close() {
pgSQL.DB.Close() if pgSQL.DB != nil {
pgSQL.DB.Close()
}
if pgSQL.config.ManageDatabaseLifecycle {
dbName, pgSourceURL, _ := parseConnectionString(pgSQL.config.Source)
dropDatabase(pgSourceURL, dbName)
}
} }
// Ping verifies that the database is accessible.
func (pgSQL *pgSQL) Ping() bool { func (pgSQL *pgSQL) Ping() bool {
return pgSQL.DB.Ping() == nil return pgSQL.DB.Ping() == nil
} }
// Open creates a Datastore backed by a PostgreSQL database. // Config is the configuration that is used by openDatabase.
// type Config struct {
// It will run immediately every necessary migration on the database. Source string
func Open(config *config.DatabaseConfig) (database.Datastore, error) { CacheSize int
// Run migrations.
if err := migrate(config.Source); err != nil { ManageDatabaseLifecycle bool
log.Error(err) FixturePath string
return nil, database.ErrCantOpen }
// openDatabase opens a PostgresSQL-backed Datastore using the given configuration.
// It immediately every necessary migrations. If ManageDatabaseLifecycle is specified,
// the database will be created first. If FixturePath is specified, every SQL queries that are
// present insides will be executed.
func openDatabase(registrableComponentConfig config.RegistrableComponentConfig) (database.Datastore, error) {
var pg pgSQL
var err error
// Parse configuration.
pg.config = Config{
CacheSize: 16384,
}
bytes, err := yaml.Marshal(registrableComponentConfig.Options)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
err = yaml.Unmarshal(bytes, &pg.config)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
if err != nil {
return nil, err
}
// Create database.
if pg.config.ManageDatabaseLifecycle {
log.Info("pgsql: creating database")
if err := createDatabase(pgSourceURL, dbName); err != nil {
return nil, err
}
} }
// Open database. // Open database.
db, err := sql.Open("postgres", config.Source) pg.DB, err = sql.Open("postgres", pg.config.Source)
if err != nil { if err != nil {
log.Error(err) pg.Close()
return nil, database.ErrCantOpen return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
// Verify database state.
if err := pg.DB.Ping(); err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
// Run migrations.
if err := migrate(pg.config.Source); err != nil {
pg.Close()
return nil, err
}
// Load fixture data.
if pg.config.FixturePath != "" {
log.Info("pgsql: loading fixtures")
d, err := ioutil.ReadFile(pg.config.FixturePath)
if err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open fixture file: %v", err)
}
_, err = pg.DB.Exec(string(d))
if err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err)
}
} }
// Initialize cache. // Initialize cache.
// TODO(Quentin-M): Benchmark with a simple LRU Cache. // TODO(Quentin-M): Benchmark with a simple LRU Cache.
var cache *lru.ARCCache if pg.config.CacheSize > 0 {
if config.CacheSize > 0 { pg.cache, _ = lru.NewARC(pg.config.CacheSize)
cache, _ = lru.NewARC(config.CacheSize)
} }
return &pgSQL{DB: db, cache: cache}, nil return &pg, nil
}
func parseConnectionString(source string) (dbName string, pgSourceURL string, err error) {
if source == "" {
return "", "", cerrors.NewBadRequestError("pgsql: no database connection string specified")
}
sourceURL, err := url.Parse(source)
if err != nil {
return "", "", cerrors.NewBadRequestError("pgsql: database connection string is not a valid URL")
}
dbName = strings.TrimPrefix(sourceURL.Path, "/")
pgSource := *sourceURL
pgSource.Path = "/postgres"
pgSourceURL = pgSource.String()
return
} }
// migrate runs all available migrations on a pgSQL database. // migrate runs all available migrations on a pgSQL database.
func migrate(dataSource string) error { func migrate(source string) error {
log.Info("running database migrations") log.Info("running database migrations")
_, filename, _, _ := runtime.Caller(1) _, filename, _, _ := runtime.Caller(1)
migrationDir := path.Join(path.Dir(filename), "/migrations/") migrationDir := filepath.Join(filepath.Dir(filename), "/migrations/")
conf := &goose.DBConf{ conf := &goose.DBConf{
MigrationsDir: migrationDir, MigrationsDir: migrationDir,
Driver: goose.DBDriver{ Driver: goose.DBDriver{
Name: "postgres", Name: "postgres",
OpenStr: dataSource, OpenStr: source,
Import: "github.com/lib/pq", Import: "github.com/lib/pq",
Dialect: &goose.PostgresDialect{}, Dialect: &goose.PostgresDialect{},
}, },
@ -138,13 +232,13 @@ func migrate(dataSource string) error {
// Determine the most recent revision available from the migrations folder. // Determine the most recent revision available from the migrations folder.
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir) target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
if err != nil { if err != nil {
return err return fmt.Errorf("pgsql: could not get most recent migration: %v", err)
} }
// Run migrations // Run migrations.
err = goose.RunMigrations(conf, conf.MigrationsDir, target) err = goose.RunMigrations(conf, conf.MigrationsDir, target)
if err != nil { if err != nil {
return err return fmt.Errorf("pgsql: an error occured while running migrations: %v", err)
} }
log.Info("database migration ran successfully") log.Info("database migration ran successfully")
@ -152,109 +246,51 @@ func migrate(dataSource string) error {
} }
// createDatabase creates a new database. // createDatabase creates a new database.
// The dataSource parameter should not contain a dbname. // The source parameter should not contain a dbname.
func createDatabase(dataSource, databaseName string) error { func createDatabase(source, dbName string) error {
// Open database. // Open database.
db, err := sql.Open("postgres", dataSource) db, err := sql.Open("postgres", source)
if err != nil { if err != nil {
return fmt.Errorf("could not open database (CreateDatabase): %v", err) return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
} }
defer db.Close() defer db.Close()
// Create database. // Create database.
_, err = db.Exec("CREATE DATABASE " + databaseName) _, err = db.Exec("CREATE DATABASE " + dbName)
if err != nil { if err != nil {
return fmt.Errorf("could not create database: %v", err) return fmt.Errorf("pgsql: could not create database: %v", err)
} }
return nil return nil
} }
// dropDatabase drops an existing database. // dropDatabase drops an existing database.
// The dataSource parameter should not contain a dbname. // The source parameter should not contain a dbname.
func dropDatabase(dataSource, databaseName string) error { func dropDatabase(source, dbName string) error {
// Open database. // Open database.
db, err := sql.Open("postgres", dataSource) db, err := sql.Open("postgres", source)
if err != nil { if err != nil {
return fmt.Errorf("could not open database (DropDatabase): %v", err) return fmt.Errorf("could not open database (DropDatabase): %v", err)
} }
defer db.Close() defer db.Close()
// Kill any opened connection. // Kill any opened connection.
if _, err := db.Exec(` if _, err = db.Exec(`
SELECT pg_terminate_backend(pg_stat_activity.pid) SELECT pg_terminate_backend(pg_stat_activity.pid)
FROM pg_stat_activity FROM pg_stat_activity
WHERE pg_stat_activity.datname = $1 WHERE pg_stat_activity.datname = $1
AND pid <> pg_backend_pid()`, databaseName); err != nil { AND pid <> pg_backend_pid()`, dbName); err != nil {
return fmt.Errorf("could not drop database: %v", err) return fmt.Errorf("could not drop database: %v", err)
} }
// Drop database. // Drop database.
if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil { if _, err = db.Exec("DROP DATABASE " + dbName); err != nil {
return fmt.Errorf("could not drop database: %v", err) return fmt.Errorf("could not drop database: %v", err)
} }
return nil return nil
} }
// pgSQLTest wraps pgSQL for testing purposes.
// Its Close() method drops the database.
type pgSQLTest struct {
*pgSQL
dataSourceDefaultDatabase string
dbName string
}
// OpenForTest creates a test Datastore backed by a new PostgreSQL database.
// It creates a new unique and prefixed ("test_") database.
// Using Close() will drop the database.
func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
// Define the PostgreSQL connection strings.
dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname="
if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" {
dataSource = dataSourceEnv + " dbname="
}
dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
dataSourceDefaultDatabase := dataSource + "postgres"
dataSourceTestDatabase := dataSource + dbName
// Create database.
if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil {
log.Error(err)
return nil, database.ErrCantOpen
}
// Open database.
db, err := Open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0})
if err != nil {
dropDatabase(dataSourceDefaultDatabase, dbName)
log.Error(err)
return nil, database.ErrCantOpen
}
// Load test data if specified.
if withTestData {
_, filename, _, _ := runtime.Caller(0)
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql")
_, err = db.(*pgSQL).Exec(string(d))
if err != nil {
dropDatabase(dataSourceDefaultDatabase, dbName)
log.Error(err)
return nil, database.ErrCantOpen
}
}
return &pgSQLTest{
pgSQL: db.(*pgSQL),
dataSourceDefaultDatabase: dataSourceDefaultDatabase,
dbName: dbName}, nil
}
func (pgSQL *pgSQLTest) Close() {
pgSQL.DB.Close()
dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName)
}
// handleError logs an error with an extra description and masks the error if it's an SQL one. // handleError logs an error with an extra description and masks the error if it's an SQL one.
// This ensures we never return plain SQL errors and leak anything. // This ensures we never return plain SQL errors and leak anything.
func handleError(desc string, err error) error { func handleError(desc string, err error) error {

View File

@ -0,0 +1,59 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"fmt"
"os"
"path/filepath"
"runtime"
"strings"
"github.com/coreos/clair/config"
"github.com/pborman/uuid"
)
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
ds, err := openDatabase(generateTestConfig(testName, loadFixture))
if err != nil {
return nil, err
}
datastore := ds.(*pgSQL)
return datastore, nil
}
func generateTestConfig(testName string, loadFixture bool) config.RegistrableComponentConfig {
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
var fixturePath string
if loadFixture {
_, filename, _, _ := runtime.Caller(0)
fixturePath = filepath.Join(filepath.Dir(filename)) + "/testdata/data.sql"
}
source := fmt.Sprintf("postgresql://postgres@127.0.0.1:5432/%s?sslmode=disable", dbName)
if sourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); sourceEnv != "" {
source = fmt.Sprintf(sourceEnv, dbName)
}
return config.RegistrableComponentConfig{
Options: map[string]interface{}{
"source": source,
"cachesize": 0,
"managedatabaselifecycle": true,
"fixturepath": fixturePath,
},
}
}

View File

@ -18,14 +18,15 @@ import (
"reflect" "reflect"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
) )
func TestFindVulnerability(t *testing.T) { func TestFindVulnerability(t *testing.T) {
datastore, err := OpenForTest("FindVulnerability", true) datastore, err := openDatabaseForTest("FindVulnerability", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -75,7 +76,7 @@ func TestFindVulnerability(t *testing.T) {
} }
func TestDeleteVulnerability(t *testing.T) { func TestDeleteVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", true) datastore, err := openDatabaseForTest("InsertVulnerability", true)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return
@ -97,7 +98,7 @@ func TestDeleteVulnerability(t *testing.T) {
} }
func TestInsertVulnerability(t *testing.T) { func TestInsertVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", false) datastore, err := openDatabaseForTest("InsertVulnerability", false)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return return

View File

@ -16,7 +16,7 @@ package debian
import ( import (
"os" "os"
"path" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@ -29,7 +29,7 @@ func TestDebianParser(t *testing.T) {
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
// Test parsing testdata/fetcher_debian_test.json // Test parsing testdata/fetcher_debian_test.json
testFile, _ := os.Open(path.Join(path.Dir(filename)) + "/testdata/fetcher_debian_test.json") testFile, _ := os.Open(filepath.Join(filepath.Dir(filename)) + "/testdata/fetcher_debian_test.json")
response, err := buildResponse(testFile, "") response, err := buildResponse(testFile, "")
if assert.Nil(t, err) && assert.Len(t, response.Vulnerabilities, 3) { if assert.Nil(t, err) && assert.Len(t, response.Vulnerabilities, 3) {
for _, vulnerability := range response.Vulnerabilities { for _, vulnerability := range response.Vulnerabilities {

View File

@ -16,7 +16,7 @@ package rhel
import ( import (
"os" "os"
"path" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@ -27,7 +27,7 @@ import (
func TestRHELParser(t *testing.T) { func TestRHELParser(t *testing.T) {
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
path := path.Join(path.Dir(filename)) path := filepath.Join(filepath.Dir(filename))
// Test parsing testdata/fetcher_rhel_test.1.xml // Test parsing testdata/fetcher_rhel_test.1.xml
testFile, _ := os.Open(path + "/testdata/fetcher_rhel_test.1.xml") testFile, _ := os.Open(path + "/testdata/fetcher_rhel_test.1.xml")

View File

@ -16,7 +16,7 @@ package ubuntu
import ( import (
"os" "os"
"path" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@ -27,7 +27,7 @@ import (
func TestUbuntuParser(t *testing.T) { func TestUbuntuParser(t *testing.T) {
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
path := path.Join(path.Dir(filename)) path := filepath.Join(filepath.Dir(filename))
// Test parsing testdata/fetcher_ // Test parsing testdata/fetcher_
testData, _ := os.Open(path + "/testdata/fetcher_ubuntu_test.txt") testData, _ := os.Open(path + "/testdata/fetcher_ubuntu_test.txt")

View File

@ -17,7 +17,7 @@ package utils
import ( import (
"bytes" "bytes"
"os" "os"
"path" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@ -65,10 +65,10 @@ func TestString(t *testing.T) {
func TestTar(t *testing.T) { func TestTar(t *testing.T) {
var err error var err error
var data map[string][]byte var data map[string][]byte
_, filepath, _, _ := runtime.Caller(0) _, path, _, _ := runtime.Caller(0)
testDataDir := "/testdata" testDataDir := "/testdata"
for _, filename := range []string{"utils_test.tar.gz", "utils_test.tar.bz2", "utils_test.tar.xz", "utils_test.tar"} { for _, filename := range []string{"utils_test.tar.gz", "utils_test.tar.bz2", "utils_test.tar.xz", "utils_test.tar"} {
testArchivePath := path.Join(path.Dir(filepath), testDataDir, filename) testArchivePath := filepath.Join(filepath.Dir(path), testDataDir, filename)
// Extract non compressed data // Extract non compressed data
data, err = SelectivelyExtractArchive(bytes.NewReader([]byte("that string does not represent a tar or tar-gzip file")), "", []string{}, 0) data, err = SelectivelyExtractArchive(bytes.NewReader([]byte("that string does not represent a tar or tar-gzip file")), "", []string{}, 0)

View File

@ -16,7 +16,7 @@ package feature
import ( import (
"io/ioutil" "io/ioutil"
"path" "path/filepath"
"runtime" "runtime"
"testing" "testing"
@ -32,7 +32,7 @@ type FeatureVersionTest struct {
func LoadFileForTest(name string) []byte { func LoadFileForTest(name string) []byte {
_, filename, _, _ := runtime.Caller(0) _, filename, _, _ := runtime.Caller(0)
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/" + name) d, _ := ioutil.ReadFile(filepath.Join(filepath.Dir(filename)) + "/" + name)
return d return d
} }

View File

@ -113,7 +113,7 @@ func Process(datastore database.Datastore, imageFormat, name, parentName, path s
} }
// detectContent downloads a layer's archive and extracts its Namespace and Features. // detectContent downloads a layer's archive and extracts its Namespace and Features.
func detectContent(imageFormat, name, path string, headers map[string]string, parent *database.Layer) (namespace *database.Namespace, features []database.FeatureVersion, err error) { func detectContent(imageFormat, name, path string, headers map[string]string, parent *database.Layer) (namespace *database.Namespace, featureVersions []database.FeatureVersion, err error) {
data, err := detectors.DetectData(imageFormat, path, headers, append(detectors.GetRequiredFilesFeatures(), detectors.GetRequiredFilesNamespace()...), maxFileSize) data, err := detectors.DetectData(imageFormat, path, headers, append(detectors.GetRequiredFilesFeatures(), detectors.GetRequiredFilesNamespace()...), maxFileSize)
if err != nil { if err != nil {
log.Errorf("layer %s: failed to extract data from %s: %s", name, utils.CleanURL(path), err) log.Errorf("layer %s: failed to extract data from %s: %s", name, utils.CleanURL(path), err)
@ -121,41 +121,33 @@ func detectContent(imageFormat, name, path string, headers map[string]string, pa
} }
// Detect namespace. // Detect namespace.
namespace, err = detectNamespace(data, parent) namespace = detectNamespace(name, data, parent)
if err != nil {
return
}
if namespace != nil {
log.Debugf("layer %s: Namespace is %s.", name, namespace.Name)
} else {
log.Debugf("layer %s: OS is unknown.", name)
}
// Detect features. // Detect features.
features, err = detectFeatures(name, data, namespace) featureVersions, err = detectFeatureVersions(name, data, namespace, parent)
if err != nil { if err != nil {
return return
} }
if len(featureVersions) > 0 {
// If there are no feature detected, use parent's features if possible. log.Debugf("layer %s: detected %d features", name, len(featureVersions))
// TODO(Quentin-M): We eventually want to give the choice to each detectors to use none/some
// parent's Features. It would be useful for detectors that can't find their entire result using
// one Layer.
if len(features) == 0 && parent != nil {
features = parent.Features
} }
log.Debugf("layer %s: detected %d features", name, len(features))
return return
} }
func detectNamespace(data map[string][]byte, parent *database.Layer) (namespace *database.Namespace, err error) { func detectNamespace(name string, data map[string][]byte, parent *database.Layer) (namespace *database.Namespace) {
// Use registered detectors to get the Namespace.
namespace = detectors.DetectNamespace(data) namespace = detectors.DetectNamespace(data)
if namespace != nil {
log.Debugf("layer %s: detected namespace %q", name, namespace.Name)
return
}
// Attempt to detect the OS from the parent layer. // Use the parent's Namespace.
if namespace == nil && parent != nil { if parent != nil {
namespace = parent.Namespace namespace = parent.Namespace
if err != nil { if namespace != nil {
log.Debugf("layer %s: detected namespace %q (from parent)", name, namespace.Name)
return return
} }
} }
@ -163,8 +155,8 @@ func detectNamespace(data map[string][]byte, parent *database.Layer) (namespace
return return
} }
func detectFeatures(name string, data map[string][]byte, namespace *database.Namespace) (features []database.FeatureVersion, err error) { func detectFeatureVersions(name string, data map[string][]byte, namespace *database.Namespace, parent *database.Layer) (features []database.FeatureVersion, err error) {
// TODO(Quentin-M): We need to pass the parent image DetectFeatures because it's possible that // TODO(Quentin-M): We need to pass the parent image to DetectFeatures because it's possible that
// some detectors would need it in order to produce the entire feature list (if they can only // some detectors would need it in order to produce the entire feature list (if they can only
// detect a diff). Also, we should probably pass the detected namespace so detectors could // detect a diff). Also, we should probably pass the detected namespace so detectors could
// make their own decision. // make their own decision.
@ -173,19 +165,46 @@ func detectFeatures(name string, data map[string][]byte, namespace *database.Nam
return return
} }
// Ensure that every feature has a Namespace associated, otherwise associate the detected // If there are no FeatureVersions, use parent's FeatureVersions if possible.
// namespace. If there is no detected namespace, we'll throw an error. // TODO(Quentin-M): We eventually want to give the choice to each detectors to use none/some of
for i := 0; i < len(features); i++ { // their parent's FeatureVersions. It would be useful for detectors that can't find their entire
if features[i].Feature.Namespace.Name == "" { // result using one Layer.
if namespace != nil { if len(features) == 0 && parent != nil {
features[i].Feature.Namespace = *namespace features = parent.Features
} else { return
log.Warningf("layer %s: Layer's namespace is unknown but non-namespaced features have been detected", name) }
err = ErrUnsupported
return // Build a map of the namespaces for each FeatureVersion in our parent layer.
} parentFeatureNamespaces := make(map[string]database.Namespace)
if parent != nil {
for _, parentFeature := range parent.Features {
parentFeatureNamespaces[parentFeature.Feature.Name+":"+parentFeature.Version.String()] = parentFeature.Feature.Namespace
} }
} }
// Ensure that each FeatureVersion has an associated Namespace.
for i, feature := range features {
if feature.Feature.Namespace.Name != "" {
// There is a Namespace associated.
continue
}
if parentFeatureNamespace, ok := parentFeatureNamespaces[feature.Feature.Name+":"+feature.Version.String()]; ok {
// The FeatureVersion is present in the parent layer; associate with their Namespace.
features[i].Feature.Namespace = parentFeatureNamespace
continue
}
if namespace != nil {
// The Namespace has been detected in this layer; associate it.
features[i].Feature.Namespace = *namespace
continue
}
log.Warningf("layer %s: Layer's namespace is unknown but non-namespaced features have been detected", name)
err = ErrUnsupported
return
}
return return
} }

View File

@ -15,15 +15,16 @@
package worker package worker
import ( import (
"path" "path/filepath"
"runtime" "runtime"
"testing" "testing"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
// Register the required detectors. // Register the required detectors.
_ "github.com/coreos/clair/worker/detectors/data/docker" _ "github.com/coreos/clair/worker/detectors/data/docker"
_ "github.com/coreos/clair/worker/detectors/feature/dpkg" _ "github.com/coreos/clair/worker/detectors/feature/dpkg"
@ -31,101 +32,81 @@ import (
_ "github.com/coreos/clair/worker/detectors/namespace/osrelease" _ "github.com/coreos/clair/worker/detectors/namespace/osrelease"
) )
func TestProcessWithDistUpgrade(t *testing.T) { type mockDatastore struct {
// TODO(Quentin-M): This should not be bound to a single database implementation. database.MockDatastore
datastore, err := pgsql.OpenForTest("ProcessWithDistUpgrade", false) layers map[string]database.Layer
if err != nil { }
t.Error(err)
return func newMockDatastore() *mockDatastore {
return &mockDatastore{
layers: make(map[string]database.Layer),
} }
defer datastore.Close() }
func TestProcessWithDistUpgrade(t *testing.T) {
_, f, _, _ := runtime.Caller(0) _, f, _, _ := runtime.Caller(0)
path := path.Join(path.Dir(f)) + "/testdata/DistUpgrade/" testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/"
// Create a mock datastore.
datastore := newMockDatastore()
datastore.FctInsertLayer = func(layer database.Layer) error {
datastore.layers[layer.Name] = layer
return nil
}
datastore.FctFindLayer = func(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) {
if layer, exists := datastore.layers[name]; exists {
return layer, nil
}
return database.Layer{}, cerrors.ErrNotFound
}
// Create the list of FeatureVersions that should not been upgraded from one layer to another.
nonUpgradedFeatureVersions := []database.FeatureVersion{
{Feature: database.Feature{Name: "libtext-wrapi18n-perl"}, Version: types.NewVersionUnsafe("0.06-7")},
{Feature: database.Feature{Name: "libtext-charwidth-perl"}, Version: types.NewVersionUnsafe("0.04-7")},
{Feature: database.Feature{Name: "libtext-iconv-perl"}, Version: types.NewVersionUnsafe("1.7-5")},
{Feature: database.Feature{Name: "mawk"}, Version: types.NewVersionUnsafe("1.3.3-17")},
{Feature: database.Feature{Name: "insserv"}, Version: types.NewVersionUnsafe("1.14.0-5")},
{Feature: database.Feature{Name: "db"}, Version: types.NewVersionUnsafe("5.1.29-5")},
{Feature: database.Feature{Name: "ustr"}, Version: types.NewVersionUnsafe("1.0.4-3")},
{Feature: database.Feature{Name: "xz-utils"}, Version: types.NewVersionUnsafe("5.1.1alpha+20120614-2")},
}
// Process test layers.
//
// blank.tar: MAINTAINER Quentin MACHU <quentin.machu.fr> // blank.tar: MAINTAINER Quentin MACHU <quentin.machu.fr>
// wheezy.tar: FROM debian:wheezy // wheezy.tar: FROM debian:wheezy
// jessie.tar: RUN sed -i "s/precise/trusty/" /etc/apt/sources.list && apt-get update && // jessie.tar: RUN sed -i "s/precise/trusty/" /etc/apt/sources.list && apt-get update &&
// apt-get -y dist-upgrade // apt-get -y dist-upgrade
assert.Nil(t, Process(datastore, "Docker", "blank", "", path+"blank.tar.gz", nil)) assert.Nil(t, Process(datastore, "Docker", "blank", "", testDataPath+"blank.tar.gz", nil))
assert.Nil(t, Process(datastore, "Docker", "wheezy", "blank", path+"wheezy.tar.gz", nil)) assert.Nil(t, Process(datastore, "Docker", "wheezy", "blank", testDataPath+"wheezy.tar.gz", nil))
assert.Nil(t, Process(datastore, "Docker", "jessie", "wheezy", path+"jessie.tar.gz", nil)) assert.Nil(t, Process(datastore, "Docker", "jessie", "wheezy", testDataPath+"jessie.tar.gz", nil))
wheezy, err := datastore.FindLayer("wheezy", true, false) // Ensure that the 'wheezy' layer has the expected namespace and features.
if assert.Nil(t, err) { wheezy, ok := datastore.layers["wheezy"]
if assert.True(t, ok, "layer 'wheezy' not processed") {
assert.Equal(t, "debian:7", wheezy.Namespace.Name) assert.Equal(t, "debian:7", wheezy.Namespace.Name)
assert.Len(t, wheezy.Features, 52) assert.Len(t, wheezy.Features, 52)
jessie, err := datastore.FindLayer("jessie", true, false) for _, nufv := range nonUpgradedFeatureVersions {
if assert.Nil(t, err) { nufv.Feature.Namespace.Name = "debian:7"
assert.Equal(t, "debian:8", jessie.Namespace.Name) assert.Contains(t, wheezy.Features, nufv)
assert.Len(t, jessie.Features, 74) }
}
// These FeatureVersions haven't been upgraded. // Ensure that the 'wheezy' layer has the expected namespace and non-upgraded features.
nonUpgradedFeatureVersions := []database.FeatureVersion{ jessie, ok := datastore.layers["jessie"]
{ if assert.True(t, ok, "layer 'jessie' not processed") {
Feature: database.Feature{Name: "libtext-wrapi18n-perl"}, assert.Equal(t, "debian:8", jessie.Namespace.Name)
Version: types.NewVersionUnsafe("0.06-7"), assert.Len(t, jessie.Features, 74)
},
{
Feature: database.Feature{Name: "libtext-charwidth-perl"},
Version: types.NewVersionUnsafe("0.04-7"),
},
{
Feature: database.Feature{Name: "libtext-iconv-perl"},
Version: types.NewVersionUnsafe("1.7-5"),
},
{
Feature: database.Feature{Name: "mawk"},
Version: types.NewVersionUnsafe("1.3.3-17"),
},
{
Feature: database.Feature{Name: "insserv"},
Version: types.NewVersionUnsafe("1.14.0-5"),
},
{
Feature: database.Feature{Name: "db"},
Version: types.NewVersionUnsafe("5.1.29-5"),
},
{
Feature: database.Feature{Name: "ustr"},
Version: types.NewVersionUnsafe("1.0.4-3"),
},
{
Feature: database.Feature{Name: "xz-utils"},
Version: types.NewVersionUnsafe("5.1.1alpha+20120614-2"),
},
}
for _, nufv := range nonUpgradedFeatureVersions { for _, nufv := range nonUpgradedFeatureVersions {
nufv.Feature.Namespace.Name = "debian:7" nufv.Feature.Namespace.Name = "debian:7"
assert.Contains(t, jessie.Features, nufv)
found := false }
for _, fv := range jessie.Features { for _, nufv := range nonUpgradedFeatureVersions {
if fv.Feature.Name == nufv.Feature.Name && nufv.Feature.Namespace.Name = "debian:8"
fv.Feature.Namespace.Name == nufv.Feature.Namespace.Name && assert.NotContains(t, jessie.Features, nufv)
fv.Version == nufv.Version {
found = true
break
}
}
assert.Equal(t, true, found, "Jessie layer doesn't have %#v but it should.", nufv)
}
for _, nufv := range nonUpgradedFeatureVersions {
nufv.Feature.Namespace.Name = "debian:8"
found := false
for _, fv := range jessie.Features {
if fv.Feature.Name == nufv.Feature.Name &&
fv.Feature.Namespace.Name == nufv.Feature.Namespace.Name &&
fv.Version == nufv.Version {
found = true
break
}
}
assert.Equal(t, false, found, "Jessie layer has %#v but it shouldn't.", nufv)
}
} }
} }
} }