database: postgres implementation with tests.

This commit is contained in:
Sida Chen 2017-07-26 16:23:54 -07:00
parent fb32dcfa58
commit a5c6400065
30 changed files with 3682 additions and 3076 deletions

261
database/pgsql/ancestry.go Normal file
View File

@ -0,0 +1,261 @@
package pgsql
import (
"database/sql"
"errors"
"fmt"
"strings"
"github.com/lib/pq"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr"
)
func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry, features []database.NamespacedFeature, processedBy database.Processors) error {
if ancestry.Name == "" {
log.Warning("Empty ancestry name is not allowed")
return commonerr.NewBadRequestError("could not insert an ancestry with empty name")
}
if len(ancestry.Layers) == 0 {
log.Warning("Empty ancestry is not allowed")
return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers")
}
err := tx.deleteAncestry(ancestry.Name)
if err != nil {
return err
}
var ancestryID int64
err = tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID)
if err != nil {
if isErrUniqueViolation(err) {
return handleError("insertAncestry", errors.New("Other Go-routine is processing this ancestry (skip)."))
}
return handleError("insertAncestry", err)
}
err = tx.insertAncestryLayers(ancestryID, ancestry.Layers)
if err != nil {
return err
}
err = tx.insertAncestryFeatures(ancestryID, features)
if err != nil {
return err
}
return tx.persistProcessors(persistAncestryLister,
"persistAncestryLister",
persistAncestryDetector,
"persistAncestryDetector",
ancestryID, processedBy)
}
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, database.Processors, bool, error) {
ancestry := database.Ancestry{Name: name}
processed := database.Processors{}
var ancestryID int64
err := tx.QueryRow(searchAncestry, name).Scan(&ancestryID)
if err != nil {
if err == sql.ErrNoRows {
return ancestry, processed, false, nil
}
return ancestry, processed, false, handleError("searchAncestry", err)
}
ancestry.Layers, err = tx.findAncestryLayers(ancestryID)
if err != nil {
return ancestry, processed, false, err
}
processed.Detectors, err = tx.findProcessors(searchAncestryDetectors, "searchAncestryDetectors", "detector", ancestryID)
if err != nil {
return ancestry, processed, false, err
}
processed.Listers, err = tx.findProcessors(searchAncestryListers, "searchAncestryListers", "lister", ancestryID)
if err != nil {
return ancestry, processed, false, err
}
return ancestry, processed, true, nil
}
func (tx *pgSession) FindAncestryFeatures(name string) (database.AncestryWithFeatures, bool, error) {
var (
awf database.AncestryWithFeatures
ok bool
err error
)
awf.Ancestry, awf.ProcessedBy, ok, err = tx.FindAncestry(name)
if err != nil {
return awf, false, err
}
if !ok {
return awf, false, nil
}
rows, err := tx.Query(searchAncestryFeatures, name)
if err != nil {
return awf, false, handleError("searchAncestryFeatures", err)
}
for rows.Next() {
nf := database.NamespacedFeature{}
err := rows.Scan(&nf.Namespace.Name, &nf.Namespace.VersionFormat, &nf.Feature.Name, &nf.Feature.Version)
if err != nil {
return awf, false, handleError("searchAncestryFeatures", err)
}
nf.Feature.VersionFormat = nf.Namespace.VersionFormat
awf.Features = append(awf.Features, nf)
}
return awf, true, nil
}
func (tx *pgSession) deleteAncestry(name string) error {
result, err := tx.Exec(removeAncestry, name)
if err != nil {
return handleError("removeAncestry", err)
}
_, err = result.RowsAffected()
if err != nil {
return handleError("removeAncestry", err)
}
return nil
}
func (tx *pgSession) findProcessors(query, queryName, processorType string, id int64) ([]string, error) {
rows, err := tx.Query(query, id)
if err != nil {
if err == sql.ErrNoRows {
log.Warning("No " + processorType + " are used")
return nil, nil
}
return nil, handleError(queryName, err)
}
var (
processors []string
processor string
)
for rows.Next() {
err := rows.Scan(&processor)
if err != nil {
return nil, handleError(queryName, err)
}
processors = append(processors, processor)
}
return processors, nil
}
func (tx *pgSession) findAncestryLayers(ancestryID int64) ([]database.Layer, error) {
rows, err := tx.Query(searchAncestryLayer, ancestryID)
if err != nil {
return nil, handleError("searchAncestryLayer", err)
}
layers := []database.Layer{}
for rows.Next() {
var layer database.Layer
err := rows.Scan(&layer.Hash)
if err != nil {
return nil, handleError("searchAncestryLayer", err)
}
layers = append(layers, layer)
}
return layers, nil
}
func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []database.Layer) error {
layerIDs := map[string]sql.NullInt64{}
for _, l := range layers {
layerIDs[l.Hash] = sql.NullInt64{}
}
layerHashes := []string{}
for hash := range layerIDs {
layerHashes = append(layerHashes, hash)
}
rows, err := tx.Query(searchLayerIDs, pq.Array(layerHashes))
if err != nil {
return handleError("searchLayerIDs", err)
}
for rows.Next() {
var (
layerID sql.NullInt64
layerName string
)
err := rows.Scan(&layerID, &layerName)
if err != nil {
return handleError("searchLayerIDs", err)
}
layerIDs[layerName] = layerID
}
notFound := []string{}
for hash, id := range layerIDs {
if !id.Valid {
notFound = append(notFound, hash)
}
}
if len(notFound) > 0 {
return handleError("searchLayerIDs", fmt.Errorf("Layer %s is not found in database", strings.Join(notFound, ",")))
}
//TODO(Sida): use bulk insert.
stmt, err := tx.Prepare(insertAncestryLayer)
if err != nil {
return handleError("insertAncestryLayer", err)
}
defer stmt.Close()
for index, layer := range layers {
_, err := stmt.Exec(ancestryID, index, layerIDs[layer.Hash].Int64)
if err != nil {
return handleError("insertAncestryLayer", commonerr.CombineErrors(err, stmt.Close()))
}
}
return nil
}
func (tx *pgSession) insertAncestryFeatures(ancestryID int64, features []database.NamespacedFeature) error {
featureIDs, err := tx.findNamespacedFeatureIDs(features)
if err != nil {
return err
}
//TODO(Sida): use bulk insert.
stmtFeatures, err := tx.Prepare(insertAncestryFeature)
if err != nil {
return handleError("insertAncestryFeature", err)
}
defer stmtFeatures.Close()
for _, id := range featureIDs {
if !id.Valid {
return errors.New("requested namespaced feature is not in database")
}
_, err := stmtFeatures.Exec(ancestryID, id)
if err != nil {
return handleError("insertAncestryFeature", err)
}
}
return nil
}

View File

@ -0,0 +1,207 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database"
)
func TestUpsertAncestry(t *testing.T) {
store, tx := openSessionForTest(t, "UpsertAncestry", true)
defer closeTest(t, store, tx)
a1 := database.Ancestry{
Name: "a1",
Layers: []database.Layer{
{Hash: "layer-N"},
},
}
a2 := database.Ancestry{}
a3 := database.Ancestry{
Name: "a",
Layers: []database.Layer{
{Hash: "layer-0"},
},
}
a4 := database.Ancestry{
Name: "a",
Layers: []database.Layer{
{Hash: "layer-1"},
},
}
f1 := database.Feature{
Name: "wechat",
Version: "0.5",
VersionFormat: "dpkg",
}
// not in database
f2 := database.Feature{
Name: "wechat",
Version: "0.6",
VersionFormat: "dpkg",
}
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
p := database.Processors{
Listers: []string{"dpkg", "non-existing"},
Detectors: []string{"os-release", "non-existing"},
}
nsf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
// not in database
nsf2 := database.NamespacedFeature{
Namespace: n1,
Feature: f2,
}
// invalid case
assert.NotNil(t, tx.UpsertAncestry(a1, nil, database.Processors{}))
assert.NotNil(t, tx.UpsertAncestry(a2, nil, database.Processors{}))
// valid case
assert.Nil(t, tx.UpsertAncestry(a3, nil, database.Processors{}))
// replace invalid case
assert.NotNil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1, nsf2}, p))
// replace valid case
assert.Nil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1}, p))
// validate
ancestry, ok, err := tx.FindAncestryFeatures("a")
assert.Nil(t, err)
assert.True(t, ok)
assert.Equal(t, a4, ancestry.Ancestry)
}
func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool {
sort.Strings(expected.Detectors)
sort.Strings(actual.Detectors)
sort.Strings(expected.Listers)
sort.Strings(actual.Listers)
return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers)
}
func TestFindAncestry(t *testing.T) {
store, tx := openSessionForTest(t, "FindAncestry", true)
defer closeTest(t, store, tx)
// not found
_, _, ok, err := tx.FindAncestry("ancestry-non")
assert.Nil(t, err)
assert.False(t, ok)
expected := database.Ancestry{
Name: "ancestry-1",
Layers: []database.Layer{
{Hash: "layer-0"},
{Hash: "layer-1"},
{Hash: "layer-2"},
{Hash: "layer-3a"},
},
}
expectedProcessors := database.Processors{
Detectors: []string{"os-release"},
Listers: []string{"dpkg"},
}
// found
a, p, ok2, err := tx.FindAncestry("ancestry-1")
if assert.Nil(t, err) && assert.True(t, ok2) {
assertAncestryEqual(t, expected, a)
assertProcessorsEqual(t, expectedProcessors, p)
}
}
func assertAncestryWithFeatureEqual(t *testing.T, expected database.AncestryWithFeatures, actual database.AncestryWithFeatures) bool {
return assertAncestryEqual(t, expected.Ancestry, actual.Ancestry) &&
assertNamespacedFeatureEqual(t, expected.Features, actual.Features) &&
assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy)
}
func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool {
return assert.Equal(t, expected.Name, actual.Name) && assert.Equal(t, expected.Layers, actual.Layers)
}
func TestFindAncestryFeatures(t *testing.T) {
store, tx := openSessionForTest(t, "FindAncestryFeatures", true)
defer closeTest(t, store, tx)
// invalid
_, ok, err := tx.FindAncestryFeatures("ancestry-non")
if assert.Nil(t, err) {
assert.False(t, ok)
}
expected := database.AncestryWithFeatures{
Ancestry: database.Ancestry{
Name: "ancestry-2",
Layers: []database.Layer{
{Hash: "layer-0"},
{Hash: "layer-1"},
{Hash: "layer-2"},
{Hash: "layer-3b"},
},
},
ProcessedBy: database.Processors{
Detectors: []string{"os-release"},
Listers: []string{"dpkg"},
},
Features: []database.NamespacedFeature{
{
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
},
Feature: database.Feature{
Name: "wechat",
Version: "0.5",
VersionFormat: "dpkg",
},
},
{
Namespace: database.Namespace{
Name: "debian:8",
VersionFormat: "dpkg",
},
Feature: database.Feature{
Name: "openssl",
Version: "1.0",
VersionFormat: "dpkg",
},
},
},
}
// valid
ancestry, ok, err := tx.FindAncestryFeatures("ancestry-2")
if assert.Nil(t, err) && assert.True(t, ok) {
assertAncestryEqual(t, expected.Ancestry, ancestry.Ancestry)
assertNamespacedFeatureEqual(t, expected.Features, ancestry.Features)
assertProcessorsEqual(t, expected.ProcessedBy, ancestry.ProcessedBy)
}
}

View File

@ -27,135 +27,200 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/ext/versionfmt/dpkg"
"github.com/coreos/clair/pkg/strutil"
) )
const ( const (
numVulnerabilities = 100 numVulnerabilities = 100
numFeatureVersions = 100 numFeatures = 100
) )
func TestRaceAffects(t *testing.T) { func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) {
datastore, err := openDatabaseForTest("RaceAffects", false) tx, err := store.Begin()
if err != nil { if !assert.Nil(t, err) {
t.Error(err) t.FailNow()
return
} }
defer datastore.Close()
// Insert the Feature on which we'll work. featureName := "TestFeature"
feature := database.Feature{ featureVersionFormat := dpkg.ParserName
Namespace: database.Namespace{ // Insert the namespace on which we'll work.
namespace := database.Namespace{
Name: "TestRaceAffectsFeatureNamespace1", Name: "TestRaceAffectsFeatureNamespace1",
VersionFormat: dpkg.ParserName, VersionFormat: dpkg.ParserName,
},
Name: "TestRaceAffecturesFeature1",
} }
_, err = datastore.insertFeature(feature)
if err != nil { if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) {
t.Error(err) t.FailNow()
return
} }
// Initialize random generator and enforce max procs. // Initialize random generator and enforce max procs.
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
runtime.GOMAXPROCS(runtime.NumCPU()) runtime.GOMAXPROCS(runtime.NumCPU())
// Generate FeatureVersions. // Generate Distinct random features
featureVersions := make([]database.FeatureVersion, numFeatureVersions) features := make([]database.Feature, numFeatures)
for i := 0; i < numFeatureVersions; i++ { nsFeatures := make([]database.NamespacedFeature, numFeatures)
version := rand.Intn(numFeatureVersions) for i := 0; i < numFeatures; i++ {
version := rand.Intn(numFeatures)
featureVersions[i] = database.FeatureVersion{ features[i] = database.Feature{
Feature: feature, Name: featureName,
VersionFormat: featureVersionFormat,
Version: strconv.Itoa(version), Version: strconv.Itoa(version),
} }
nsFeatures[i] = database.NamespacedFeature{
Namespace: namespace,
Feature: features[i],
}
}
// insert features
if !assert.Nil(t, tx.PersistFeatures(features)) {
t.FailNow()
} }
// Generate vulnerabilities. // Generate vulnerabilities.
// They are mapped by fixed version, which will make verification really easy afterwards. vulnerabilities := []database.VulnerabilityWithAffected{}
vulnerabilities := make(map[int][]database.Vulnerability)
for i := 0; i < numVulnerabilities; i++ { for i := 0; i < numVulnerabilities; i++ {
version := rand.Intn(numFeatureVersions) + 1 // any version less than this is vulnerable
version := rand.Intn(numFeatures) + 1
// if _, ok := vulnerabilities[version]; !ok { vulnerability := database.VulnerabilityWithAffected{
// vulnerabilities[version] = make([]database.Vulnerability) Vulnerability: database.Vulnerability{
// }
vulnerability := database.Vulnerability{
Name: uuid.New(), Name: uuid.New(),
Namespace: feature.Namespace, Namespace: namespace,
FixedIn: []database.FeatureVersion{
{
Feature: feature,
Version: strconv.Itoa(version),
},
},
Severity: database.UnknownSeverity, Severity: database.UnknownSeverity,
},
Affected: []database.AffectedFeature{
{
Namespace: namespace,
FeatureName: featureName,
AffectedVersion: strconv.Itoa(version),
FixedInVersion: strconv.Itoa(version),
},
},
} }
vulnerabilities[version] = append(vulnerabilities[version], vulnerability) vulnerabilities = append(vulnerabilities, vulnerability)
} }
tx.Commit()
// Insert featureversions and vulnerabilities in parallel. return nsFeatures, vulnerabilities
}
func TestConcurrency(t *testing.T) {
store, err := openDatabaseForTest("Concurrency", false)
if !assert.Nil(t, err) {
t.FailNow()
}
defer store.Close()
start := time.Now()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(2) wg.Add(100)
for i := 0; i < 100; i++ {
go func() { go func() {
defer wg.Done() defer wg.Done()
for _, vulnerabilitiesM := range vulnerabilities { nsNamespaces := genRandomNamespaces(t, 100)
for _, vulnerability := range vulnerabilitiesM { tx, err := store.Begin()
err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) if !assert.Nil(t, err) {
assert.Nil(t, err) t.FailNow()
}
assert.Nil(t, tx.PersistNamespaces(nsNamespaces))
tx.Commit()
}()
}
wg.Wait()
fmt.Println("total", time.Since(start))
}
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
r := make([]database.Namespace, count)
for i := 0; i < count; i++ {
r[i] = database.Namespace{
Name: uuid.New(),
VersionFormat: "dpkg",
} }
} }
fmt.Println("finished to insert vulnerabilities") return r
}
func TestCaching(t *testing.T) {
store, err := openDatabaseForTest("Caching", false)
if !assert.Nil(t, err) {
t.FailNow()
}
defer store.Close()
nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store)
fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
tx, err := store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
fmt.Println("finished to insert namespaced features")
tx.Commit()
}() }()
go func() { go func() {
defer wg.Done() defer wg.Done()
for i := 0; i < len(featureVersions); i++ { tx, err := store.Begin()
featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i]) if !assert.Nil(t, err) {
assert.Nil(t, err) t.FailNow()
} }
fmt.Println("finished to insert featureVersions")
assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
fmt.Println("finished to insert vulnerabilities")
tx.Commit()
}() }()
wg.Wait() wg.Wait()
// Verify consistency now. tx, err := store.Begin()
var actualAffectedNames []string
var expectedAffectedNames []string
for _, featureVersion := range featureVersions {
featureVersionVersion, _ := strconv.Atoi(featureVersion.Version)
// Get actual affects.
rows, err := datastore.Query(searchComplexTestFeatureVersionAffects,
featureVersion.ID)
assert.Nil(t, err)
defer rows.Close()
var vulnName string
for rows.Next() {
err = rows.Scan(&vulnName)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
continue t.FailNow()
} }
actualAffectedNames = append(actualAffectedNames, vulnName) defer tx.Rollback()
}
if assert.Nil(t, rows.Err()) { // Verify consistency now.
rows.Close() affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures)
if !assert.Nil(t, err) {
t.FailNow()
} }
// Get expected affects. for _, ansf := range affected {
for i := numVulnerabilities; i > featureVersionVersion; i-- { if !assert.True(t, ansf.Valid) {
for _, vulnerability := range vulnerabilities[i] { t.FailNow()
expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name) }
expectedAffectedNames := []string{}
for _, vuln := range vulnerabilities {
if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil {
if ok {
expectedAffectedNames = append(expectedAffectedNames, vuln.Name)
}
} }
} }
assert.Len(t, compareStringLists(expectedAffectedNames, actualAffectedNames), 0) actualAffectedNames := []string{}
assert.Len(t, compareStringLists(actualAffectedNames, expectedAffectedNames), 0) for _, s := range ansf.AffectedBy {
actualAffectedNames = append(actualAffectedNames, s.Name)
}
assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0)
assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0)
} }
} }

View File

@ -16,230 +16,366 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"strings" "errors"
"time" "sort"
"github.com/lib/pq"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { var (
if feature.Name == "" { errFeatureNotFound = errors.New("Feature not found")
return 0, commonerr.NewBadRequestError("could not find/insert invalid Feature") )
}
// Do cache lookup. type vulnerabilityAffecting struct {
if pgSQL.cache != nil { vulnerabilityID int64
promCacheQueriesTotal.WithLabelValues("feature").Inc() addedByID int64
id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name)
if found {
promCacheHitsTotal.WithLabelValues("feature").Inc()
return id.(int), nil
}
}
// We do `defer observeQueryTime` here because we don't want to observe cached features.
defer observeQueryTime("insertFeature", "all", time.Now())
// Find or create Namespace.
namespaceID, err := pgSQL.insertNamespace(feature.Namespace)
if err != nil {
return 0, err
}
// Find or create Feature.
var id int
err = pgSQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id)
if err != nil {
return 0, handleError("soiFeature", err)
}
if pgSQL.cache != nil {
pgSQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id)
}
return id, nil
} }
func (pgSQL *pgSQL) insertFeatureVersion(fv database.FeatureVersion) (id int, err error) { func (tx *pgSession) PersistFeatures(features []database.Feature) error {
err = versionfmt.Valid(fv.Feature.Namespace.VersionFormat, fv.Version) if len(features) == 0 {
if err != nil { return nil
return 0, commonerr.NewBadRequestError("could not find/insert invalid FeatureVersion")
} }
// Do cache lookup. // Sorting is needed before inserting into database to prevent deadlock.
cacheIndex := strings.Join([]string{"featureversion", fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") sort.Slice(features, func(i, j int) bool {
if pgSQL.cache != nil { return features[i].Name < features[j].Name ||
promCacheQueriesTotal.WithLabelValues("featureversion").Inc() features[i].Version < features[j].Version ||
id, found := pgSQL.cache.Get(cacheIndex) features[i].VersionFormat < features[j].VersionFormat
if found { })
promCacheHitsTotal.WithLabelValues("featureversion").Inc()
return id.(int), nil // TODO(Sida): A better interface for bulk insertion is needed.
keys := make([]interface{}, len(features)*3)
for i, f := range features {
keys[i*3] = f.Name
keys[i*3+1] = f.Version
keys[i*3+2] = f.VersionFormat
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
} }
} }
// We do `defer observeQueryTime` here because we don't want to observe cached featureversions. _, err := tx.Exec(queryPersistFeature(len(features)), keys...)
defer observeQueryTime("insertFeatureVersion", "all", time.Now()) return handleError("queryPersistFeature", err)
// Find or create Feature first.
t := time.Now()
featureID, err := pgSQL.insertFeature(fv.Feature)
observeQueryTime("insertFeatureVersion", "insertFeature", t)
if err != nil {
return 0, err
}
fv.Feature.ID = featureID
// Try to find the FeatureVersion.
//
// In a populated database, the likelihood of the FeatureVersion already being there is high.
// If we can find it here, we then avoid using a transaction and locking the database.
err = pgSQL.QueryRow(searchFeatureVersion, featureID, fv.Version).Scan(&fv.ID)
if err != nil && err != sql.ErrNoRows {
return 0, handleError("searchFeatureVersion", err)
}
if err == nil {
if pgSQL.cache != nil {
pgSQL.cache.Add(cacheIndex, fv.ID)
}
return fv.ID, nil
}
// Begin transaction.
tx, err := pgSQL.Begin()
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.Begin()", err)
}
// Lock Vulnerability_Affects_FeatureVersion exclusively.
// We want to prevent InsertVulnerability to modify it.
promConcurrentLockVAFV.Inc()
defer promConcurrentLockVAFV.Dec()
t = time.Now()
_, err = tx.Exec(lockVulnerabilityAffects)
observeQueryTime("insertFeatureVersion", "lock", t)
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err)
}
// Find or create FeatureVersion.
var created bool
t = time.Now()
err = tx.QueryRow(soiFeatureVersion, featureID, fv.Version).Scan(&created, &fv.ID)
observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t)
if err != nil {
tx.Rollback()
return 0, handleError("soiFeatureVersion", err)
}
if !created {
// The featureVersion already existed, no need to link it to
// vulnerabilities.
tx.Commit()
if pgSQL.cache != nil {
pgSQL.cache.Add(cacheIndex, fv.ID)
}
return fv.ID, nil
}
// Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in
// Vulnerability_Affects_FeatureVersion.
t = time.Now()
err = linkFeatureVersionToVulnerabilities(tx, fv)
observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t)
if err != nil {
tx.Rollback()
return 0, err
}
// Commit transaction.
err = tx.Commit()
if err != nil {
return 0, handleError("insertFeatureVersion.Commit()", err)
}
if pgSQL.cache != nil {
pgSQL.cache.Add(cacheIndex, fv.ID)
}
return fv.ID, nil
} }
// TODO(Quentin-M): Batch me type namespacedFeatureWithID struct {
func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) { database.NamespacedFeature
IDs := make([]int, 0, len(featureVersions))
for i := 0; i < len(featureVersions); i++ { ID int64
id, err := pgSQL.insertFeatureVersion(featureVersions[i])
if err != nil {
return IDs, err
}
IDs = append(IDs, id)
}
return IDs, nil
} }
type vulnerabilityAffectsFeatureVersion struct { type vulnerabilityCache struct {
vulnerabilityID int nsFeatureID int64
fixedInID int vulnID int64
fixedInVersion string vulnAffectingID int64
} }
func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) {
// Select every vulnerability and the fixed version that affect this Feature. if len(features) == 0 {
// TODO(Quentin-M): LIMIT return nil, nil
rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID)
if err != nil {
return handleError("searchVulnerabilityFixedInFeature", err)
} }
ids, err := tx.findNamespacedFeatureIDs(features)
if err != nil {
return nil, err
}
fMap := map[int64]database.NamespacedFeature{}
for i, f := range features {
if !ids[i].Valid {
return nil, errFeatureNotFound
}
fMap[ids[i].Int64] = f
}
cacheTable := []vulnerabilityCache{}
rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids))
if err != nil {
return nil, handleError("searchPotentialAffectingVulneraibilities", err)
}
defer rows.Close() defer rows.Close()
var affects []vulnerabilityAffectsFeatureVersion
for rows.Next() { for rows.Next() {
var affect vulnerabilityAffectsFeatureVersion var (
cache vulnerabilityCache
affected string
)
err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion) err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID)
if err != nil { if err != nil {
return handleError("searchVulnerabilityFixedInFeature.Scan()", err) return nil, err
} }
cmp, err := versionfmt.Compare(featureVersion.Feature.Namespace.VersionFormat, featureVersion.Version, affect.fixedInVersion) if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil {
return nil, err
} else if ok {
cacheTable = append(cacheTable, cache)
}
}
return cacheTable, nil
}
func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error {
if len(features) == 0 {
return nil
}
_, err := tx.Exec(lockVulnerabilityAffects)
if err != nil { if err != nil {
return handleError("lockVulnerabilityAffects", err)
}
cache, err := tx.searchAffectingVulnerabilities(features)
keys := make([]interface{}, len(cache)*3)
for i, c := range cache {
keys[i*3] = c.vulnID
keys[i*3+1] = c.nsFeatureID
keys[i*3+2] = c.vulnAffectingID
}
if len(cache) == 0 {
return nil
}
affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...)
if err != nil {
return handleError("persistVulnerabilityAffectedNamespacedFeature", err)
}
if count, err := affected.RowsAffected(); err != nil {
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count)
}
return nil
}
func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error {
if len(features) == 0 {
return nil
}
nsIDs := map[database.Namespace]sql.NullInt64{}
fIDs := map[database.Feature]sql.NullInt64{}
for _, f := range features {
nsIDs[f.Namespace] = sql.NullInt64{}
fIDs[f.Feature] = sql.NullInt64{}
}
fToFind := []database.Feature{}
for f := range fIDs {
fToFind = append(fToFind, f)
}
sort.Slice(fToFind, func(i, j int) bool {
return fToFind[i].Name < fToFind[j].Name ||
fToFind[i].Version < fToFind[j].Version ||
fToFind[i].VersionFormat < fToFind[j].VersionFormat
})
if ids, err := tx.findFeatureIDs(fToFind); err == nil {
for i, id := range ids {
if !id.Valid {
return errFeatureNotFound
}
fIDs[fToFind[i]] = id
}
} else {
return err return err
} }
if cmp < 0 {
// The version of the FeatureVersion we are inserting is lower than the fixed version on this
// Vulnerability, thus, this FeatureVersion is affected by it.
affects = append(affects, affect)
}
}
if err = rows.Err(); err != nil {
return handleError("searchVulnerabilityFixedInFeature.Rows()", err)
}
rows.Close()
// Insert into Vulnerability_Affects_FeatureVersion. nsToFind := []database.Namespace{}
for _, affect := range affects { for ns := range nsIDs {
// TODO(Quentin-M): Batch me. nsToFind = append(nsToFind, ns)
_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID,
featureVersion.ID, affect.fixedInID)
if err != nil {
return handleError("insertVulnerabilityAffectsFeatureVersion", err)
} }
if ids, err := tx.findNamespaceIDs(nsToFind); err == nil {
for i, id := range ids {
if !id.Valid {
return errNamespaceNotFound
}
nsIDs[nsToFind[i]] = id
}
} else {
return err
}
keys := make([]interface{}, len(features)*2)
for i, f := range features {
keys[i*2] = fIDs[f.Feature]
keys[i*2+1] = nsIDs[f.Namespace]
}
_, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...)
if err != nil {
return err
} }
return nil return nil
} }
// FindAffectedNamespacedFeatures looks up cache table and retrieves all
// vulnerabilities associated with the features.
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
if len(features) == 0 {
return nil, nil
}
returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
// featureMap is used to keep track of duplicated features.
featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{}
// initialize return value and generate unique feature request queries.
for i, f := range features {
returnFeatures[i] = database.NullableAffectedNamespacedFeature{
AffectedNamespacedFeature: database.AffectedNamespacedFeature{
NamespacedFeature: f,
},
}
featureMap[f] = append(featureMap[f], &returnFeatures[i])
}
// query unique namespaced features
distinctFeatures := []database.NamespacedFeature{}
for f := range featureMap {
distinctFeatures = append(distinctFeatures, f)
}
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures)
if err != nil {
return nil, err
}
toQuery := []int64{}
featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{}
for i, id := range nsFeatureIDs {
if id.Valid {
toQuery = append(toQuery, id.Int64)
for _, f := range featureMap[distinctFeatures[i]] {
f.Valid = id.Valid
featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f)
}
}
}
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery))
if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
}
defer rows.Close()
for rows.Next() {
var (
featureID int64
vuln database.VulnerabilityWithFixedIn
)
err := rows.Scan(&featureID,
&vuln.Name,
&vuln.Description,
&vuln.Link,
&vuln.Severity,
&vuln.Metadata,
&vuln.FixedInVersion,
&vuln.Namespace.Name,
&vuln.Namespace.VersionFormat,
)
if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
}
for _, f := range featureIDMap[featureID] {
f.AffectedBy = append(f.AffectedBy, vuln)
}
}
return returnFeatures, nil
}
func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
if len(nfs) == 0 {
return nil, nil
}
nfsMap := map[database.NamespacedFeature]sql.NullInt64{}
keys := make([]interface{}, len(nfs)*4)
for i, nf := range nfs {
keys[i*4] = nfs[i].Name
keys[i*4+1] = nfs[i].Version
keys[i*4+2] = nfs[i].VersionFormat
keys[i*4+3] = nfs[i].Namespace.Name
nfsMap[nf] = sql.NullInt64{}
}
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
if err != nil {
return nil, handleError("searchNamespacedFeature", err)
}
defer rows.Close()
var (
id sql.NullInt64
nf database.NamespacedFeature
)
for rows.Next() {
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name)
nf.Namespace.VersionFormat = nf.VersionFormat
if err != nil {
return nil, handleError("searchNamespacedFeature", err)
}
nfsMap[nf] = id
}
ids := make([]sql.NullInt64, len(nfs))
for i, nf := range nfs {
ids[i] = nfsMap[nf]
}
return ids, nil
}
func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) {
if len(fs) == 0 {
return nil, nil
}
fMap := map[database.Feature]sql.NullInt64{}
keys := make([]interface{}, len(fs)*3)
for i, f := range fs {
keys[i*3] = f.Name
keys[i*3+1] = f.Version
keys[i*3+2] = f.VersionFormat
fMap[f] = sql.NullInt64{}
}
rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...)
if err != nil {
return nil, handleError("querySearchFeatureID", err)
}
defer rows.Close()
var (
id sql.NullInt64
f database.Feature
)
for rows.Next() {
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat)
if err != nil {
return nil, handleError("querySearchFeatureID", err)
}
fMap[f] = id
}
ids := make([]sql.NullInt64, len(fs))
for i, f := range fs {
ids[i] = fMap[f]
}
return ids, nil
}

View File

@ -20,96 +20,237 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt/dpkg"
// register dpkg feature lister for testing
_ "github.com/coreos/clair/ext/featurefmt/dpkg"
) )
func TestInsertFeature(t *testing.T) { func TestPersistFeatures(t *testing.T) {
datastore, err := openDatabaseForTest("InsertFeature", false) datastore, tx := openSessionForTest(t, "PersistFeatures", false)
defer closeTest(t, datastore, tx)
f1 := database.Feature{}
f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"}
// empty
assert.Nil(t, tx.PersistFeatures([]database.Feature{}))
// invalid
assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1}))
// duplicated
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2}))
// existing
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2}))
fs := listFeatures(t, tx)
assert.Len(t, fs, 1)
assert.Equal(t, f2, fs[0])
}
func TestPersistNamespacedFeatures(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true)
defer closeTest(t, datastore, tx)
// existing features
f1 := database.Feature{
Name: "wechat",
Version: "0.5",
VersionFormat: "dpkg",
}
// non-existing features
f2 := database.Feature{
Name: "fake!",
}
f3 := database.Feature{
Name: "openssl",
Version: "2.0",
VersionFormat: "dpkg",
}
// exising namespace
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
n3 := database.Namespace{
Name: "debian:8",
VersionFormat: "dpkg",
}
// non-existing namespace
n2 := database.Namespace{
Name: "debian:non",
VersionFormat: "dpkg",
}
// existing namespaced feature
nf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
// invalid namespaced feature
nf2 := database.NamespacedFeature{
Namespace: n2,
Feature: f2,
}
// new namespaced feature affected by vulnerability
nf3 := database.NamespacedFeature{
Namespace: n3,
Feature: f3,
}
// namespaced features with namespaces or features not in the database will
// generate error.
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{}))
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2}))
// valid case: insert nf3
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3}))
all := listNamespacedFeatures(t, tx)
assert.Contains(t, all, nf1)
assert.Contains(t, all, nf3)
}
func TestVulnerableFeature(t *testing.T) {
datastore, tx := openSessionForTest(t, "VulnerableFeature", true)
defer closeTest(t, datastore, tx)
f1 := database.Feature{
Name: "openssl",
Version: "1.3",
VersionFormat: "dpkg",
}
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
nf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
assert.Nil(t, tx.PersistFeatures([]database.Feature{f1}))
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1}))
assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}))
// ensure the namespaced feature is affected correctly
anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})
if assert.Nil(t, err) &&
assert.Len(t, anf, 1) &&
assert.True(t, anf[0].Valid) &&
assert.Len(t, anf[0].AffectedBy, 1) {
assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name)
}
}
func TestFindAffectedNamespacedFeatures(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true)
defer closeTest(t, datastore, tx)
ns := database.NamespacedFeature{
Feature: database.Feature{
Name: "openssl",
Version: "1.0",
VersionFormat: "dpkg",
},
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
},
}
ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns})
if assert.Nil(t, err) &&
assert.Len(t, ans, 1) &&
assert.True(t, ans[0].Valid) &&
assert.Len(t, ans[0].AffectedBy, 1) {
assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name)
}
}
func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature {
rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format
FROM feature AS f, namespace AS n, namespaced_feature AS nf
WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
return t.FailNow()
}
defer datastore.Close()
// Invalid Feature.
id0, err := datastore.insertFeature(database.Feature{})
assert.NotNil(t, err)
assert.Zero(t, id0)
id0, err = datastore.insertFeature(database.Feature{
Namespace: database.Namespace{},
Name: "TestInsertFeature0",
})
assert.NotNil(t, err)
assert.Zero(t, id0)
// Insert Feature and ensure we can find it.
feature := database.Feature{
Namespace: database.Namespace{
Name: "TestInsertFeatureNamespace1",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertFeature1",
}
id1, err := datastore.insertFeature(feature)
assert.Nil(t, err)
id2, err := datastore.insertFeature(feature)
assert.Nil(t, err)
assert.Equal(t, id1, id2)
// Insert invalid FeatureVersion.
for _, invalidFeatureVersion := range []database.FeatureVersion{
{
Feature: database.Feature{},
Version: "1.0",
},
{
Feature: database.Feature{
Namespace: database.Namespace{},
Name: "TestInsertFeature2",
},
Version: "1.0",
},
{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertFeatureNamespace2",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertFeature2",
},
Version: "",
},
{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertFeatureNamespace2",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertFeature2",
},
Version: "bad version",
},
} {
id3, err := datastore.insertFeatureVersion(invalidFeatureVersion)
assert.Error(t, err)
assert.Zero(t, id3)
} }
// Insert FeatureVersion and ensure we can find it. nf := []database.NamespacedFeature{}
featureVersion := database.FeatureVersion{ for rows.Next() {
Feature: database.Feature{ f := database.NamespacedFeature{}
Namespace: database.Namespace{ err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat)
Name: "TestInsertFeatureNamespace1", if err != nil {
VersionFormat: dpkg.ParserName, t.Error(err)
}, t.FailNow()
Name: "TestInsertFeature1",
},
Version: "2:3.0-imba",
} }
id4, err := datastore.insertFeatureVersion(featureVersion) nf = append(nf, f)
assert.Nil(t, err) }
id5, err := datastore.insertFeatureVersion(featureVersion)
assert.Nil(t, err) return nf
assert.Equal(t, id4, id5) }
func listFeatures(t *testing.T, tx *pgSession) []database.Feature {
rows, err := tx.Query("SELECT name, version, version_format FROM feature")
if err != nil {
t.FailNow()
}
fs := []database.Feature{}
for rows.Next() {
f := database.Feature{}
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat)
if err != nil {
t.FailNow()
}
fs = append(fs, f)
}
return fs
}
func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.Feature]bool{}
for _, nf := range expected {
has[nf] = false
}
for _, nf := range actual {
has[nf] = true
}
for nf, visited := range has {
if !assert.True(t, visited, nf.Name+" is expected") {
return false
}
return true
}
}
return false
}
func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.NamespacedFeature]bool{}
for _, nf := range expected {
has[nf] = false
}
for _, nf := range actual {
has[nf] = true
}
for nf, visited := range has {
if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") {
return false
}
}
return true
}
return false
} }

View File

@ -23,63 +23,35 @@ import (
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
// InsertKeyValue stores (or updates) a single key / value tuple. func (tx *pgSession) UpdateKeyValue(key, value string) (err error) {
func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
if key == "" || value == "" { if key == "" || value == "" {
log.Warning("could not insert a flag which has an empty name or value") log.Warning("could not insert a flag which has an empty name or value")
return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value") return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value")
} }
defer observeQueryTime("InsertKeyValue", "all", time.Now()) defer observeQueryTime("PersistKeyValue", "all", time.Now())
// Upsert. _, err = tx.Exec(upsertKeyValue, key, value)
//
// Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS.
// The best solution is currently the use of http://dba.stackexchange.com/a/13477
// but the key/value storage doesn't need to be super-efficient and super-safe at the
// moment so we can just use a client-side solution with transactions, based on
// http://postgresql.org/docs/current/static/plpgsql-control-structures.html.
// TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable.
for {
// First, try to update.
r, err := pgSQL.Exec(updateKeyValue, value, key)
if err != nil { if err != nil {
return handleError("updateKeyValue", err)
}
if n, _ := r.RowsAffected(); n > 0 {
// Updated successfully.
return nil
}
// Try to insert the key.
// If someone else inserts the same key concurrently, we could get a unique-key violation error.
_, err = pgSQL.Exec(insertKeyValue, key, value)
if err != nil {
if isErrUniqueViolation(err) {
// Got unique constraint violation, retry.
continue
}
return handleError("insertKeyValue", err) return handleError("insertKeyValue", err)
} }
return nil return nil
}
} }
// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. func (tx *pgSession) FindKeyValue(key string) (string, bool, error) {
func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { defer observeQueryTime("FindKeyValue", "all", time.Now())
defer observeQueryTime("GetKeyValue", "all", time.Now())
var value string var value string
err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value) err := tx.QueryRow(searchKeyValue, key).Scan(&value)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return "", nil return "", false, nil
}
if err != nil {
return "", handleError("searchKeyValue", err)
} }
return value, nil if err != nil {
return "", false, handleError("searchKeyValue", err)
}
return value, true, nil
} }

View File

@ -21,32 +21,30 @@ import (
) )
func TestKeyValue(t *testing.T) { func TestKeyValue(t *testing.T) {
datastore, err := openDatabaseForTest("KeyValue", false) datastore, tx := openSessionForTest(t, "KeyValue", true)
if err != nil { defer closeTest(t, datastore, tx)
t.Error(err)
return
}
defer datastore.Close()
// Get non-existing key/value // Get non-existing key/value
f, err := datastore.GetKeyValue("test") f, ok, err := tx.FindKeyValue("test")
assert.Nil(t, err) assert.Nil(t, err)
assert.Empty(t, "", f) assert.False(t, ok)
// Try to insert invalid key/value. // Try to insert invalid key/value.
assert.Error(t, datastore.InsertKeyValue("test", "")) assert.Error(t, tx.UpdateKeyValue("test", ""))
assert.Error(t, datastore.InsertKeyValue("", "test")) assert.Error(t, tx.UpdateKeyValue("", "test"))
assert.Error(t, datastore.InsertKeyValue("", "")) assert.Error(t, tx.UpdateKeyValue("", ""))
// Insert and verify. // Insert and verify.
assert.Nil(t, datastore.InsertKeyValue("test", "test1")) assert.Nil(t, tx.UpdateKeyValue("test", "test1"))
f, err = datastore.GetKeyValue("test") f, ok, err = tx.FindKeyValue("test")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok)
assert.Equal(t, "test1", f) assert.Equal(t, "test1", f)
// Update and verify. // Update and verify.
assert.Nil(t, datastore.InsertKeyValue("test", "test2")) assert.Nil(t, tx.UpdateKeyValue("test", "test2"))
f, err = datastore.GetKeyValue("test") f, ok, err = tx.FindKeyValue("test")
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, ok)
assert.Equal(t, "test2", f) assert.Equal(t, "test2", f)
} }

View File

@ -16,464 +16,293 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"strings" "sort"
"time"
"github.com/guregu/null/zero"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { func (tx *pgSession) FindLayer(hash string) (database.Layer, database.Processors, bool, error) {
subquery := "all" l, p, _, ok, err := tx.findLayer(hash)
if withFeatures { return l, p, ok, err
subquery += "/features" }
} else if withVulnerabilities {
subquery += "/features+vulnerabilities"
}
defer observeQueryTime("FindLayer", subquery, time.Now())
// Find the layer func (tx *pgSession) FindLayerWithContent(hash string) (database.LayerWithContent, bool, error) {
var ( var (
layer database.Layer layer database.LayerWithContent
parentID zero.Int layerID int64
parentName zero.String ok bool
nsID zero.Int err error
nsName sql.NullString
nsVersionFormat sql.NullString
) )
t := time.Now() layer.Layer, layer.ProcessedBy, layerID, ok, err = tx.findLayer(hash)
err := pgSQL.QueryRow(searchLayer, name).Scan(
&layer.ID,
&layer.Name,
&layer.EngineVersion,
&parentID,
&parentName,
)
observeQueryTime("FindLayer", "searchLayer", t)
if err != nil { if err != nil {
return layer, handleError("searchLayer", err) return layer, false, err
} }
if !parentID.IsZero() { if !ok {
layer.Parent = &database.Layer{ return layer, false, nil
Model: database.Model{ID: int(parentID.Int64)},
Name: parentName.String,
}
} }
rows, err := pgSQL.Query(searchLayerNamespace, layer.ID) layer.Features, err = tx.findLayerFeatures(layerID)
defer rows.Close() layer.Namespaces, err = tx.findLayerNamespaces(layerID)
return layer, true, nil
}
func (tx *pgSession) PersistLayer(layer database.Layer) error {
if layer.Hash == "" {
return commonerr.NewBadRequestError("Empty Layer Hash is not allowed")
}
_, err := tx.Exec(queryPersistLayer(1), layer.Hash)
if err != nil { if err != nil {
return layer, handleError("searchLayerNamespace", err) return handleError("queryPersistLayer", err)
} }
return nil
}
// PersistLayerContent relates layer identified by hash with namespaces,
// features and processors provided. If the layer, namespaces, features are not
// in database, the function returns an error.
func (tx *pgSession) PersistLayerContent(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error {
if hash == "" {
return commonerr.NewBadRequestError("Empty layer hash is not allowed")
}
var layerID int64
err := tx.QueryRow(searchLayer, hash).Scan(&layerID)
if err != nil {
return err
}
if err = tx.persistLayerNamespace(layerID, namespaces); err != nil {
return err
}
if err = tx.persistLayerFeatures(layerID, features); err != nil {
return err
}
if err = tx.persistLayerDetectors(layerID, processedBy.Detectors); err != nil {
return err
}
if err = tx.persistLayerListers(layerID, processedBy.Listers); err != nil {
return err
}
return nil
}
func (tx *pgSession) persistLayerDetectors(id int64, detectors []string) error {
if len(detectors) == 0 {
return nil
}
// Sorting is needed before inserting into database to prevent deadlock.
sort.Strings(detectors)
keys := make([]interface{}, len(detectors)*2)
for i, d := range detectors {
keys[i*2] = id
keys[i*2+1] = d
}
_, err := tx.Exec(queryPersistLayerDetectors(len(detectors)), keys...)
if err != nil {
return handleError("queryPersistLayerDetectors", err)
}
return nil
}
func (tx *pgSession) persistLayerListers(id int64, listers []string) error {
if len(listers) == 0 {
return nil
}
sort.Strings(listers)
keys := make([]interface{}, len(listers)*2)
for i, d := range listers {
keys[i*2] = id
keys[i*2+1] = d
}
_, err := tx.Exec(queryPersistLayerListers(len(listers)), keys...)
if err != nil {
return handleError("queryPersistLayerDetectors", err)
}
return nil
}
func (tx *pgSession) persistLayerFeatures(id int64, features []database.Feature) error {
if len(features) == 0 {
return nil
}
fIDs, err := tx.findFeatureIDs(features)
if err != nil {
return err
}
ids := make([]int, len(fIDs))
for i, fID := range fIDs {
if !fID.Valid {
return errNamespaceNotFound
}
ids[i] = int(fID.Int64)
}
sort.IntSlice(ids).Sort()
keys := make([]interface{}, len(features)*2)
for i, fID := range ids {
keys[i*2] = id
keys[i*2+1] = fID
}
_, err = tx.Exec(queryPersistLayerFeature(len(features)), keys...)
if err != nil {
return handleError("queryPersistLayerFeature", err)
}
return nil
}
func (tx *pgSession) persistLayerNamespace(id int64, namespaces []database.Namespace) error {
if len(namespaces) == 0 {
return nil
}
nsIDs, err := tx.findNamespaceIDs(namespaces)
if err != nil {
return err
}
// for every bulk persist operation, the input data should be sorted.
ids := make([]int, len(nsIDs))
for i, nsID := range nsIDs {
if !nsID.Valid {
panic(errNamespaceNotFound)
}
ids[i] = int(nsID.Int64)
}
sort.IntSlice(ids).Sort()
keys := make([]interface{}, len(namespaces)*2)
for i, nsID := range ids {
keys[i*2] = id
keys[i*2+1] = nsID
}
_, err = tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
if err != nil {
return handleError("queryPersistLayerNamespace", err)
}
return nil
}
func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int64, processors database.Processors) error {
stmt, err := tx.Prepare(listerQuery)
if err != nil {
return handleError(listerQueryName, err)
}
for _, l := range processors.Listers {
_, err := stmt.Exec(id, l)
if err != nil {
stmt.Close()
return handleError(listerQueryName, err)
}
}
if err := stmt.Close(); err != nil {
return handleError(listerQueryName, err)
}
stmt, err = tx.Prepare(detectorQuery)
if err != nil {
return handleError(detectorQueryName, err)
}
for _, d := range processors.Detectors {
_, err := stmt.Exec(id, d)
if err != nil {
stmt.Close()
return handleError(detectorQueryName, err)
}
}
if err := stmt.Close(); err != nil {
return handleError(detectorQueryName, err)
}
return nil
}
func (tx *pgSession) findLayerNamespaces(layerID int64) ([]database.Namespace, error) {
var namespaces []database.Namespace
rows, err := tx.Query(searchLayerNamespaces, layerID)
if err != nil {
return nil, handleError("searchLayerFeatures", err)
}
for rows.Next() { for rows.Next() {
err = rows.Scan(&nsID, &nsName, &nsVersionFormat) ns := database.Namespace{}
err := rows.Scan(&ns.Name, &ns.VersionFormat)
if err != nil { if err != nil {
return layer, handleError("searchLayerNamespace", err) return nil, err
} }
if !nsID.IsZero() { namespaces = append(namespaces, ns)
layer.Namespaces = append(layer.Namespaces, database.Namespace{
Model: database.Model{ID: int(nsID.Int64)},
Name: nsName.String,
VersionFormat: nsVersionFormat.String,
})
} }
} return namespaces, nil
// Find its features
if withFeatures || withVulnerabilities {
// Create a transaction to disable hash/merge joins as our experiments have shown that
// PostgreSQL 9.4 makes bad planning decisions about:
// - joining the layer tree to feature versions and feature
// - joining the feature versions to affected/fixed feature version and vulnerabilities
// It would for instance do a merge join between affected feature versions (300 rows, estimated
// 3000 rows) and fixed in feature version (100k rows). In this case, it is much more
// preferred to use a nested loop.
tx, err := pgSQL.Begin()
if err != nil {
return layer, handleError("FindLayer.Begin()", err)
}
defer tx.Commit()
_, err = tx.Exec(disableHashJoin)
if err != nil {
log.WithError(err).Warningf("FindLayer: could not disable hash join")
}
_, err = tx.Exec(disableMergeJoin)
if err != nil {
log.WithError(err).Warningf("FindLayer: could not disable merge join")
}
t = time.Now()
featureVersions, err := getLayerFeatureVersions(tx, layer.ID)
observeQueryTime("FindLayer", "getLayerFeatureVersions", t)
if err != nil {
return layer, err
}
layer.Features = featureVersions
if withVulnerabilities {
// Load the vulnerabilities that affect the FeatureVersions.
t = time.Now()
err := loadAffectedBy(tx, layer.Features)
observeQueryTime("FindLayer", "loadAffectedBy", t)
if err != nil {
return layer, err
}
}
}
return layer, nil
} }
// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has. func (tx *pgSession) findLayerFeatures(layerID int64) ([]database.Feature, error) {
func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) { var features []database.Feature
var featureVersions []database.FeatureVersion
// Query. rows, err := tx.Query(searchLayerFeatures, layerID)
rows, err := tx.Query(searchLayerFeatureVersion, layerID)
if err != nil { if err != nil {
return featureVersions, handleError("searchLayerFeatureVersion", err) return nil, handleError("searchLayerFeatures", err)
} }
defer rows.Close()
// Scan query.
var modification string
mapFeatureVersions := make(map[int]database.FeatureVersion)
for rows.Next() { for rows.Next() {
var fv database.FeatureVersion f := database.Feature{}
err = rows.Scan( err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat)
&fv.ID, if err != nil {
&modification, return nil, err
&fv.Feature.Namespace.ID, }
&fv.Feature.Namespace.Name, features = append(features, f)
&fv.Feature.Namespace.VersionFormat, }
&fv.Feature.ID, return features, nil
&fv.Feature.Name, }
&fv.ID,
&fv.Version, func (tx *pgSession) findLayer(hash string) (database.Layer, database.Processors, int64, bool, error) {
&fv.AddedBy.ID, var (
&fv.AddedBy.Name, layerID int64
layer = database.Layer{Hash: hash}
processors database.Processors
) )
if hash == "" {
return layer, processors, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed")
}
err := tx.QueryRow(searchLayer, hash).Scan(&layerID)
if err != nil { if err != nil {
return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err) if err == sql.ErrNoRows {
return layer, processors, layerID, false, nil
}
return layer, processors, layerID, false, err
} }
// Do transitive closure. processors.Detectors, err = tx.findProcessors(searchLayerDetectors, "searchLayerDetectors", "detector", layerID)
switch modification { if err != nil {
case "add": return layer, processors, layerID, false, err
mapFeatureVersions[fv.ID] = fv
case "del":
delete(mapFeatureVersions, fv.ID)
default:
log.WithField("modification", modification).Warning("unknown Layer_diff_FeatureVersion's modification")
return featureVersions, database.ErrInconsistent
}
}
if err = rows.Err(); err != nil {
return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err)
} }
// Build result by converting our map to a slice. processors.Listers, err = tx.findProcessors(searchLayerListers, "searchLayerListers", "lister", layerID)
for _, featureVersion := range mapFeatureVersions { if err != nil {
featureVersions = append(featureVersions, featureVersion) return layer, processors, layerID, false, err
} }
return featureVersions, nil return layer, processors, layerID, true, nil
}
// loadAffectedBy returns the list of database.Vulnerability that affect the given
// FeatureVersion.
func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error {
if len(featureVersions) == 0 {
return nil
}
// Construct list of FeatureVersion IDs, we will do a single query
featureVersionIDs := make([]int, 0, len(featureVersions))
for i := 0; i < len(featureVersions); i++ {
featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID)
}
rows, err := tx.Query(searchFeatureVersionVulnerability,
buildInputArray(featureVersionIDs))
if err != nil && err != sql.ErrNoRows {
return handleError("searchFeatureVersionVulnerability", err)
}
defer rows.Close()
vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions))
var featureversionID int
for rows.Next() {
var vulnerability database.Vulnerability
err := rows.Scan(
&featureversionID,
&vulnerability.ID,
&vulnerability.Name,
&vulnerability.Description,
&vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
&vulnerability.Namespace.Name,
&vulnerability.Namespace.VersionFormat,
&vulnerability.FixedBy,
)
if err != nil {
return handleError("searchFeatureVersionVulnerability.Scan()", err)
}
vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability)
}
if err = rows.Err(); err != nil {
return handleError("searchFeatureVersionVulnerability.Rows()", err)
}
// Assign vulnerabilities to every FeatureVersions
for i := 0; i < len(featureVersions); i++ {
featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID]
}
return nil
}
// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent,
// the Feature list will be compared to the parent's Feature list and the difference will be stored.
// Note that when the Namespace of a layer differs from its parent, it is expected that several
// Feature that were already included a parent will have their Namespace updated as well
// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed
// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't
// been modified.
func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
tf := time.Now()
// Verify parameters
if layer.Name == "" {
log.Warning("could not insert a layer which has an empty Name")
return commonerr.NewBadRequestError("could not insert a layer which has an empty Name")
}
// Get a potentially existing layer.
existingLayer, err := pgSQL.FindLayer(layer.Name, true, false)
if err != nil && err != commonerr.ErrNotFound {
return err
} else if err == nil {
if existingLayer.EngineVersion >= layer.EngineVersion {
// The layer exists and has an equal or higher engine version, do nothing.
return nil
}
layer.ID = existingLayer.ID
}
// We do `defer observeQueryTime` here because we don't want to observe existing layers.
defer observeQueryTime("InsertLayer", "all", tf)
// Get parent ID.
var parentID zero.Int
if layer.Parent != nil {
if layer.Parent.ID == 0 {
log.Warning("Parent is expected to be retrieved from database when inserting a layer.")
return commonerr.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.")
}
parentID = zero.IntFrom(int64(layer.Parent.ID))
}
// namespaceIDs will contain inherited and new namespaces
namespaceIDs := make(map[int]struct{})
// try to insert the new namespaces
for _, ns := range layer.Namespaces {
n, err := pgSQL.insertNamespace(ns)
if err != nil {
return handleError("pgSQL.insertNamespace", err)
}
namespaceIDs[n] = struct{}{}
}
// inherit namespaces from parent layer
if layer.Parent != nil {
for _, ns := range layer.Parent.Namespaces {
namespaceIDs[ns.ID] = struct{}{}
}
}
// Begin transaction.
tx, err := pgSQL.Begin()
if err != nil {
tx.Rollback()
return handleError("InsertLayer.Begin()", err)
}
if layer.ID == 0 {
// Insert a new layer.
err = tx.QueryRow(insertLayer, layer.Name, layer.EngineVersion, parentID).
Scan(&layer.ID)
if err != nil {
tx.Rollback()
if isErrUniqueViolation(err) {
// Ignore this error, another process collided.
log.Debug("Attempted to insert duplicate layer.")
return nil
}
return handleError("insertLayer", err)
}
} else {
// Update an existing layer.
_, err = tx.Exec(updateLayer, layer.ID, layer.EngineVersion)
if err != nil {
tx.Rollback()
return handleError("updateLayer", err)
}
// replace the old namespace in the database
_, err := tx.Exec(removeLayerNamespace, layer.ID)
if err != nil {
tx.Rollback()
return handleError("removeLayerNamespace", err)
}
// Remove all existing Layer_diff_FeatureVersion.
_, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID)
if err != nil {
tx.Rollback()
return handleError("removeLayerDiffFeatureVersion", err)
}
}
// insert the layer's namespaces
stmt, err := tx.Prepare(insertLayerNamespace)
if err != nil {
tx.Rollback()
return handleError("failed to prepare statement", err)
}
defer func() {
err = stmt.Close()
if err != nil {
tx.Rollback()
log.WithError(err).Error("failed to close prepared statement")
}
}()
for nsid := range namespaceIDs {
_, err := stmt.Exec(layer.ID, nsid)
if err != nil {
tx.Rollback()
return handleError("insertLayerNamespace", err)
}
}
// Update Layer_diff_FeatureVersion now.
err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer)
if err != nil {
tx.Rollback()
return err
}
// Commit transaction.
err = tx.Commit()
if err != nil {
tx.Rollback()
return handleError("InsertLayer.Commit()", err)
}
return nil
}
func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error {
// add and del are the FeatureVersion diff we should insert.
var add []database.FeatureVersion
var del []database.FeatureVersion
if layer.Parent == nil {
// There is no parent, every Features are added.
add = append(add, layer.Features...)
} else if layer.Parent != nil {
// There is a parent, we need to diff the Features with it.
// Build name:version structures.
layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features)
parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features)
// Calculate the added and deleted FeatureVersions name:version.
addNV := compareStringLists(layerFeaturesNV, parentLayerFeaturesNV)
delNV := compareStringLists(parentLayerFeaturesNV, layerFeaturesNV)
// Fill the structures containing the added and deleted FeatureVersions.
for _, nv := range addNV {
add = append(add, *layerFeaturesMapNV[nv])
}
for _, nv := range delNV {
del = append(del, *parentLayerFeaturesMapNV[nv])
}
}
// Insert FeatureVersions in the database.
addIDs, err := pgSQL.insertFeatureVersions(add)
if err != nil {
return err
}
delIDs, err := pgSQL.insertFeatureVersions(del)
if err != nil {
return err
}
// Insert diff in the database.
if len(addIDs) > 0 {
_, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "add", buildInputArray(addIDs))
if err != nil {
return handleError("insertLayerDiffFeatureVersion.Add", err)
}
}
if len(delIDs) > 0 {
_, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "del", buildInputArray(delIDs))
if err != nil {
return handleError("insertLayerDiffFeatureVersion.Del", err)
}
}
return nil
}
func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) {
mapNV := make(map[string]*database.FeatureVersion, 0)
sliceNV := make([]string, 0, len(features))
for i := 0; i < len(features); i++ {
fv := &features[i]
nv := strings.Join([]string{fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":")
mapNV[nv] = fv
sliceNV = append(sliceNV, nv)
}
return mapNV, sliceNV
}
func (pgSQL *pgSQL) DeleteLayer(name string) error {
defer observeQueryTime("DeleteLayer", "all", time.Now())
result, err := pgSQL.Exec(removeLayer, name)
if err != nil {
return handleError("removeLayer", err)
}
affected, err := result.RowsAffected()
if err != nil {
return handleError("removeLayer.RowsAffected()", err)
}
if affected <= 0 {
return commonerr.ErrNotFound
}
return nil
} }

View File

@ -15,423 +15,100 @@
package pgsql package pgsql
import ( import (
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt/dpkg"
"github.com/coreos/clair/pkg/commonerr"
) )
func TestPersistLayer(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistLayer", false)
defer closeTest(t, datastore, tx)
l1 := database.Layer{}
l2 := database.Layer{Hash: "HESOYAM"}
// invalid
assert.NotNil(t, tx.PersistLayer(l1))
// valid
assert.Nil(t, tx.PersistLayer(l2))
// duplicated
assert.Nil(t, tx.PersistLayer(l2))
}
func TestPersistLayerProcessors(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistLayerProcessors", true)
defer closeTest(t, datastore, tx)
// invalid
assert.NotNil(t, tx.PersistLayerContent("hash", []database.Namespace{}, []database.Feature{}, database.Processors{}))
// valid
assert.Nil(t, tx.PersistLayerContent("layer-4", []database.Namespace{}, []database.Feature{}, database.Processors{Detectors: []string{"new detector!"}}))
}
func TestFindLayer(t *testing.T) { func TestFindLayer(t *testing.T) {
datastore, err := openDatabaseForTest("FindLayer", true) datastore, tx := openSessionForTest(t, "FindLayer", true)
if err != nil { defer closeTest(t, datastore, tx)
t.Error(err)
return
}
defer datastore.Close()
// Layer-0: no parent, no namespace, no feature, no vulnerability expected := database.Layer{Hash: "layer-4"}
layer, err := datastore.FindLayer("layer-0", false, false) expectedProcessors := database.Processors{
if assert.Nil(t, err) && assert.NotNil(t, layer) { Detectors: []string{"os-release", "apt-sources"},
assert.Equal(t, "layer-0", layer.Name) Listers: []string{"dpkg", "rpm"},
assert.Len(t, layer.Namespaces, 0)
assert.Nil(t, layer.Parent)
assert.Equal(t, 1, layer.EngineVersion)
assert.Len(t, layer.Features, 0)
} }
layer, err = datastore.FindLayer("layer-0", true, false) // invalid
if assert.Nil(t, err) && assert.NotNil(t, layer) { _, _, _, err := tx.FindLayer("")
assert.Len(t, layer.Features, 0) assert.NotNil(t, err)
} _, _, ok, err := tx.FindLayer("layer-non")
// Layer-1: one parent, adds two features, one vulnerability
layer, err = datastore.FindLayer("layer-1", false, false)
if assert.Nil(t, err) && assert.NotNil(t, layer) {
assert.Equal(t, layer.Name, "layer-1")
assertExpectedNamespaceName(t, &layer, []string{"debian:7"})
if assert.NotNil(t, layer.Parent) {
assert.Equal(t, "layer-0", layer.Parent.Name)
}
assert.Equal(t, 1, layer.EngineVersion)
assert.Len(t, layer.Features, 0)
}
layer, err = datastore.FindLayer("layer-1", true, false)
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
for _, featureVersion := range layer.Features {
assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name)
switch featureVersion.Feature.Name {
case "wechat":
assert.Equal(t, "0.5", featureVersion.Version)
case "openssl":
assert.Equal(t, "1.0", featureVersion.Version)
default:
t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name)
}
}
}
layer, err = datastore.FindLayer("layer-1", true, true)
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
for _, featureVersion := range layer.Features {
assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name)
switch featureVersion.Feature.Name {
case "wechat":
assert.Equal(t, "0.5", featureVersion.Version)
case "openssl":
assert.Equal(t, "1.0", featureVersion.Version)
if assert.Len(t, featureVersion.AffectedBy, 1) {
assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name)
assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name)
assert.Equal(t, database.HighSeverity, featureVersion.AffectedBy[0].Severity)
assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description)
assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link)
assert.Equal(t, "2.0", featureVersion.AffectedBy[0].FixedBy)
}
default:
t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name)
}
}
}
// Testing Multiple namespaces layer-3b has debian:7 and debian:8 namespaces
layer, err = datastore.FindLayer("layer-3b", true, true)
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
assert.Equal(t, "layer-3b", layer.Name)
// validate the namespace
assertExpectedNamespaceName(t, &layer, []string{"debian:7", "debian:8"})
for _, featureVersion := range layer.Features {
switch featureVersion.Feature.Namespace.Name {
case "debian:7":
assert.Equal(t, "wechat", featureVersion.Feature.Name)
assert.Equal(t, "0.5", featureVersion.Version)
case "debian:8":
assert.Equal(t, "openssl", featureVersion.Feature.Name)
assert.Equal(t, "1.0", featureVersion.Version)
default:
t.Errorf("unexpected package %s for layer-3b", featureVersion.Feature.Name)
}
}
}
}
func TestInsertLayer(t *testing.T) {
datastore, err := openDatabaseForTest("InsertLayer", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Insert invalid layer.
testInsertLayerInvalid(t, datastore)
// Insert a layer tree.
testInsertLayerTree(t, datastore)
// Update layer.
testInsertLayerUpdate(t, datastore)
// Delete layer.
testInsertLayerDelete(t, datastore)
}
func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) {
invalidLayers := []database.Layer{
{},
{Name: "layer0", Parent: &database.Layer{}},
{Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}},
}
for _, invalidLayer := range invalidLayers {
err := datastore.InsertLayer(invalidLayer)
assert.Error(t, err)
}
}
func testInsertLayerTree(t *testing.T, datastore database.Datastore) {
f1 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace2",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertLayerFeature1",
},
Version: "1.0",
}
f2 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace2",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertLayerFeature2",
},
Version: "0.34",
}
f3 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace2",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertLayerFeature3",
},
Version: "0.56",
}
f4 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace3",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertLayerFeature2",
},
Version: "0.34",
}
f5 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace3",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertLayerFeature3",
},
Version: "0.56",
}
f6 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace3",
VersionFormat: dpkg.ParserName,
},
Name: "TestInsertLayerFeature4",
},
Version: "0.666",
}
layers := []database.Layer{
{
Name: "TestInsertLayer1",
},
{
Name: "TestInsertLayer2",
Parent: &database.Layer{Name: "TestInsertLayer1"},
Namespaces: []database.Namespace{database.Namespace{
Name: "TestInsertLayerNamespace1",
VersionFormat: dpkg.ParserName,
}},
},
// This layer changes the namespace and adds Features.
{
Name: "TestInsertLayer3",
Parent: &database.Layer{Name: "TestInsertLayer2"},
Namespaces: []database.Namespace{database.Namespace{
Name: "TestInsertLayerNamespace2",
VersionFormat: dpkg.ParserName,
}},
Features: []database.FeatureVersion{f1, f2, f3},
},
// This layer covers the case where the last layer doesn't provide any new Feature.
{
Name: "TestInsertLayer4a",
Parent: &database.Layer{Name: "TestInsertLayer3"},
Features: []database.FeatureVersion{f1, f2, f3},
},
// This layer covers the case where the last layer provides Features.
// It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their
// Namespaces should then remain unchanged.
{
Name: "TestInsertLayer4b",
Parent: &database.Layer{Name: "TestInsertLayer3"},
Namespaces: []database.Namespace{database.Namespace{
Name: "TestInsertLayerNamespace3",
VersionFormat: dpkg.ParserName,
}},
Features: []database.FeatureVersion{
// Deletes TestInsertLayerFeature1.
// Keep TestInsertLayerFeature2 (old Namespace should be kept):
f4,
// Upgrades TestInsertLayerFeature3 (with new Namespace):
f5,
// Adds TestInsertLayerFeature4:
f6,
},
},
}
var err error
retrievedLayers := make(map[string]database.Layer)
for _, layer := range layers {
if layer.Parent != nil {
// Retrieve from database its parent and assign.
parent := retrievedLayers[layer.Parent.Name]
layer.Parent = &parent
}
err = datastore.InsertLayer(layer)
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok)
retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false) // valid
assert.Nil(t, err) layer, processors, ok2, err := tx.FindLayer("layer-4")
} if assert.Nil(t, err) && assert.True(t, ok2) {
assert.Equal(t, expected, layer)
// layer inherits all namespaces from its ancestries assertProcessorsEqual(t, expectedProcessors, processors)
l4a := retrievedLayers["TestInsertLayer4a"]
assertExpectedNamespaceName(t, &l4a, []string{"TestInsertLayerNamespace2", "TestInsertLayerNamespace1"})
assert.Len(t, l4a.Features, 3)
for _, featureVersion := range l4a.Features {
if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) {
assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3))
}
}
l4b := retrievedLayers["TestInsertLayer4b"]
assertExpectedNamespaceName(t, &l4b, []string{"TestInsertLayerNamespace1", "TestInsertLayerNamespace2", "TestInsertLayerNamespace3"})
assert.Len(t, l4b.Features, 3)
for _, featureVersion := range l4b.Features {
if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) {
assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6))
}
} }
} }
func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) { func TestFindLayerWithContent(t *testing.T) {
f7 := database.FeatureVersion{ datastore, tx := openSessionForTest(t, "FindLayerWithContent", true)
Feature: database.Feature{ defer closeTest(t, datastore, tx)
Namespace: database.Namespace{
Name: "TestInsertLayerNamespace3", _, _, err := tx.FindLayerWithContent("")
VersionFormat: dpkg.ParserName, assert.NotNil(t, err)
_, ok, err := tx.FindLayerWithContent("layer-non")
assert.Nil(t, err)
assert.False(t, ok)
expectedL := database.LayerWithContent{
Layer: database.Layer{
Hash: "layer-4",
}, },
Name: "TestInsertLayerFeature7", Features: []database.Feature{
{Name: "fake", Version: "2.0", VersionFormat: "rpm"},
{Name: "openssl", Version: "2.0", VersionFormat: "dpkg"},
},
Namespaces: []database.Namespace{
{Name: "debian:7", VersionFormat: "dpkg"},
{Name: "fake:1.0", VersionFormat: "rpm"},
},
ProcessedBy: database.Processors{
Detectors: []string{"os-release", "apt-sources"},
Listers: []string{"dpkg", "rpm"},
}, },
Version: "0.01",
} }
l3, _ := datastore.FindLayer("TestInsertLayer3", true, false) layer, ok2, err := tx.FindLayerWithContent("layer-4")
l3u := database.Layer{ if assert.Nil(t, err) && assert.True(t, ok2) {
Name: l3.Name, assertLayerWithContentEqual(t, expectedL, layer)
Parent: l3.Parent,
Namespaces: []database.Namespace{database.Namespace{
Name: "TestInsertLayerNamespaceUpdated1",
VersionFormat: dpkg.ParserName,
}},
Features: []database.FeatureVersion{f7},
}
l4u := database.Layer{
Name: "TestInsertLayer4",
Parent: &database.Layer{Name: "TestInsertLayer3"},
Features: []database.FeatureVersion{f7},
EngineVersion: 2,
}
// Try to re-insert without increasing the EngineVersion.
err := datastore.InsertLayer(l3u)
assert.Nil(t, err)
l3uf, err := datastore.FindLayer(l3u.Name, true, false)
if assert.Nil(t, err) {
assertSameNamespaceName(t, &l3, &l3uf)
assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion)
assert.Len(t, l3uf.Features, len(l3.Features))
}
// Update layer l3.
// Verify that the Namespace, EngineVersion and FeatureVersions got updated.
l3u.EngineVersion = 2
err = datastore.InsertLayer(l3u)
assert.Nil(t, err)
l3uf, err = datastore.FindLayer(l3u.Name, true, false)
if assert.Nil(t, err) {
assertSameNamespaceName(t, &l3u, &l3uf)
assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion)
if assert.Len(t, l3uf.Features, 1) {
assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0])
}
}
// Update layer l4.
// Verify that the Namespace got updated from its new Parent's, and also verify the
// EnginVersion and FeatureVersions.
l4u.Parent = &l3uf
err = datastore.InsertLayer(l4u)
assert.Nil(t, err)
l4uf, err := datastore.FindLayer(l3u.Name, true, false)
if assert.Nil(t, err) {
assertSameNamespaceName(t, &l3u, &l4uf)
assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion)
if assert.Len(t, l4uf.Features, 1) {
assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0])
}
} }
} }
func assertSameNamespaceName(t *testing.T, layer1 *database.Layer, layer2 *database.Layer) { func assertLayerWithContentEqual(t *testing.T, expected database.LayerWithContent, actual database.LayerWithContent) bool {
assert.Len(t, compareStringLists(extractNamespaceName(layer1), extractNamespaceName(layer2)), 0) return assert.Equal(t, expected.Layer, actual.Layer) &&
} assertFeaturesEqual(t, expected.Features, actual.Features) &&
assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) &&
func assertExpectedNamespaceName(t *testing.T, layer *database.Layer, expectedNames []string) { assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces)
assert.Len(t, compareStringLists(extractNamespaceName(layer), expectedNames), 0)
}
func extractNamespaceName(layer *database.Layer) []string {
slist := make([]string, 0, len(layer.Namespaces))
for _, ns := range layer.Namespaces {
slist = append(slist, ns.Name)
}
return slist
}
func testInsertLayerDelete(t *testing.T, datastore database.Datastore) {
err := datastore.DeleteLayer("TestInsertLayerX")
assert.Equal(t, commonerr.ErrNotFound, err)
// ensure layer_namespace table is cleaned up once a layer is removed
layer3, err := datastore.FindLayer("TestInsertLayer3", false, false)
layer4a, err := datastore.FindLayer("TestInsertLayer4a", false, false)
layer4b, err := datastore.FindLayer("TestInsertLayer4b", false, false)
err = datastore.DeleteLayer("TestInsertLayer3")
assert.Nil(t, err)
_, err = datastore.FindLayer("TestInsertLayer3", false, false)
assert.Equal(t, commonerr.ErrNotFound, err)
assertNotInLayerNamespace(t, layer3.ID, datastore)
_, err = datastore.FindLayer("TestInsertLayer4a", false, false)
assert.Equal(t, commonerr.ErrNotFound, err)
assertNotInLayerNamespace(t, layer4a.ID, datastore)
_, err = datastore.FindLayer("TestInsertLayer4b", true, false)
assert.Equal(t, commonerr.ErrNotFound, err)
assertNotInLayerNamespace(t, layer4b.ID, datastore)
}
func assertNotInLayerNamespace(t *testing.T, layerID int, datastore database.Datastore) {
pg, ok := datastore.(*pgSQL)
if !assert.True(t, ok) {
return
}
tx, err := pg.Begin()
if !assert.Nil(t, err) {
return
}
rows, err := tx.Query(searchLayerNamespace, layerID)
assert.False(t, rows.Next())
}
func cmpFV(a, b database.FeatureVersion) bool {
return a.Feature.Name == b.Feature.Name &&
a.Feature.Namespace.Name == b.Feature.Namespace.Name &&
a.Version == b.Version
} }

View File

@ -15,6 +15,7 @@
package pgsql package pgsql
import ( import (
"errors"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -22,86 +23,91 @@ import (
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
var (
errLockNotFound = errors.New("lock is not in database")
)
// Lock tries to set a temporary lock in the database. // Lock tries to set a temporary lock in the database.
// //
// Lock does not block, instead, it returns true and its expiration time // Lock does not block, instead, it returns true and its expiration time
// is the lock has been successfully acquired or false otherwise // is the lock has been successfully acquired or false otherwise.
func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { func (tx *pgSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) {
if name == "" || owner == "" || duration == 0 { if name == "" || owner == "" || duration == 0 {
log.Warning("could not create an invalid lock") log.Warning("could not create an invalid lock")
return false, time.Time{} return false, time.Time{}, commonerr.NewBadRequestError("Invalid Lock Parameters")
} }
defer observeQueryTime("Lock", "all", time.Now())
// Compute expiration.
until := time.Now().Add(duration) until := time.Now().Add(duration)
if renew { if renew {
defer observeQueryTime("Lock", "update", time.Now())
// Renew lock. // Renew lock.
r, err := pgSQL.Exec(updateLock, name, owner, until) r, err := tx.Exec(updateLock, name, owner, until)
if err != nil { if err != nil {
handleError("updateLock", err) return false, until, handleError("updateLock", err)
return false, until
} }
if n, _ := r.RowsAffected(); n > 0 {
// Updated successfully. if n, err := r.RowsAffected(); err == nil {
return true, until return n > 0, until, nil
} }
} else { return false, until, handleError("updateLock", err)
// Prune locks. } else if err := tx.pruneLocks(); err != nil {
pgSQL.pruneLocks() return false, until, err
} }
// Lock. // Lock.
_, err := pgSQL.Exec(insertLock, name, owner, until) defer observeQueryTime("Lock", "soiLock", time.Now())
_, err := tx.Exec(soiLock, name, owner, until)
if err != nil { if err != nil {
if !isErrUniqueViolation(err) { if isErrUniqueViolation(err) {
handleError("insertLock", err) return false, until, nil
} }
return false, until return false, until, handleError("insertLock", err)
} }
return true, until, nil
return true, until
} }
// Unlock unlocks a lock specified by its name if I own it // Unlock unlocks a lock specified by its name if I own it
func (pgSQL *pgSQL) Unlock(name, owner string) { func (tx *pgSession) Unlock(name, owner string) error {
if name == "" || owner == "" { if name == "" || owner == "" {
log.Warning("could not delete an invalid lock") return commonerr.NewBadRequestError("Invalid Lock Parameters")
return
} }
defer observeQueryTime("Unlock", "all", time.Now()) defer observeQueryTime("Unlock", "all", time.Now())
pgSQL.Exec(removeLock, name, owner) _, err := tx.Exec(removeLock, name, owner)
return err
} }
// FindLock returns the owner of a lock specified by its name and its // FindLock returns the owner of a lock specified by its name and its
// expiration time. // expiration time.
func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { func (tx *pgSession) FindLock(name string) (string, time.Time, bool, error) {
if name == "" { if name == "" {
log.Warning("could not find an invalid lock") return "", time.Time{}, false, commonerr.NewBadRequestError("could not find an invalid lock")
return "", time.Time{}, commonerr.NewBadRequestError("could not find an invalid lock")
} }
defer observeQueryTime("FindLock", "all", time.Now()) defer observeQueryTime("FindLock", "all", time.Now())
var owner string var owner string
var until time.Time var until time.Time
err := pgSQL.QueryRow(searchLock, name).Scan(&owner, &until) err := tx.QueryRow(searchLock, name).Scan(&owner, &until)
if err != nil { if err != nil {
return owner, until, handleError("searchLock", err) return owner, until, false, handleError("searchLock", err)
} }
return owner, until, nil return owner, until, true, nil
} }
// pruneLocks removes every expired locks from the database // pruneLocks removes every expired locks from the database
func (pgSQL *pgSQL) pruneLocks() { func (tx *pgSession) pruneLocks() error {
defer observeQueryTime("pruneLocks", "all", time.Now()) defer observeQueryTime("pruneLocks", "all", time.Now())
if _, err := pgSQL.Exec(removeLockExpired); err != nil { if r, err := tx.Exec(removeLockExpired); err != nil {
handleError("removeLockExpired", err) return handleError("removeLockExpired", err)
} else if affected, err := r.RowsAffected(); err != nil {
return handleError("removeLockExpired", err)
} else {
log.Debugf("Pruned %d Locks", affected)
} }
return nil
} }

View File

@ -22,48 +22,72 @@ import (
) )
func TestLock(t *testing.T) { func TestLock(t *testing.T) {
datastore, err := openDatabaseForTest("InsertNamespace", false) datastore, tx := openSessionForTest(t, "Lock", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close() defer datastore.Close()
var l bool var l bool
var et time.Time var et time.Time
// Create a first lock. // Create a first lock.
l, _ = datastore.Lock("test1", "owner1", time.Minute, false) l, _, err := tx.Lock("test1", "owner1", time.Minute, false)
assert.Nil(t, err)
assert.True(t, l) assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// Try to lock the same lock with another owner. // lock again by itself, the previous lock is not expired yet.
l, _ = datastore.Lock("test1", "owner2", time.Minute, true) l, _, err = tx.Lock("test1", "owner1", time.Minute, false)
assert.Nil(t, err)
assert.False(t, l) assert.False(t, l)
tx = restartSession(t, datastore, tx, false)
l, _ = datastore.Lock("test1", "owner2", time.Minute, false) // Try to renew the same lock with another owner.
l, _, err = tx.Lock("test1", "owner2", time.Minute, true)
assert.Nil(t, err)
assert.False(t, l) assert.False(t, l)
tx = restartSession(t, datastore, tx, false)
l, _, err = tx.Lock("test1", "owner2", time.Minute, false)
assert.Nil(t, err)
assert.False(t, l)
tx = restartSession(t, datastore, tx, false)
// Renew the lock. // Renew the lock.
l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true) l, _, err = tx.Lock("test1", "owner1", 2*time.Minute, true)
assert.Nil(t, err)
assert.True(t, l) assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// Unlock and then relock by someone else. // Unlock and then relock by someone else.
datastore.Unlock("test1", "owner1") err = tx.Unlock("test1", "owner1")
assert.Nil(t, err)
tx = restartSession(t, datastore, tx, true)
l, et = datastore.Lock("test1", "owner2", time.Minute, false) l, et, err = tx.Lock("test1", "owner2", time.Minute, false)
assert.Nil(t, err)
assert.True(t, l) assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// LockInfo // LockInfo
o, et2, err := datastore.FindLock("test1") o, et2, ok, err := tx.FindLock("test1")
assert.True(t, ok)
assert.Nil(t, err) assert.Nil(t, err)
assert.Equal(t, "owner2", o) assert.Equal(t, "owner2", o)
assert.Equal(t, et.Second(), et2.Second()) assert.Equal(t, et.Second(), et2.Second())
tx = restartSession(t, datastore, tx, true)
// Create a second lock which is actually already expired ... // Create a second lock which is actually already expired ...
l, _ = datastore.Lock("test2", "owner1", -time.Minute, false) l, _, err = tx.Lock("test2", "owner1", -time.Minute, false)
assert.Nil(t, err)
assert.True(t, l) assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
// Take over the lock // Take over the lock
l, _ = datastore.Lock("test2", "owner2", time.Minute, false) l, _, err = tx.Lock("test2", "owner2", time.Minute, false)
assert.Nil(t, err)
assert.True(t, l) assert.True(t, l)
tx = restartSession(t, datastore, tx, true)
if !assert.Nil(t, tx.Rollback()) {
t.FailNow()
}
} }

View File

@ -1,53 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import (
"database/sql"
"github.com/remind101/migrate"
)
func init() {
// This migration removes the data maintained by the previous migration tool
// (liamstask/goose), and if it was present, mark the 00002_initial_schema
// migration as done.
RegisterMigration(migrate.Migration{
ID: 1,
Up: func(tx *sql.Tx) error {
// Verify that goose was in use before, otherwise skip this migration.
var e bool
err := tx.QueryRow("SELECT true FROM pg_class WHERE relname = $1", "goose_db_version").Scan(&e)
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return err
}
// Delete goose's data.
_, err = tx.Exec("DROP TABLE goose_db_version CASCADE")
if err != nil {
return err
}
// Mark the '00002_initial_schema' as done.
_, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES (2)")
return err
},
Down: migrate.Queries([]string{}),
})
}

View File

@ -0,0 +1,192 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 1,
Up: migrate.Queries([]string{
// namespaces
`CREATE TABLE IF NOT EXISTS namespace (
id SERIAL PRIMARY KEY,
name TEXT NULL,
version_format TEXT,
UNIQUE (name, version_format));`,
`CREATE INDEX ON namespace(name);`,
// features
`CREATE TABLE IF NOT EXISTS feature (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL,
version TEXT NOT NULL,
version_format TEXT NOT NULL,
UNIQUE (name, version, version_format));`,
`CREATE INDEX ON feature(name);`,
`CREATE TABLE IF NOT EXISTS namespaced_feature (
id SERIAL PRIMARY KEY,
namespace_id INT REFERENCES namespace,
feature_id INT REFERENCES feature,
UNIQUE (namespace_id, feature_id));`,
// layers
`CREATE TABLE IF NOT EXISTS layer(
id SERIAL PRIMARY KEY,
hash TEXT NOT NULL UNIQUE);`,
`CREATE TABLE IF NOT EXISTS layer_feature (
id SERIAL PRIMARY KEY,
layer_id INT REFERENCES layer ON DELETE CASCADE,
feature_id INT REFERENCES feature ON DELETE CASCADE,
UNIQUE (layer_id, feature_id));`,
`CREATE INDEX ON layer_feature(layer_id);`,
`CREATE TABLE IF NOT EXISTS layer_lister (
id SERIAL PRIMARY KEY,
layer_id INT REFERENCES layer ON DELETE CASCADE,
lister TEXT NOT NULL,
UNIQUE (layer_id, lister));`,
`CREATE INDEX ON layer_lister(layer_id);`,
`CREATE TABLE IF NOT EXISTS layer_detector (
id SERIAL PRIMARY KEY,
layer_id INT REFERENCES layer ON DELETE CASCADE,
detector TEXT,
UNIQUE (layer_id, detector));`,
`CREATE INDEX ON layer_detector(layer_id);`,
`CREATE TABLE IF NOT EXISTS layer_namespace (
id SERIAL PRIMARY KEY,
layer_id INT REFERENCES layer ON DELETE CASCADE,
namespace_id INT REFERENCES namespace ON DELETE CASCADE,
UNIQUE (layer_id, namespace_id));`,
`CREATE INDEX ON layer_namespace(layer_id);`,
// ancestry
`CREATE TABLE IF NOT EXISTS ancestry (
id SERIAL PRIMARY KEY,
name TEXT NOT NULL UNIQUE);`,
`CREATE TABLE IF NOT EXISTS ancestry_layer (
id SERIAL PRIMARY KEY,
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
ancestry_index INT NOT NULL,
layer_id INT REFERENCES layer ON DELETE RESTRICT,
UNIQUE (ancestry_id, ancestry_index));`,
`CREATE INDEX ON ancestry_layer(ancestry_id);`,
`CREATE TABLE IF NOT EXISTS ancestry_feature (
id SERIAL PRIMARY KEY,
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
namespaced_feature_id INT REFERENCES namespaced_feature ON DELETE CASCADE,
UNIQUE (ancestry_id, namespaced_feature_id));`,
`CREATE TABLE IF NOT EXISTS ancestry_lister (
id SERIAL PRIMARY KEY,
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
lister TEXT,
UNIQUE (ancestry_id, lister));`,
`CREATE INDEX ON ancestry_lister(ancestry_id);`,
`CREATE TABLE IF NOT EXISTS ancestry_detector (
id SERIAL PRIMARY KEY,
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
detector TEXT,
UNIQUE (ancestry_id, detector));`,
`CREATE INDEX ON ancestry_detector(ancestry_id);`,
`CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`,
// vulnerability
`CREATE TABLE IF NOT EXISTS vulnerability (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name TEXT NOT NULL,
description TEXT NULL,
link TEXT NULL,
severity severity NOT NULL,
metadata TEXT NULL,
created_at TIMESTAMP WITH TIME ZONE,
deleted_at TIMESTAMP WITH TIME ZONE NULL);`,
`CREATE INDEX ON vulnerability(namespace_id, name);`,
`CREATE INDEX ON vulnerability(namespace_id);`,
`CREATE TABLE IF NOT EXISTS vulnerability_affected_feature (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE,
feature_name TEXT NOT NULL,
affected_version TEXT,
fixedin TEXT);`,
`CREATE INDEX ON vulnerability_affected_feature(vulnerability_id, feature_name);`,
`CREATE TABLE IF NOT EXISTS vulnerability_affected_namespaced_feature(
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE,
namespaced_feature_id INT NOT NULL REFERENCES namespaced_feature ON DELETE CASCADE,
added_by INT NOT NULL REFERENCES vulnerability_affected_feature ON DELETE CASCADE,
UNIQUE (vulnerability_id, namespaced_feature_id));`,
`CREATE INDEX ON vulnerability_affected_namespaced_feature(namespaced_feature_id);`,
`CREATE TABLE IF NOT EXISTS KeyValue (
id SERIAL PRIMARY KEY,
key TEXT NOT NULL UNIQUE,
value TEXT);`,
`CREATE TABLE IF NOT EXISTS Lock (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
owner VARCHAR(64) NOT NULL,
until TIMESTAMP WITH TIME ZONE);`,
`CREATE INDEX ON Lock (owner);`,
// Notification
`CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
created_at TIMESTAMP WITH TIME ZONE,
notified_at TIMESTAMP WITH TIME ZONE NULL,
deleted_at TIMESTAMP WITH TIME ZONE NULL,
old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE,
new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`,
`CREATE INDEX ON Vulnerability_Notification (notified_at);`,
}),
Down: migrate.Queries([]string{
`DROP TABLE IF EXISTS
ancestry,
ancestry_layer,
ancestry_feature,
ancestry_detector,
ancestry_lister,
feature,
namespaced_feature,
keyvalue,
layer,
layer_detector,
layer_feature,
layer_lister,
layer_namespace,
lock,
namespace,
vulnerability,
vulnerability_affected_feature,
vulnerability_affected_namespaced_feature,
vulnerability_notification
CASCADE;`,
`DROP TYPE IF EXISTS severity;`,
}),
})
}

View File

@ -1,128 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
// This migration creates the initial Clair's schema.
RegisterMigration(migrate.Migration{
ID: 2,
Up: migrate.Queries([]string{
`CREATE TABLE IF NOT EXISTS Namespace (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NULL);`,
`CREATE TABLE IF NOT EXISTS Layer (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NOT NULL UNIQUE,
engineversion SMALLINT NOT NULL,
parent_id INT NULL REFERENCES Layer ON DELETE CASCADE,
namespace_id INT NULL REFERENCES Namespace,
created_at TIMESTAMP WITH TIME ZONE);`,
`CREATE INDEX ON Layer (parent_id);`,
`CREATE INDEX ON Layer (namespace_id);`,
`CREATE TABLE IF NOT EXISTS Feature (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
UNIQUE (namespace_id, name));`,
`CREATE TABLE IF NOT EXISTS FeatureVersion (
id SERIAL PRIMARY KEY,
feature_id INT NOT NULL REFERENCES Feature,
version VARCHAR(128) NOT NULL);`,
`CREATE INDEX ON FeatureVersion (feature_id);`,
`CREATE TYPE modification AS ENUM ('add', 'del');`,
`CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion (
id SERIAL PRIMARY KEY,
layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE,
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
modification modification NOT NULL,
UNIQUE (layer_id, featureversion_id));`,
`CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);`,
`CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);`,
`CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);`,
`CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`,
`CREATE TABLE IF NOT EXISTS Vulnerability (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
description TEXT NULL,
link VARCHAR(128) NULL,
severity severity NOT NULL,
metadata TEXT NULL,
created_at TIMESTAMP WITH TIME ZONE,
deleted_at TIMESTAMP WITH TIME ZONE NULL);`,
`CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
feature_id INT NOT NULL REFERENCES Feature,
version VARCHAR(128) NOT NULL,
UNIQUE (vulnerability_id, feature_id));`,
`CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);`,
`CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion (
id SERIAL PRIMARY KEY,
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE,
UNIQUE (vulnerability_id, featureversion_id));`,
`CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);`,
`CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);`,
`CREATE TABLE IF NOT EXISTS KeyValue (
id SERIAL PRIMARY KEY,
key VARCHAR(128) NOT NULL UNIQUE,
value TEXT);`,
`CREATE TABLE IF NOT EXISTS Lock (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
owner VARCHAR(64) NOT NULL,
until TIMESTAMP WITH TIME ZONE);`,
`CREATE INDEX ON Lock (owner);`,
`CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
created_at TIMESTAMP WITH TIME ZONE,
notified_at TIMESTAMP WITH TIME ZONE NULL,
deleted_at TIMESTAMP WITH TIME ZONE NULL,
old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE,
new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`,
`CREATE INDEX ON Vulnerability_Notification (notified_at);`,
}),
Down: migrate.Queries([]string{
`DROP TABLE IF EXISTS
Namespace,
Layer,
Feature,
FeatureVersion,
Layer_diff_FeatureVersion,
Vulnerability,
Vulnerability_FixedIn_Feature,
Vulnerability_Affects_FeatureVersion,
Vulnerability_Notification,
KeyValue,
Lock
CASCADE;`,
}),
})
}

View File

@ -1,35 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 3,
Up: migrate.Queries([]string{
`CREATE UNIQUE INDEX namespace_name_key ON Namespace (name);`,
`CREATE INDEX vulnerability_name_idx ON Vulnerability (name);`,
`CREATE INDEX vulnerability_namespace_id_name_idx ON Vulnerability (namespace_id, name);`,
`CREATE UNIQUE INDEX featureversion_feature_id_version_key ON FeatureVersion (feature_id, version);`,
}),
Down: migrate.Queries([]string{
`DROP INDEX namespace_name_key;`,
`DROP INDEX vulnerability_name_idx;`,
`DROP INDEX vulnerability_namespace_id_name_idx;`,
`DROP INDEX featureversion_feature_id_version_key;`,
}),
})
}

View File

@ -1,29 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 4,
Up: migrate.Queries([]string{
`CREATE INDEX vulnerability_notification_deleted_at_idx ON Vulnerability_Notification (deleted_at);`,
}),
Down: migrate.Queries([]string{
`DROP INDEX vulnerability_notification_deleted_at_idx;`,
}),
})
}

View File

@ -1,29 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 5,
Up: migrate.Queries([]string{
`CREATE INDEX layer_diff_featureversion_layer_id_modification_idx ON Layer_diff_FeatureVersion (layer_id, modification);`,
}),
Down: migrate.Queries([]string{
`DROP INDEX layer_diff_featureversion_layer_id_modification_idx;`,
}),
})
}

View File

@ -1,31 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 6,
Up: migrate.Queries([]string{
`ALTER TABLE Namespace ADD COLUMN version_format varchar(128);`,
`UPDATE Namespace SET version_format = 'rpm' WHERE name LIKE 'rhel%' OR name LIKE 'centos%' OR name LIKE 'fedora%' OR name LIKE 'amzn%' OR name LIKE 'scientific%' OR name LIKE 'ol%' OR name LIKE 'oracle%';`,
`UPDATE Namespace SET version_format = 'dpkg' WHERE version_format is NULL;`,
}),
Down: migrate.Queries([]string{
`ALTER TABLE Namespace DROP COLUMN version_format;`,
}),
})
}

View File

@ -1,31 +0,0 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 7,
Up: migrate.Queries([]string{
`ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(256);`,
`ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(256);`,
}),
Down: migrate.Queries([]string{
`ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(128);`,
`ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(128);`,
}),
})
}

View File

@ -1,44 +0,0 @@
// Copyright 2016 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package migrations
import "github.com/remind101/migrate"
func init() {
RegisterMigration(migrate.Migration{
ID: 8,
Up: migrate.Queries([]string{
// set on deletion, remove the corresponding rows in database
`CREATE TABLE IF NOT EXISTS Layer_Namespace(
id SERIAL PRIMARY KEY,
layer_id INT REFERENCES Layer(id) ON DELETE CASCADE,
namespace_id INT REFERENCES Namespace(id) ON DELETE CASCADE,
unique(layer_id, namespace_id)
);`,
`CREATE INDEX ON Layer_Namespace (namespace_id);`,
`CREATE INDEX ON Layer_Namespace (layer_id);`,
// move the namespace_id to the table
`INSERT INTO Layer_Namespace (layer_id, namespace_id) SELECT id, namespace_id FROM Layer;`,
// alter the Layer table to remove the column
`ALTER TABLE IF EXISTS Layer DROP namespace_id;`,
}),
Down: migrate.Queries([]string{
`ALTER TABLE IF EXISTS Layer ADD namespace_id INT NULL REFERENCES Namespace;`,
`CREATE INDEX ON Layer (namespace_id);`,
`UPDATE IF EXISTS Layer SET namespace_id = (SELECT lns.namespace_id FROM Layer_Namespace lns WHERE Layer.id = lns.layer_id LIMIT 1);`,
`DROP TABLE IF EXISTS Layer_Namespace;`,
}),
})
}

View File

@ -15,61 +15,82 @@
package pgsql package pgsql
import ( import (
"time" "database/sql"
"errors"
"sort"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { var (
if namespace.Name == "" { errNamespaceNotFound = errors.New("Requested Namespace is not in database")
return 0, commonerr.NewBadRequestError("could not find/insert invalid Namespace") )
// PersistNamespaces soi namespaces into database.
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
if len(namespaces) == 0 {
return nil
} }
if pgSQL.cache != nil { // Sorting is needed before inserting into database to prevent deadlock.
promCacheQueriesTotal.WithLabelValues("namespace").Inc() sort.Slice(namespaces, func(i, j int) bool {
if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found { return namespaces[i].Name < namespaces[j].Name &&
promCacheHitsTotal.WithLabelValues("namespace").Inc() namespaces[i].VersionFormat < namespaces[j].VersionFormat
return id.(int), nil })
keys := make([]interface{}, len(namespaces)*2)
for i, ns := range namespaces {
if ns.Name == "" || ns.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty namespace name or version format is not allowed")
} }
keys[i*2] = ns.Name
keys[i*2+1] = ns.VersionFormat
} }
// We do `defer observeQueryTime` here because we don't want to observe cached namespaces. _, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...)
defer observeQueryTime("insertNamespace", "all", time.Now())
var id int
err := pgSQL.QueryRow(soiNamespace, namespace.Name, namespace.VersionFormat).Scan(&id)
if err != nil { if err != nil {
return 0, handleError("soiNamespace", err) return handleError("queryPersistNamespace", err)
} }
return nil
if pgSQL.cache != nil {
pgSQL.cache.Add("namespace:"+namespace.Name, id)
}
return id, nil
} }
func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error) { func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) {
rows, err := pgSQL.Query(listNamespace) if len(namespaces) == 0 {
if err != nil { return nil, nil
return namespaces, handleError("listNamespace", err)
} }
keys := make([]interface{}, len(namespaces)*2)
nsMap := map[database.Namespace]sql.NullInt64{}
for i, n := range namespaces {
keys[i*2] = n.Name
keys[i*2+1] = n.VersionFormat
nsMap[n] = sql.NullInt64{}
}
rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...)
if err != nil {
return nil, handleError("searchNamespace", err)
}
defer rows.Close() defer rows.Close()
var (
id sql.NullInt64
ns database.Namespace
)
for rows.Next() { for rows.Next() {
var ns database.Namespace err := rows.Scan(&id, &ns.Name, &ns.VersionFormat)
err = rows.Scan(&ns.ID, &ns.Name, &ns.VersionFormat)
if err != nil { if err != nil {
return namespaces, handleError("listNamespace.Scan()", err) return nil, handleError("searchNamespace", err)
}
nsMap[ns] = id
} }
namespaces = append(namespaces, ns) ids := make([]sql.NullInt64, len(namespaces))
} for i, ns := range namespaces {
if err = rows.Err(); err != nil { ids[i] = nsMap[ns]
return namespaces, handleError("listNamespace.Rows()", err)
} }
return namespaces, err return ids, nil
} }

View File

@ -15,60 +15,69 @@
package pgsql package pgsql
import ( import (
"fmt"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt/dpkg"
) )
func TestInsertNamespace(t *testing.T) { func TestPersistNamespaces(t *testing.T) {
datastore, err := openDatabaseForTest("InsertNamespace", false) datastore, tx := openSessionForTest(t, "PersistNamespaces", false)
if err != nil { defer closeTest(t, datastore, tx)
t.Error(err)
return
}
defer datastore.Close()
// Invalid Namespace. ns1 := database.Namespace{}
id0, err := datastore.insertNamespace(database.Namespace{}) ns2 := database.Namespace{Name: "t", VersionFormat: "b"}
assert.NotNil(t, err)
assert.Zero(t, id0)
// Insert Namespace and ensure we can find it. // Empty Case
id1, err := datastore.insertNamespace(database.Namespace{ assert.Nil(t, tx.PersistNamespaces([]database.Namespace{}))
Name: "TestInsertNamespace1", // Invalid Case
VersionFormat: dpkg.ParserName, assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1}))
}) // Duplicated Case
assert.Nil(t, err) assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2}))
id2, err := datastore.insertNamespace(database.Namespace{ // Existing Case
Name: "TestInsertNamespace1", assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2}))
VersionFormat: dpkg.ParserName,
}) nsList := listNamespaces(t, tx)
assert.Nil(t, err) assert.Len(t, nsList, 1)
assert.Equal(t, id1, id2) assert.Equal(t, ns2, nsList[0])
} }
func TestListNamespace(t *testing.T) { func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool {
datastore, err := openDatabaseForTest("ListNamespaces", true) if assert.Len(t, actual, len(expected)) {
if err != nil { has := map[database.Namespace]bool{}
t.Error(err) for _, i := range expected {
return has[i] = false
} }
defer datastore.Close() for _, i := range actual {
has[i] = true
namespaces, err := datastore.ListNamespaces() }
assert.Nil(t, err) for key, v := range has {
if assert.Len(t, namespaces, 2) { if !assert.True(t, v, key.Name+"is expected") {
for _, namespace := range namespaces { return false
switch namespace.Name {
case "debian:7", "debian:8":
continue
default:
assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name))
} }
} }
return true
} }
return false
}
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
rows, err := tx.Query("SELECT name, version_format FROM namespace")
if err != nil {
t.FailNow()
}
defer rows.Close()
namespaces := []database.Namespace{}
for rows.Next() {
var ns database.Namespace
err := rows.Scan(&ns.Name, &ns.VersionFormat)
if err != nil {
t.FailNow()
}
namespaces = append(namespaces, ns)
}
return namespaces
} }

View File

@ -16,235 +16,320 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"errors"
"time" "time"
"github.com/guregu/null/zero" "github.com/guregu/null/zero"
"github.com/pborman/uuid"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
// do it in tx so we won't insert/update a vuln without notification and vice-versa. var (
// name and created doesn't matter. errNotificationNotFound = errors.New("requested notification is not found")
func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error { )
defer observeQueryTime("createNotification", "all", time.Now())
// Insert Notification. func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error {
oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} if len(notifications) == 0 {
newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0} return nil
_, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) }
var (
newVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64)
oldVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64)
)
invalidCreationTime := time.Time{}
for _, noti := range notifications {
if noti.Name == "" {
return commonerr.NewBadRequestError("notification should not have empty name")
}
if noti.Created == invalidCreationTime {
return commonerr.NewBadRequestError("notification should not have empty created time")
}
if noti.New != nil {
key := database.VulnerabilityID{
Name: noti.New.Name,
Namespace: noti.New.Namespace.Name,
}
newVulnIDMap[key] = sql.NullInt64{}
}
if noti.Old != nil {
key := database.VulnerabilityID{
Name: noti.Old.Name,
Namespace: noti.Old.Namespace.Name,
}
oldVulnIDMap[key] = sql.NullInt64{}
}
}
var (
newVulnIDs = make([]database.VulnerabilityID, 0, len(newVulnIDMap))
oldVulnIDs = make([]database.VulnerabilityID, 0, len(oldVulnIDMap))
)
for vulnID := range newVulnIDMap {
newVulnIDs = append(newVulnIDs, vulnID)
}
for vulnID := range oldVulnIDMap {
oldVulnIDs = append(oldVulnIDs, vulnID)
}
ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs)
if err != nil { if err != nil {
tx.Rollback() return err
return handleError("insertNotification", err) }
for i, id := range ids {
if !id.Valid {
return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
}
newVulnIDMap[newVulnIDs[i]] = id
}
ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs)
if err != nil {
return err
}
for i, id := range ids {
if !id.Valid {
return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
}
oldVulnIDMap[oldVulnIDs[i]] = id
}
var (
newVulnID sql.NullInt64
oldVulnID sql.NullInt64
)
keys := make([]interface{}, len(notifications)*4)
for i, noti := range notifications {
if noti.New != nil {
newVulnID = newVulnIDMap[database.VulnerabilityID{
Name: noti.New.Name,
Namespace: noti.New.Namespace.Name,
}]
}
if noti.Old != nil {
oldVulnID = oldVulnIDMap[database.VulnerabilityID{
Name: noti.Old.Name,
Namespace: noti.Old.Namespace.Name,
}]
}
keys[4*i] = noti.Name
keys[4*i+1] = noti.Created
keys[4*i+2] = oldVulnID
keys[4*i+3] = newVulnID
}
// NOTE(Sida): The data is not sorted before inserting into database under
// the fact that there's only one updater running at a time. If there are
// multiple updaters, deadlock may happen.
_, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...)
if err != nil {
return handleError("queryInsertNotifications", err)
} }
return nil return nil
} }
// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)). func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) {
// Does not fill new/old vuln. var (
func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) { notification database.NotificationHook
defer observeQueryTime("GetAvailableNotification", "all", time.Now()) created zero.Time
notified zero.Time
before := time.Now().Add(-renotifyInterval) deleted zero.Time
row := pgSQL.QueryRow(searchNotificationAvailable, before)
notification, err := pgSQL.scanNotification(row, false)
return notification, handleError("searchNotificationAvailable", err)
}
func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) {
defer observeQueryTime("GetNotification", "all", time.Now())
// Get Notification.
notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true)
if err != nil {
return notification, page, handleError("searchNotification", err)
}
// Load vulnerabilities' LayersIntroducingVulnerability.
page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
notification.OldVulnerability,
limit,
page.OldVulnerability,
) )
err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(&notification.Name, &created, &notified, &deleted)
if err != nil { if err != nil {
return notification, page, err if err == sql.ErrNoRows {
} return notification, false, nil
page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
notification.NewVulnerability,
limit,
page.NewVulnerability,
)
if err != nil {
return notification, page, err
}
return notification, page, nil
}
func (pgSQL *pgSQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) {
var notification database.VulnerabilityNotification
var created zero.Time
var notified zero.Time
var deleted zero.Time
var oldVulnerabilityNullableID sql.NullInt64
var newVulnerabilityNullableID sql.NullInt64
// Scan notification.
if hasVulns {
err := row.Scan(
&notification.ID,
&notification.Name,
&created,
&notified,
&deleted,
&oldVulnerabilityNullableID,
&newVulnerabilityNullableID,
)
if err != nil {
return notification, err
}
} else {
err := row.Scan(&notification.ID, &notification.Name, &created, &notified, &deleted)
if err != nil {
return notification, err
} }
return notification, false, handleError("searchNotificationAvailable", err)
} }
notification.Created = created.Time notification.Created = created.Time
notification.Notified = notified.Time notification.Notified = notified.Time
notification.Deleted = deleted.Time notification.Deleted = deleted.Time
if hasVulns { return notification, true, nil
if oldVulnerabilityNullableID.Valid {
vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64))
if err != nil {
return notification, err
}
notification.OldVulnerability = &vulnerability
}
if newVulnerabilityNullableID.Valid {
vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64))
if err != nil {
return notification, err
}
notification.NewVulnerability = &vulnerability
}
}
return notification, nil
} }
// Fills Vulnerability.LayersIntroducingVulnerability. func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) {
// limit -1: won't do anything vulnPage := database.PagedVulnerableAncestries{Limit: limit}
// limit 0: will just get the startID of the second page current := idPageNumber{0}
func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) { if currentPage != "" {
tf := time.Now() var err error
current, err = decryptPage(currentPage, tx.paginationKey)
if vulnerability == nil {
return -1, nil
}
// A startID equals to -1 means that we reached the end already.
if startID == -1 || limit == -1 {
return -1, nil
}
// Create a transaction to disable hash joins as our experience shows that
// PostgreSQL plans in certain cases a sequential scan and a hash on
// Layer_diff_FeatureVersion for the condition `ldfv.layer_id >= $2 AND
// ldfv.modification = 'add'` before realizing a hash inner join with
// Vulnerability_Affects_FeatureVersion. By disabling explictly hash joins,
// we force PostgreSQL to perform a bitmap index scan with
// `ldfv.featureversion_id = fv.id` on Layer_diff_FeatureVersion, followed by
// a bitmap heap scan on `ldfv.layer_id >= $2 AND ldfv.modification = 'add'`,
// thus avoiding a sequential scan on the biggest database table and
// allowing a small nested loop join instead.
tx, err := pgSQL.Begin()
if err != nil { if err != nil {
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Begin()", err) return vulnPage, err
} }
defer tx.Commit()
_, err = tx.Exec(disableHashJoin)
if err != nil {
log.WithError(err).Warning("searchNotificationLayerIntroducingVulnerability: could not disable hash join")
} }
// We do `defer observeQueryTime` here because we don't want to observe invalid calls. err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf) &vulnPage.Name,
&vulnPage.Description,
// Query with limit + 1, the last item will be used to know the next starting ID. &vulnPage.Link,
rows, err := tx.Query(searchNotificationLayerIntroducingVulnerability, &vulnPage.Severity,
vulnerability.ID, startID, limit+1) &vulnPage.Metadata,
&vulnPage.Namespace.Name,
&vulnPage.Namespace.VersionFormat,
)
if err != nil { if err != nil {
return 0, handleError("searchNotificationLayerIntroducingVulnerability", err) return vulnPage, handleError("searchVulnerabilityByID", err)
}
// the last result is used for the next page's startID
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1)
if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
} }
defer rows.Close() defer rows.Close()
var layers []database.Layer ancestries := []affectedAncestry{}
for rows.Next() { for rows.Next() {
var layer database.Layer var ancestry affectedAncestry
err := rows.Scan(&ancestry.id, &ancestry.name)
if err := rows.Scan(&layer.ID, &layer.Name); err != nil { if err != nil {
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err) return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
}
ancestries = append(ancestries, ancestry)
} }
layers = append(layers, layer) lastIndex := 0
if len(ancestries)-1 < limit {
lastIndex = len(ancestries)
vulnPage.End = true
} else {
// Use the last ancestry's ID as the next PageNumber.
lastIndex = len(ancestries) - 1
vulnPage.Next, err = encryptPage(
idPageNumber{
ancestries[len(ancestries)-1].id,
}, tx.paginationKey)
if err != nil {
return vulnPage, err
} }
if err = rows.Err(); err != nil {
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err)
} }
size := limit vulnPage.Affected = map[int]string{}
if len(layers) < limit { for _, ancestry := range ancestries[0:lastIndex] {
size = len(layers) vulnPage.Affected[int(ancestry.id)] = ancestry.name
}
vulnerability.LayersIntroducingVulnerability = layers[:size]
nextID := -1
if len(layers) > limit {
nextID = layers[limit].ID
} }
return nextID, nil vulnPage.Current, err = encryptPage(current, tx.paginationKey)
if err != nil {
return vulnPage, err
}
return vulnPage, nil
} }
func (pgSQL *pgSQL) SetNotificationNotified(name string) error { func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) (
defer observeQueryTime("SetNotificationNotified", "all", time.Now()) database.VulnerabilityNotificationWithVulnerable, bool, error) {
var (
noti database.VulnerabilityNotificationWithVulnerable
oldVulnID sql.NullInt64
newVulnID sql.NullInt64
created zero.Time
notified zero.Time
deleted zero.Time
)
if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil { if name == "" {
return noti, false, commonerr.NewBadRequestError("Empty notification name is not allowed")
}
noti.Name = name
err := tx.QueryRow(searchNotification, name).Scan(&created, &notified,
&deleted, &oldVulnID, &newVulnID)
if err != nil {
if err == sql.ErrNoRows {
return noti, false, nil
}
return noti, false, handleError("searchNotification", err)
}
if created.Valid {
noti.Created = created.Time
}
if notified.Valid {
noti.Notified = notified.Time
}
if deleted.Valid {
noti.Deleted = deleted.Time
}
if oldVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage)
if err != nil {
return noti, false, err
}
noti.Old = &page
}
if newVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage)
if err != nil {
return noti, false, err
}
noti.New = &page
}
return noti, true, nil
}
func (tx *pgSession) MarkNotificationNotified(name string) error {
if name == "" {
return commonerr.NewBadRequestError("Empty notification name is not allowed")
}
r, err := tx.Exec(updatedNotificationNotified, name)
if err != nil {
return handleError("updatedNotificationNotified", err) return handleError("updatedNotificationNotified", err)
} }
affected, err := r.RowsAffected()
if err != nil {
return handleError("updatedNotificationNotified", err)
}
if affected <= 0 {
return handleError("updatedNotificationNotified", errNotificationNotFound)
}
return nil return nil
} }
func (pgSQL *pgSQL) DeleteNotification(name string) error { func (tx *pgSession) DeleteNotification(name string) error {
defer observeQueryTime("DeleteNotification", "all", time.Now()) if name == "" {
return commonerr.NewBadRequestError("Empty notification name is not allowed")
}
result, err := pgSQL.Exec(removeNotification, name) result, err := tx.Exec(removeNotification, name)
if err != nil { if err != nil {
return handleError("removeNotification", err) return handleError("removeNotification", err)
} }
affected, err := result.RowsAffected() affected, err := result.RowsAffected()
if err != nil { if err != nil {
return handleError("removeNotification.RowsAffected()", err) return handleError("removeNotification", err)
} }
if affected <= 0 { if affected <= 0 {
return commonerr.ErrNotFound return handleError("removeNotification", commonerr.ErrNotFound)
} }
return nil return nil

View File

@ -21,211 +21,225 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/ext/versionfmt/dpkg"
"github.com/coreos/clair/pkg/commonerr"
) )
func TestNotification(t *testing.T) { func TestPagination(t *testing.T) {
datastore, err := openDatabaseForTest("Notification", false) datastore, tx := openSessionForTest(t, "Pagination", true)
if err != nil { defer closeTest(t, datastore, tx)
t.Error(err)
return
}
defer datastore.Close()
// Try to get a notification when there is none. ns := database.Namespace{
_, err = datastore.GetAvailableNotification(time.Second) Name: "debian:7",
assert.Equal(t, commonerr.ErrNotFound, err) VersionFormat: "dpkg",
// Create some data.
f1 := database.Feature{
Name: "TestNotificationFeature1",
Namespace: database.Namespace{
Name: "TestNotificationNamespace1",
VersionFormat: dpkg.ParserName,
},
} }
f2 := database.Feature{ vNew := database.Vulnerability{
Name: "TestNotificationFeature2", Namespace: ns,
Namespace: database.Namespace{ Name: "CVE-OPENSSL-1-DEB7",
Name: "TestNotificationNamespace1", Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
VersionFormat: dpkg.ParserName, Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
}, Severity: database.HighSeverity,
} }
l1 := database.Layer{ vOld := database.Vulnerability{
Name: "TestNotificationLayer1", Namespace: ns,
Features: []database.FeatureVersion{ Name: "CVE-NOPE",
{ Description: "A vulnerability affecting nothing",
Feature: f1, Severity: database.UnknownSeverity,
Version: "0.1",
},
},
} }
l2 := database.Layer{ noti, ok, err := tx.FindVulnerabilityNotification("test", 1, "", "")
Name: "TestNotificationLayer2", oldPage := database.PagedVulnerableAncestries{
Features: []database.FeatureVersion{ Vulnerability: vOld,
{ Limit: 1,
Feature: f1, Affected: make(map[int]string),
Version: "0.2", End: true,
},
},
} }
l3 := database.Layer{ newPage1 := database.PagedVulnerableAncestries{
Name: "TestNotificationLayer3", Vulnerability: vNew,
Features: []database.FeatureVersion{ Limit: 1,
{ Affected: map[int]string{3: "ancestry-3"},
Feature: f1, End: false,
Version: "0.3",
},
},
} }
l4 := database.Layer{ newPage2 := database.PagedVulnerableAncestries{
Name: "TestNotificationLayer4", Vulnerability: vNew,
Features: []database.FeatureVersion{ Limit: 1,
{ Affected: map[int]string{4: "ancestry-4"},
Feature: f2, Next: "",
Version: "0.1", End: true,
},
},
} }
if !assert.Nil(t, datastore.InsertLayer(l1)) || if assert.Nil(t, err) && assert.True(t, ok) {
!assert.Nil(t, datastore.InsertLayer(l2)) || assert.Equal(t, "test", noti.Name)
!assert.Nil(t, datastore.InsertLayer(l3)) || if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
!assert.Nil(t, datastore.InsertLayer(l4)) { oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey)
return if !assert.Nil(t, err) {
assert.FailNow(t, "")
} }
// Insert a new vulnerability that is introduced by three layers. assert.Equal(t, int64(0), oldPageNum.StartID)
v1 := database.Vulnerability{ newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey)
Name: "TestNotificationVulnerability1", if !assert.Nil(t, err) {
Namespace: f1.Namespace, assert.FailNow(t, "")
Description: "TestNotificationDescription1",
Link: "TestNotificationLink1",
Severity: "Unknown",
FixedIn: []database.FeatureVersion{
{
Feature: f1,
Version: "1.0",
},
},
} }
assert.Nil(t, datastore.insertVulnerability(v1, false, true)) newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey)
if !assert.Nil(t, err) {
// Get the notification associated to the previously inserted vulnerability. assert.FailNow(t, "")
notification, err := datastore.GetAvailableNotification(time.Second)
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
// Verify the renotify behaviour.
if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) {
_, err := datastore.GetAvailableNotification(time.Second)
assert.Equal(t, commonerr.ErrNotFound, err)
time.Sleep(50 * time.Millisecond)
notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond)
assert.Nil(t, err)
assert.Equal(t, notification.Name, notificationB.Name)
datastore.SetNotificationNotified(notification.Name)
} }
assert.Equal(t, int64(0), newPageNum.StartID)
assert.Equal(t, int64(4), newPageNextNum.StartID)
// Get notification. noti.Old.Current = ""
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) noti.New.Current = ""
if assert.Nil(t, err) { noti.New.Next = ""
assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage) assert.Equal(t, oldPage, *noti.Old)
assert.Nil(t, filledNotification.OldVulnerability) assert.Equal(t, newPage1, *noti.New)
if assert.NotNil(t, filledNotification.NewVulnerability) {
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2)
} }
} }
// Get second page. page1, err := encryptPage(idPageNumber{0}, tx.paginationKey)
filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage) if !assert.Nil(t, err) {
if assert.Nil(t, err) { assert.FailNow(t, "")
assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage)
assert.Nil(t, filledNotification.OldVulnerability)
if assert.NotNil(t, filledNotification.NewVulnerability) {
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
}
} }
// Delete notification. page2, err := encryptPage(idPageNumber{4}, tx.paginationKey)
assert.Nil(t, datastore.DeleteNotification(notification.Name)) if !assert.Nil(t, err) {
assert.FailNow(t, "")
_, err = datastore.GetAvailableNotification(time.Millisecond)
assert.Equal(t, commonerr.ErrNotFound, err)
} }
// Update a vulnerability and ensure that the old/new vulnerabilities are correct. noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2)
v1b := v1 if assert.Nil(t, err) && assert.True(t, ok) {
v1b.Severity = database.HighSeverity assert.Equal(t, "test", noti.Name)
v1b.FixedIn = []database.FeatureVersion{ if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
{ oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey)
Feature: f1, if !assert.Nil(t, err) {
Version: versionfmt.MinVersion, assert.FailNow(t, "")
},
{
Feature: f2,
Version: versionfmt.MaxVersion,
},
} }
if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) { newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey)
notification, err = datastore.GetAvailableNotification(time.Second) if !assert.Nil(t, err) {
assert.Nil(t, err) assert.FailNow(t, "")
assert.NotEmpty(t, notification.Name)
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
if assert.Nil(t, err) {
if assert.NotNil(t, filledNotification.OldVulnerability) {
assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name)
assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity)
assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2)
} }
if assert.NotNil(t, filledNotification.NewVulnerability) { assert.Equal(t, int64(0), oldCurrentPage.StartID)
assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name) assert.Equal(t, int64(4), newCurrentPage.StartID)
assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity) noti.Old.Current = ""
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) noti.New.Current = ""
} assert.Equal(t, oldPage, *noti.Old)
assert.Equal(t, newPage2, *noti.New)
assert.Equal(t, -1, nextPage.NewVulnerability)
}
assert.Nil(t, datastore.DeleteNotification(notification.Name))
}
}
// Delete a vulnerability and verify the notification.
if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) {
notification, err = datastore.GetAvailableNotification(time.Second)
assert.Nil(t, err)
assert.NotEmpty(t, notification.Name)
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
if assert.Nil(t, err) {
assert.Nil(t, filledNotification.NewVulnerability)
if assert.NotNil(t, filledNotification.OldVulnerability) {
assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name)
assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity)
assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1)
}
}
assert.Nil(t, datastore.DeleteNotification(notification.Name))
} }
} }
} }
func TestInsertVulnerabilityNotifications(t *testing.T) {
datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true)
n1 := database.VulnerabilityNotification{}
n3 := database.VulnerabilityNotification{
NotificationHook: database.NotificationHook{
Name: "random name",
Created: time.Now(),
},
Old: nil,
New: &database.Vulnerability{},
}
n4 := database.VulnerabilityNotification{
NotificationHook: database.NotificationHook{
Name: "random name",
Created: time.Now(),
},
Old: nil,
New: &database.Vulnerability{
Name: "CVE-OPENSSL-1-DEB7",
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
},
},
}
// invalid case
err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1})
assert.NotNil(t, err)
// invalid case: unknown vulnerability
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3})
assert.NotNil(t, err)
// invalid case: duplicated input notification
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4})
assert.NotNil(t, err)
tx = restartSession(t, datastore, tx, false)
// valid case
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
assert.Nil(t, err)
// invalid case: notification is already in database
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
assert.NotNil(t, err)
closeTest(t, datastore, tx)
}
func TestFindNewNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindNewNotification", true)
defer closeTest(t, datastore, tx)
noti, ok, err := tx.FindNewNotification(time.Now())
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
assert.Equal(t, time.Time{}, noti.Notified)
assert.Equal(t, time.Time{}, noti.Created)
assert.Equal(t, time.Time{}, noti.Deleted)
}
// can't find the notified
assert.Nil(t, tx.MarkNotificationNotified("test"))
// if the notified time is before
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second)))
assert.Nil(t, err)
assert.False(t, ok)
// can find the notified after a period of time
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
assert.NotEqual(t, time.Time{}, noti.Notified)
assert.Equal(t, time.Time{}, noti.Created)
assert.Equal(t, time.Time{}, noti.Deleted)
}
assert.Nil(t, tx.DeleteNotification("test"))
// can't find in any time
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000)))
assert.Nil(t, err)
assert.False(t, ok)
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
assert.Nil(t, err)
assert.False(t, ok)
}
func TestMarkNotificationNotified(t *testing.T) {
datastore, tx := openSessionForTest(t, "MarkNotificationNotified", true)
defer closeTest(t, datastore, tx)
// invalid case: notification doesn't exist
assert.NotNil(t, tx.MarkNotificationNotified("non-existing"))
// valid case
assert.Nil(t, tx.MarkNotificationNotified("test"))
// valid case
assert.Nil(t, tx.MarkNotificationNotified("test"))
}
func TestDeleteNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "DeleteNotification", true)
defer closeTest(t, datastore, tx)
// invalid case: notification doesn't exist
assert.NotNil(t, tx.DeleteNotification("non-existing"))
// valid case
assert.Nil(t, tx.DeleteNotification("test"))
// invalid case: notification is already deleted
assert.NotNil(t, tx.DeleteNotification("test"))
}

View File

@ -31,6 +31,7 @@ import (
"github.com/remind101/migrate" "github.com/remind101/migrate"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/coreos/clair/api/token"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/migrations" "github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
@ -59,7 +60,7 @@ var (
promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "clair_pgsql_concurrent_lock_vafv_total", Name: "clair_pgsql_concurrent_lock_vafv_total",
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.", Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.",
}) })
) )
@ -73,17 +74,65 @@ func init() {
database.Register("pgsql", openDatabase) database.Register("pgsql", openDatabase)
} }
type Queryer interface { // pgSessionCache is the session's cache, which holds the pgSQL's cache and the
Query(query string, args ...interface{}) (*sql.Rows, error) // individual session's cache. Only when session.Commit is called, all the
QueryRow(query string, args ...interface{}) *sql.Row // changes to pgSQL cache will be applied.
type pgSessionCache struct {
c *lru.ARCCache
} }
type pgSQL struct { type pgSQL struct {
*sql.DB *sql.DB
cache *lru.ARCCache cache *lru.ARCCache
config Config config Config
} }
type pgSession struct {
*sql.Tx
paginationKey string
}
type idPageNumber struct {
// StartID is an implementation detail for paginating by an ID required to
// be unique to every ancestry and always increasing.
//
// StartID is used to search for ancestry with ID >= StartID
StartID int64
}
func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) {
resultBytes, err := token.Marshal(page, paginationKey)
if err != nil {
return result, err
}
result = database.PageNumber(resultBytes)
return result, nil
}
func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) {
err = token.Unmarshal(string(page), paginationKey, &result)
return
}
// Begin initiates a transaction to database. The expected transaction isolation
// level in this implementation is "Read Committed".
func (pgSQL *pgSQL) Begin() (database.Session, error) {
tx, err := pgSQL.DB.Begin()
if err != nil {
return nil, err
}
return &pgSession{
Tx: tx,
paginationKey: pgSQL.config.PaginationKey,
}, nil
}
func (tx *pgSession) Commit() error {
return tx.Tx.Commit()
}
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
// the configuration. // the configuration.
func (pgSQL *pgSQL) Close() { func (pgSQL *pgSQL) Close() {
@ -109,6 +158,7 @@ type Config struct {
ManageDatabaseLifecycle bool ManageDatabaseLifecycle bool
FixturePath string FixturePath string
PaginationKey string
} }
// openDatabase opens a PostgresSQL-backed Datastore using the given // openDatabase opens a PostgresSQL-backed Datastore using the given
@ -134,6 +184,10 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
} }
if pg.config.PaginationKey == "" {
panic("pagination key should be given")
}
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
if err != nil { if err != nil {
return nil, err return nil, err
@ -179,7 +233,7 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig
_, err = pg.DB.Exec(string(d)) _, err = pg.DB.Exec(string(d))
if err != nil { if err != nil {
pg.Close() pg.Close()
return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err) return nil, fmt.Errorf("pgsql: an error occurred while importing fixtures: %v", err)
} }
} }
@ -217,7 +271,7 @@ func migrateDatabase(db *sql.DB) error {
err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...) err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...)
if err != nil { if err != nil {
return fmt.Errorf("pgsql: an error occured while running migrations: %v", err) return fmt.Errorf("pgsql: an error occurred while running migrations: %v", err)
} }
log.Info("database migration ran successfully") log.Info("database migration ran successfully")
@ -271,7 +325,8 @@ func dropDatabase(source, dbName string) error {
} }
// handleError logs an error with an extra description and masks the error if it's an SQL one. // handleError logs an error with an extra description and masks the error if it's an SQL one.
// This ensures we never return plain SQL errors and leak anything. // The function ensures we never return plain SQL errors and leak anything.
// The function should be used for every database query error.
func handleError(desc string, err error) error { func handleError(desc string, err error) error {
if err == nil { if err == nil {
return nil return nil
@ -297,6 +352,11 @@ func isErrUniqueViolation(err error) bool {
return ok && pqErr.Code == "23505" return ok && pqErr.Code == "23505"
} }
// observeQueryTime computes the time elapsed since `start` to represent the
// query time.
// 1. `query` is a pgSession function name.
// 2. `subquery` is a specific query or a batched query.
// 3. `start` is the time right before query is executed.
func observeQueryTime(query, subquery string, start time.Time) { func observeQueryTime(query, subquery string, start time.Time) {
promQueryDurationMilliseconds. promQueryDurationMilliseconds.
WithLabelValues(query, subquery). WithLabelValues(query, subquery).

View File

@ -15,27 +15,193 @@
package pgsql package pgsql
import ( import (
"database/sql"
"fmt" "fmt"
"io/ioutil"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"testing"
fernet "github.com/fernet/fernet-go"
"github.com/pborman/uuid" "github.com/pborman/uuid"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
yaml "gopkg.in/yaml.v2"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
) )
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { var (
ds, err := openDatabase(generateTestConfig(testName, loadFixture)) withFixtureName, withoutFixtureName string
)
func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) {
config := generateTestConfig(name, loadFixture, false)
source := config.Options["source"].(string)
name, url, err := parseConnectionString(source)
if err != nil {
panic(err)
}
fixturePath := config.Options["fixturepath"].(string)
if err := createDatabase(url, name); err != nil {
panic(err)
}
// migration and fixture
db, err := sql.Open("postgres", source)
if err != nil {
panic(err)
}
// Verify database state.
if err := db.Ping(); err != nil {
panic(err)
}
// Run migrations.
if err := migrateDatabase(db); err != nil {
panic(err)
}
if loadFixture {
log.Info("pgsql: loading fixtures")
d, err := ioutil.ReadFile(fixturePath)
if err != nil {
panic(err)
}
_, err = db.Exec(string(d))
if err != nil {
panic(err)
}
}
db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name)
db.Close()
log.Info("Generated Template database ", name)
return url, name
}
func dropTemplateDatabase(url string, name string) {
db, err := sql.Open("postgres", url)
if err != nil {
panic(err)
}
if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil {
panic(err)
}
if err := db.Close(); err != nil {
panic(err)
}
if err := dropDatabase(url, name); err != nil {
panic(err)
}
}
func TestMain(m *testing.M) {
fURL, fName := genTemplateDatabase("fixture", true)
nfURL, nfName := genTemplateDatabase("nonfixture", false)
withFixtureName = fName
withoutFixtureName = nfName
m.Run()
dropTemplateDatabase(fURL, fName)
dropTemplateDatabase(nfURL, nfName)
}
func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) {
var fixtureName string
if fixture {
fixtureName = withFixtureName
} else {
fixtureName = withoutFixtureName
}
// copy the database into new database
var pg pgSQL
// Parse configuration.
pg.config = Config{
CacheSize: 16384,
}
bytes, err := yaml.Marshal(testConfig.Options)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
err = yaml.Unmarshal(bytes, &pg.config)
if err != nil {
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
}
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
if err != nil { if err != nil {
return nil, err return nil, err
} }
datastore := ds.(*pgSQL)
// Create database.
if pg.config.ManageDatabaseLifecycle {
if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil {
return nil, err
}
}
// Open database.
pg.DB, err = sql.Open("postgres", pg.config.Source)
fmt.Println("database", pg.config.Source)
if err != nil {
pg.Close()
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
}
return &pg, nil
}
// copyDatabase creates a new database with
func copyDatabase(url, name string, templateName string) error {
// Open database.
db, err := sql.Open("postgres", url)
if err != nil {
return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
}
defer db.Close()
// Create database with copy
_, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName)
if err != nil {
return fmt.Errorf("pgsql: could not create database: %v", err)
}
return nil
}
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
var (
db database.Datastore
err error
testConfig = generateTestConfig(testName, loadFixture, true)
)
db, err = openCopiedDatabase(testConfig, loadFixture)
if err != nil {
return nil, err
}
datastore := db.(*pgSQL)
return datastore, nil return datastore, nil
} }
func generateTestConfig(testName string, loadFixture bool) database.RegistrableComponentConfig { func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig {
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1) dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
var fixturePath string var fixturePath string
@ -49,12 +215,60 @@ func generateTestConfig(testName string, loadFixture bool) database.RegistrableC
source = fmt.Sprintf(sourceEnv, dbName) source = fmt.Sprintf(sourceEnv, dbName)
} }
var key fernet.Key
if err := key.Generate(); err != nil {
panic("failed to generate pagination key" + err.Error())
}
return database.RegistrableComponentConfig{ return database.RegistrableComponentConfig{
Options: map[string]interface{}{ Options: map[string]interface{}{
"source": source, "source": source,
"cachesize": 0, "cachesize": 0,
"managedatabaselifecycle": true, "managedatabaselifecycle": manageLife,
"fixturepath": fixturePath, "fixturepath": fixturePath,
"paginationkey": key.Encode(),
}, },
} }
} }
func closeTest(t *testing.T, store database.Datastore, session database.Session) {
err := session.Rollback()
if err != nil {
t.Error(err)
t.FailNow()
}
store.Close()
}
func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) {
store, err := openDatabaseForTest(name, loadFixture)
if err != nil {
t.Error(err)
t.FailNow()
}
tx, err := store.Begin()
if err != nil {
t.Error(err)
t.FailNow()
}
return store, tx.(*pgSession)
}
func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession {
var err error
if !commit {
err = tx.Rollback()
} else {
err = tx.Commit()
}
if assert.Nil(t, err) {
session, err := datastore.Begin()
if assert.Nil(t, err) {
return session.(*pgSession)
}
}
t.FailNow()
return nil
}

View File

@ -14,172 +14,146 @@
package pgsql package pgsql
import "strconv" import (
"fmt"
"strings"
"github.com/lib/pq"
)
const ( const (
lockVulnerabilityAffects = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE`
disableHashJoin = `SET LOCAL enable_hashjoin = off`
disableMergeJoin = `SET LOCAL enable_mergejoin = off`
// keyvalue.go // keyvalue.go
updateKeyValue = `UPDATE KeyValue SET value = $1 WHERE key = $2`
insertKeyValue = `INSERT INTO KeyValue(key, value) VALUES($1, $2)`
searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1` searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1`
upsertKeyValue = `
INSERT INTO KeyValue(key, value)
VALUES ($1, $2)
ON CONFLICT ON CONSTRAINT keyvalue_key_key
DO UPDATE SET key=$1, value=$2`
// namespace.go // namespace.go
soiNamespace = `
WITH new_namespace AS (
INSERT INTO Namespace(name, version_format)
SELECT CAST($1 AS VARCHAR), CAST($2 AS VARCHAR)
WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1)
RETURNING id
)
SELECT id FROM Namespace WHERE name = $1
UNION
SELECT id FROM new_namespace`
searchNamespace = `SELECT id FROM Namespace WHERE name = $1` searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2`
listNamespace = `SELECT id, name, version_format FROM Namespace`
// feature.go // feature.go
soiFeature = ` soiNamespacedFeature = `
WITH new_feature AS ( WITH new_feature_ns AS (
INSERT INTO Feature(name, namespace_id) INSERT INTO namespaced_feature(feature_id, namespace_id)
SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER) SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER)
WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2) WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2)
RETURNING id RETURNING id
) )
SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2 SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2
UNION UNION
SELECT id FROM new_feature` SELECT id FROM new_feature_ns`
searchFeatureVersion = ` searchPotentialAffectingVulneraibilities = `
SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2` SELECT nf.id, v.id, vaf.affected_version, vaf.id
FROM vulnerability_affected_feature AS vaf, vulnerability AS v,
soiFeatureVersion = ` namespaced_feature AS nf, feature AS f
WITH new_featureversion AS ( WHERE nf.id = ANY($1)
INSERT INTO FeatureVersion(feature_id, version) AND nf.feature_id = f.id
SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR) AND nf.namespace_id = v.namespace_id
WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2) AND vaf.feature_name = f.name
RETURNING id AND vaf.vulnerability_id = v.id
)
SELECT false, id FROM FeatureVersion WHERE feature_id = $1 AND version = $2
UNION
SELECT true, id FROM new_featureversion`
searchVulnerabilityFixedInFeature = `
SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature
WHERE feature_id = $1`
insertVulnerabilityAffectsFeatureVersion = `
INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id)
VALUES($1, $2, $3)`
// layer.go
searchLayer = `
SELECT l.id, l.name, l.engineversion, p.id, p.name
FROM Layer l
LEFT JOIN Layer p ON l.parent_id = p.id
WHERE l.name = $1;`
searchLayerNamespace = `
SELECT n.id, n.name, n.version_format
FROM Namespace n
JOIN Layer_Namespace lns ON lns.namespace_id = n.id
WHERE lns.layer_id = $1`
searchLayerFeatureVersion = `
WITH RECURSIVE layer_tree(id, name, parent_id, depth, path, cycle) AS(
SELECT l.id, l.name, l.parent_id, 1, ARRAY[l.id], false
FROM Layer l
WHERE l.id = $1
UNION ALL
SELECT l.id, l.name, l.parent_id, lt.depth + 1, path || l.id, l.id = ANY(path)
FROM Layer l, layer_tree lt
WHERE l.id = lt.parent_id
)
SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, fn.version_format, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name
FROM Layer_diff_FeatureVersion ldf
JOIN (
SELECT row_number() over (ORDER BY depth DESC), id, name FROM layer_tree
) AS ltree (ordering, id, name) ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn
WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id
ORDER BY ltree.ordering`
searchFeatureVersionVulnerability = `
SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata,
vn.name, vn.version_format, vfif.version
FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v,
Namespace vn, Vulnerability_FixedIn_Feature vfif
WHERE vafv.featureversion_id = ANY($1::integer[])
AND vfif.vulnerability_id = v.id
AND vafv.fixedin_id = vfif.id
AND v.namespace_id = vn.id
AND v.deleted_at IS NULL` AND v.deleted_at IS NULL`
insertLayer = ` searchNamespacedFeaturesVulnerabilities = `
INSERT INTO Layer(name, engineversion, parent_id, created_at) SELECT vanf.namespaced_feature_id, v.name, v.description, v.link,
VALUES($1, $2, $3, CURRENT_TIMESTAMP) v.severity, v.metadata, vaf.fixedin, n.name, n.version_format
RETURNING id` FROM vulnerability_affected_namespaced_feature AS vanf,
Vulnerability AS v,
vulnerability_affected_feature AS vaf,
namespace AS n
WHERE vanf.namespaced_feature_id = ANY($1)
AND vaf.id = vanf.added_by
AND v.id = vanf.vulnerability_id
AND n.id = v.namespace_id
AND v.deleted_at IS NULL`
insertLayerNamespace = `INSERT INTO Layer_Namespace(layer_id, namespace_id) VALUES($1, $2)` // layer.go
removeLayerNamespace = `DELETE FROM Layer_Namespace WHERE layer_id = $1` searchLayerIDs = `SELECT id, hash FROM layer WHERE hash = ANY($1);`
updateLayer = `UPDATE LAYER SET engineversion = $2 WHERE id = $1` searchLayerFeatures = `
SELECT feature.Name, feature.Version, feature.version_format
FROM feature, layer_feature
WHERE layer_feature.layer_id = $1
AND layer_feature.feature_id = feature.id`
removeLayerDiffFeatureVersion = ` searchLayerNamespaces = `
DELETE FROM Layer_diff_FeatureVersion SELECT namespace.Name, namespace.version_format
WHERE layer_id = $1` FROM namespace, layer_namespace
WHERE layer_namespace.layer_id = $1
AND layer_namespace.namespace_id = namespace.id`
insertLayerDiffFeatureVersion = ` searchLayer = `SELECT id FROM layer WHERE hash = $1`
INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) searchLayerDetectors = `SELECT detector FROM layer_detector WHERE layer_id = $1`
SELECT $1, fv.id, $2 searchLayerListers = `SELECT lister FROM layer_lister WHERE layer_id = $1`
FROM FeatureVersion fv
WHERE fv.id = ANY($3::integer[])`
removeLayer = `DELETE FROM Layer WHERE name = $1`
// lock.go // lock.go
insertLock = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)` soiLock = `INSERT INTO lock(name, owner, until) VALUES ($1, $2, $3)`
searchLock = `SELECT owner, until FROM Lock WHERE name = $1` searchLock = `SELECT owner, until FROM Lock WHERE name = $1`
updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2` updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2`
removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2` removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2`
removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP`
// vulnerability.go // vulnerability.go
searchVulnerabilityBase = ` searchVulnerability = `
SELECT v.id, v.name, n.id, n.name, n.version_format, v.description, v.link, v.severity, v.metadata SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format
FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` FROM vulnerability AS v, namespace AS n
searchVulnerabilityForUpdate = ` FOR UPDATE OF v` WHERE v.namespace_id = n.id
searchVulnerabilityByNamespaceAndName = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL` AND v.name = $1
searchVulnerabilityByID = ` WHERE v.id = $1` AND n.name = $2
searchVulnerabilityByNamespace = ` WHERE n.name = $1 AND v.deleted_at IS NULL AND v.deleted_at IS NULL
AND v.id >= $2 `
ORDER BY v.id
LIMIT $3`
searchVulnerabilityFixedIn = ` insertVulnerabilityAffected = `
SELECT vfif.version, f.id, f.Name INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin)
FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id VALUES ($1, $2, $3, $4)
WHERE vfif.vulnerability_id = $1` RETURNING ID
`
searchVulnerabilityAffected = `
SELECT vulnerability_id, feature_name, affected_version, fixedin
FROM vulnerability_affected_feature
WHERE vulnerability_id = ANY($1)
`
searchVulnerabilityByID = `
SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format
FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id
AND v.id = $1`
searchVulnerabilityPotentialAffected = `
WITH req AS (
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id
FROM vulnerability_affected_feature AS vaf,
vulnerability AS v,
namespace AS n
WHERE vaf.vulnerability_id = ANY($1)
AND v.id = vaf.vulnerability_id
AND n.id = v.namespace_id
)
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
FROM feature AS f, namespaced_feature AS nf, req
WHERE f.name = req.name
AND nf.namespace_id = req.n_id
AND nf.feature_id = f.id`
insertVulnerabilityAffectedNamespacedFeature = `
INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by)
VALUES ($1, $2, $3)`
insertVulnerability = ` insertVulnerability = `
INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) WITH ns AS (
VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP) SELECT id FROM namespace WHERE name = $6 AND version_format = $7
RETURNING id`
soiVulnerabilityFixedInFeature = `
WITH new_fixedinfeature AS (
INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version)
SELECT CAST($1 AS INTEGER), CAST($2 AS INTEGER), CAST($3 AS VARCHAR)
WHERE NOT EXISTS (SELECT id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2)
RETURNING id
) )
SELECT false, id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2 INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at)
UNION VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP)
SELECT true, id FROM new_fixedinfeature` RETURNING id`
searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = $1`
removeVulnerability = ` removeVulnerability = `
UPDATE Vulnerability UPDATE Vulnerability
@ -192,7 +166,7 @@ const (
// notification.go // notification.go
insertNotification = ` insertNotification = `
INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id)
VALUES($1, CURRENT_TIMESTAMP, $2, $3)` VALUES ($1, $2, $3, $4)`
updatedNotificationNotified = ` updatedNotificationNotified = `
UPDATE Vulnerability_Notification UPDATE Vulnerability_Notification
@ -202,10 +176,10 @@ const (
removeNotification = ` removeNotification = `
UPDATE Vulnerability_Notification UPDATE Vulnerability_Notification
SET deleted_at = CURRENT_TIMESTAMP SET deleted_at = CURRENT_TIMESTAMP
WHERE name = $1` WHERE name = $1 AND deleted_at IS NULL`
searchNotificationAvailable = ` searchNotificationAvailable = `
SELECT id, name, created_at, notified_at, deleted_at SELECT name, created_at, notified_at, deleted_at
FROM Vulnerability_Notification FROM Vulnerability_Notification
WHERE (notified_at IS NULL OR notified_at < $1) WHERE (notified_at IS NULL OR notified_at < $1)
AND deleted_at IS NULL AND deleted_at IS NULL
@ -214,43 +188,231 @@ const (
LIMIT 1` LIMIT 1`
searchNotification = ` searchNotification = `
SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
FROM Vulnerability_Notification FROM Vulnerability_Notification
WHERE name = $1` WHERE name = $1`
searchNotificationLayerIntroducingVulnerability = ` searchNotificationVulnerableAncestry = `
WITH LDFV AS ( SELECT DISTINCT ON (a.id)
SELECT DISTINCT ldfv.layer_id a.id, a.name
FROM Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv FROM vulnerability_affected_namespaced_feature AS vanf,
WHERE ldfv.layer_id >= $2 ancestry AS a, ancestry_feature AS af
AND vafv.vulnerability_id = $1 WHERE vanf.vulnerability_id = $1
AND vafv.featureversion_id = fv.id AND a.id >= $2
AND ldfv.featureversion_id = fv.id AND a.id = af.ancestry_id
AND ldfv.modification = 'add' AND af.namespaced_feature_id = vanf.namespaced_feature_id
ORDER BY ldfv.layer_id ORDER BY a.id ASC
) LIMIT $3;`
SELECT l.id, l.name
FROM LDFV, Layer l
WHERE LDFV.layer_id = l.id
LIMIT $3`
// complex_test.go // ancestry.go
searchComplexTestFeatureVersionAffects = ` persistAncestryLister = `
SELECT v.name INSERT INTO ancestry_lister (ancestry_id, lister)
FROM FeatureVersion fv SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT)
LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id WHERE NOT EXISTS (SELECT id FROM ancestry_lister WHERE ancestry_id = $1 AND lister = $2) ON CONFLICT DO NOTHING`
JOIN Vulnerability v ON vaf.vulnerability_id = v.id
WHERE featureversion_id = $1` persistAncestryDetector = `
INSERT INTO ancestry_detector (ancestry_id, detector)
SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT)
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector = $2) ON CONFLICT DO NOTHING`
insertAncestry = `INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
searchAncestryLayer = `
SELECT layer.hash
FROM layer, ancestry_layer
WHERE ancestry_layer.ancestry_id = $1
AND ancestry_layer.layer_id = layer.id
ORDER BY ancestry_layer.ancestry_index ASC`
searchAncestryFeatures = `
SELECT namespace.name, namespace.version_format, feature.name, feature.version
FROM namespace, feature, ancestry, namespaced_feature, ancestry_feature
WHERE ancestry.name = $1
AND ancestry.id = ancestry_feature.ancestry_id
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
AND namespaced_feature.feature_id = feature.id
AND namespaced_feature.namespace_id = namespace.id`
searchAncestry = `SELECT id FROM ancestry WHERE name = $1`
searchAncestryDetectors = `SELECT detector FROM ancestry_detector WHERE ancestry_id = $1`
searchAncestryListers = `SELECT lister FROM ancestry_lister WHERE ancestry_id = $1`
removeAncestry = `DELETE FROM ancestry WHERE name = $1`
insertAncestryLayer = `INSERT INTO ancestry_layer(ancestry_id, ancestry_index, layer_id) VALUES($1,$2,$3)`
insertAncestryFeature = `INSERT INTO ancestry_feature(ancestry_id, namespaced_feature_id) VALUES ($1, $2)`
) )
// buildInputArray constructs a PostgreSQL input array from the specified integers. // NOTE(Sida): Every search query can only have count less than postgres set
// Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using // stack depth. IN will be resolved to nested OR_s and the parser might exceed
// a single placeholder. // stack depth. TODO(Sida): Generate different queries for different count: if
func buildInputArray(ints []int) string { // count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
str := "{" // count > 65535, use is expected to split data into batches.
for i := 0; i < len(ints)-1; i++ { func querySearchLastDeletedVulnerabilityID(count int) string {
str = str + strconv.Itoa(ints[i]) + "," return fmt.Sprintf(`
} SELECT vid, vname, nname FROM (
str = str + strconv.Itoa(ints[len(ints)-1]) + "}" SELECT v.id AS vid, v.name AS vname, n.name AS nname,
return str row_number() OVER (
PARTITION by (v.name, n.name)
ORDER BY v.deleted_at DESC
) AS rownum
FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id
AND (v.name, n.name) IN ( %s )
AND v.deleted_at IS NOT NULL
) tmp WHERE rownum <= 1`,
queryString(2, count))
}
func querySearchNotDeletedVulnerabilityID(count int) string {
return fmt.Sprintf(`
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
AND v.deleted_at IS NULL`,
queryString(2, count))
}
func querySearchFeatureID(featureCount int) string {
return fmt.Sprintf(`
SELECT id, name, version, version_format
FROM Feature WHERE (name, version, version_format) IN (%s)`,
queryString(3, featureCount),
)
}
func querySearchNamespacedFeature(nsfCount int) string {
return fmt.Sprintf(`
SELECT nf.id, f.name, f.version, f.version_format, n.name
FROM namespaced_feature AS nf, feature AS f, namespace AS n
WHERE nf.feature_id = f.id
AND nf.namespace_id = n.id
AND n.version_format = f.version_format
AND (f.name, f.version, f.version_format, n.name) IN (%s)`,
queryString(4, nsfCount),
)
}
func querySearchNamespace(nsCount int) string {
return fmt.Sprintf(
`SELECT id, name, version_format
FROM namespace WHERE (name, version_format) IN (%s)`,
queryString(2, nsCount),
)
}
func queryInsert(count int, table string, columns ...string) string {
base := `INSERT INTO %s (%s) VALUES %s`
t := pq.QuoteIdentifier(table)
cols := make([]string, len(columns))
for i, c := range columns {
cols[i] = pq.QuoteIdentifier(c)
}
colsQuoted := strings.Join(cols, ",")
return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count))
}
func queryPersist(count int, table, constraint string, columns ...string) string {
ct := ""
if constraint != "" {
ct = fmt.Sprintf("ON CONSTRAINT %s", constraint)
}
return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct)
}
func queryInsertNotifications(count int) string {
return queryInsert(count,
"vulnerability_notification",
"name",
"created_at",
"old_vulnerability_id",
"new_vulnerability_id",
)
}
func queryPersistFeature(count int) string {
return queryPersist(count,
"feature",
"feature_name_version_version_format_key",
"name",
"version",
"version_format")
}
func queryPersistLayerFeature(count int) string {
return queryPersist(count,
"layer_feature",
"layer_feature_layer_id_feature_id_key",
"layer_id",
"feature_id")
}
func queryPersistNamespace(count int) string {
return queryPersist(count,
"namespace",
"namespace_name_version_format_key",
"name",
"version_format")
}
func queryPersistLayerListers(count int) string {
return queryPersist(count,
"layer_lister",
"layer_lister_layer_id_lister_key",
"layer_id",
"lister")
}
func queryPersistLayerDetectors(count int) string {
return queryPersist(count,
"layer_detector",
"layer_detector_layer_id_detector_key",
"layer_id",
"detector")
}
func queryPersistLayerNamespace(count int) string {
return queryPersist(count,
"layer_namespace",
"layer_namespace_layer_id_namespace_id_key",
"layer_id",
"namespace_id")
}
// size of key and array should be both greater than 0
func queryString(keySize, arraySize int) string {
if arraySize <= 0 || keySize <= 0 {
panic("Bulk Query requires size of element tuple and number of elements to be greater than 0")
}
keys := make([]string, 0, arraySize)
for i := 0; i < arraySize; i++ {
key := make([]string, keySize)
for j := 0; j < keySize; j++ {
key[j] = fmt.Sprintf("$%d", i*keySize+j+1)
}
keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ",")))
}
return strings.Join(keys, ",")
}
func queryPersistNamespacedFeature(count int) string {
return queryPersist(count, "namespaced_feature",
"namespaced_feature_namespace_id_feature_id_key",
"feature_id",
"namespace_id")
}
func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string {
return queryPersist(count, "vulnerability_affected_namespaced_feature",
"vulnerability_affected_namesp_vulnerability_id_namespaced_f_key",
"vulnerability_id",
"namespaced_feature_id",
"added_by")
}
func queryPersistLayer(count int) string {
return queryPersist(count, "layer", "", "hash")
}
func queryInvalidateVulnerabilityCache(count int) string {
return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
WHERE vulnerability_id = (%s)`,
queryString(1, count))
} }

View File

@ -1,73 +1,117 @@
-- Copyright 2015 clair authors
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
INSERT INTO namespace (id, name, version_format) VALUES INSERT INTO namespace (id, name, version_format) VALUES
(1, 'debian:7', 'dpkg'), (1, 'debian:7', 'dpkg'),
(2, 'debian:8', 'dpkg'); (2, 'debian:8', 'dpkg'),
(3, 'fake:1.0', 'rpm');
INSERT INTO feature (id, namespace_id, name) VALUES INSERT INTO feature (id, name, version, version_format) VALUES
(1, 1, 'wechat'), (1, 'wechat', '0.5', 'dpkg'),
(2, 1, 'openssl'), (2, 'openssl', '1.0', 'dpkg'),
(4, 1, 'libssl'), (3, 'openssl', '2.0', 'dpkg'),
(3, 2, 'openssl'); (4, 'fake', '2.0', 'rpm');
INSERT INTO featureversion (id, feature_id, version) VALUES INSERT INTO layer (id, hash) VALUES
(1, 1, '0.5'), (1, 'layer-0'), -- blank
(2, 2, '1.0'), (2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0
(3, 2, '2.0'), (3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0
(4, 3, '1.0'); (4, 'layer-3a'),-- debian:7;
(5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0
(6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake)
INSERT INTO layer (id, name, engineversion, parent_id) VALUES INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES
(1, 'layer-0', 1, NULL),
(2, 'layer-1', 1, 1),
(3, 'layer-2', 1, 2),
(4, 'layer-3a', 1, 3),
(5, 'layer-3b', 1, 3);
INSERT INTO layer_namespace (id, layer_id, namespace_id) VALUES
(1, 2, 1), (1, 2, 1),
(2, 3, 1), (2, 3, 1),
(3, 4, 1), (3, 4, 1),
(4, 5, 2), (4, 5, 2),
(5, 5, 1); (5, 6, 1),
(6, 6, 3);
INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modification) VALUES INSERT INTO layer_feature(id, layer_id, feature_id) VALUES
(1, 2, 1, 'add'), (1, 2, 1),
(2, 2, 2, 'add'), (2, 2, 2),
(3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0 (3, 3, 1),
(4, 3, 3, 'add'), -- ^ (4, 3, 3),
(5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0 (5, 5, 1),
(6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0 (6, 5, 2),
(7, 6, 4),
(8, 6, 3);
INSERT INTO layer_lister(id, layer_id, lister) VALUES
(1, 1, 'dpkg'),
(2, 2, 'dpkg'),
(3, 3, 'dpkg'),
(4, 4, 'dpkg'),
(5, 5, 'dpkg'),
(6, 6, 'dpkg'),
(7, 6, 'rpm');
INSERT INTO layer_detector(id, layer_id, detector) VALUES
(1, 1, 'os-release'),
(2, 2, 'os-release'),
(3, 3, 'os-release'),
(4, 4, 'os-release'),
(5, 5, 'os-release'),
(6, 6, 'os-release'),
(7, 6, 'apt-sources');
INSERT INTO ancestry (id, name) VALUES
(1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a
(2, 'ancestry-2'), -- layer-0, layer-1, layer-2, layer-3b
(3, 'ancestry-3'), -- empty; just for testing the vulnerable ancestry
(4, 'ancestry-4'); -- empty; just for testing the vulnerable ancestry
INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES
(1, 1, 'dpkg'),
(2, 2, 'dpkg');
INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES
(1, 1, 'os-release'),
(2, 2, 'os-release');
INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES
(1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3),
(5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3);
INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES
(1, 1, 1), -- wechat 0.5, debian:7
(2, 2, 1), -- openssl 1.0, debian:7
(3, 2, 2), -- openssl 1.0, debian:8
(4, 3, 1); -- openssl 2.0, debian:7
INSERT INTO ancestry_feature (id, ancestry_id, namespaced_feature_id) VALUES
(1, 1, 1), (2, 1, 4),
(3, 2, 1), (4, 2, 3),
(5, 3, 2), (6, 4, 2); -- assume that ancestry-3 and ancestry-4 are vulnerable.
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES
(1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'),
(2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown'); (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown');
INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES
(1, 1, 2, '2.0'), (3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483');
(2, 1, 4, '1.9-abc');
INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES
(1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 (1, 1, 'openssl', '2.0', '2.0'),
(2, 1, 'libssl', '1.9-abc', '1.9-abc');
INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES
(1, 1, 2, 1);
INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES
(1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7'
SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('featureversion', 'id'), (SELECT MAX(id) FROM featureversion)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_diff_featureversion', 'id'), (SELECT MAX(id) FROM layer_diff_featureversion)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_fixedin_feature', 'id'), (SELECT MAX(id) FROM vulnerability_fixedin_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affects_featureversion', 'id'), (SELECT MAX(id) FROM vulnerability_affects_featureversion)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1);

View File

@ -17,352 +17,207 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"reflect" "errors"
"time" "time"
"github.com/guregu/null/zero" "github.com/lib/pq"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/pkg/commonerr"
) )
// compareStringLists returns the strings that are present in X but not in Y. var (
func compareStringLists(X, Y []string) []string { errVulnerabilityNotFound = errors.New("vulnerability is not in database")
m := make(map[string]bool) )
for _, y := range Y { type affectedAncestry struct {
m[y] = true name string
} id int64
diff := []string{}
for _, x := range X {
if m[x] {
continue
}
diff = append(diff, x)
m[x] = true
}
return diff
} }
func compareStringListsInBoth(X, Y []string) []string { type affectRelation struct {
m := make(map[string]struct{}) vulnerabilityID int64
namespacedFeatureID int64
for _, y := range Y { addedBy int64
m[y] = struct{}{}
}
diff := []string{}
for _, x := range X {
if _, e := m[x]; e {
diff = append(diff, x)
delete(m, x)
}
}
return diff
} }
func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) { type affectedFeatureRows struct {
defer observeQueryTime("listVulnerabilities", "all", time.Now()) rows map[int64]database.AffectedFeature
// Query Namespace.
var id int
err := pgSQL.QueryRow(searchNamespace, namespaceName).Scan(&id)
if err != nil {
return nil, -1, handleError("searchNamespace", err)
} else if id == 0 {
return nil, -1, commonerr.ErrNotFound
}
// Query.
query := searchVulnerabilityBase + searchVulnerabilityByNamespace
rows, err := pgSQL.Query(query, namespaceName, startID, limit+1)
if err != nil {
return nil, -1, handleError("searchVulnerabilityByNamespace", err)
}
defer rows.Close()
var vulns []database.Vulnerability
nextID := -1
size := 0
// Scan query.
for rows.Next() {
var vulnerability database.Vulnerability
err := rows.Scan(
&vulnerability.ID,
&vulnerability.Name,
&vulnerability.Namespace.ID,
&vulnerability.Namespace.Name,
&vulnerability.Namespace.VersionFormat,
&vulnerability.Description,
&vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
)
if err != nil {
return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err)
}
size++
if size > limit {
nextID = vulnerability.ID
} else {
vulns = append(vulns, vulnerability)
}
}
if err := rows.Err(); err != nil {
return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err)
}
return vulns, nextID, nil
} }
func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
return findVulnerability(pgSQL, namespaceName, name, false) defer observeQueryTime("findVulnerabilities", "", time.Now())
} resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
vulnIDMap := map[int64][]*database.NullableVulnerability{}
func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) {
defer observeQueryTime("findVulnerability", "all", time.Now())
queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName"
query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName
if forUpdate {
queryName = queryName + "+searchVulnerabilityForUpdate"
query = query + searchVulnerabilityForUpdate
}
return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name))
}
func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) {
defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now())
queryName := "searchVulnerabilityBase+searchVulnerabilityByID"
query := searchVulnerabilityBase + searchVulnerabilityByID
return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id))
}
func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) {
var vulnerability database.Vulnerability
err := vulnerabilityRow.Scan(
&vulnerability.ID,
&vulnerability.Name,
&vulnerability.Namespace.ID,
&vulnerability.Namespace.Name,
&vulnerability.Namespace.VersionFormat,
&vulnerability.Description,
&vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
)
//TODO(Sida): Change to bulk search.
stmt, err := tx.Prepare(searchVulnerability)
if err != nil { if err != nil {
return vulnerability, handleError(queryName+".Scan()", err) return nil, err
} }
if vulnerability.ID == 0 { // load vulnerabilities
return vulnerability, commonerr.ErrNotFound for i, key := range vulnerabilities {
} var (
id sql.NullInt64
// Query the FixedIn FeatureVersion now. vuln = database.NullableVulnerability{
rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID) VulnerabilityWithAffected: database.VulnerabilityWithAffected{
if err != nil { Vulnerability: database.Vulnerability{
return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) Name: key.Name,
} Namespace: database.Namespace{
defer rows.Close() Name: key.Namespace,
},
for rows.Next() { },
var featureVersionID zero.Int
var featureVersionVersion zero.String
var featureVersionFeatureName zero.String
err := rows.Scan(
&featureVersionVersion,
&featureVersionID,
&featureVersionFeatureName,
)
if err != nil {
return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
}
if !featureVersionID.IsZero() {
// Note that the ID we fill in featureVersion is actually a Feature ID, and not
// a FeatureVersion ID.
featureVersion := database.FeatureVersion{
Model: database.Model{ID: int(featureVersionID.Int64)},
Feature: database.Feature{
Model: database.Model{ID: int(featureVersionID.Int64)},
Namespace: vulnerability.Namespace,
Name: featureVersionFeatureName.String,
}, },
Version: featureVersionVersion.String,
} }
vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) )
err := stmt.QueryRow(key.Name, key.Namespace).Scan(
&id,
&vuln.Description,
&vuln.Link,
&vuln.Severity,
&vuln.Metadata,
&vuln.Namespace.VersionFormat,
)
if err != nil && err != sql.ErrNoRows {
stmt.Close()
return nil, handleError("searchVulnerability", err)
}
vuln.Valid = id.Valid
resultVuln[i] = vuln
if id.Valid {
vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i])
} }
} }
if err := rows.Err(); err != nil { if err := stmt.Close(); err != nil {
return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err) return nil, handleError("searchVulnerability", err)
} }
return vulnerability, nil toQuery := make([]int64, 0, len(vulnIDMap))
for id := range vulnIDMap {
toQuery = append(toQuery, id)
}
// load vulnerability affected features
rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
if err != nil {
return nil, handleError("searchVulnerabilityAffected", err)
}
for rows.Next() {
var (
id int64
f database.AffectedFeature
)
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion)
if err != nil {
return nil, handleError("searchVulnerabilityAffected", err)
}
for _, vuln := range vulnIDMap[id] {
f.Namespace = vuln.Namespace
vuln.Affected = append(vuln.Affected, f)
}
}
return resultVuln, nil
} }
// FixedIn.Namespace are not necessary, they are overwritten by the vuln. func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error {
// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore. defer observeQueryTime("insertVulnerabilities", "all", time.Now())
func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error { // bulk insert vulnerabilities
for _, vulnerability := range vulnerabilities { vulnIDs, err := tx.insertVulnerabilities(vulnerabilities)
err := pgSQL.insertVulnerability(vulnerability, false, generateNotifications)
if err != nil { if err != nil {
return err return err
} }
// bulk insert vulnerability affected features
vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities)
if err != nil {
return err
} }
return nil
return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap)
} }
func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error { // insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
tf := time.Now() //
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
var (
vulnFeature = map[int64]affectedFeatureRows{}
affectedID int64
)
// Verify parameters //TODO(Sida): Change to bulk insert.
if vulnerability.Name == "" || vulnerability.Namespace.Name == "" { stmt, err := tx.Prepare(insertVulnerabilityAffected)
return commonerr.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace")
}
for i := 0; i < len(vulnerability.FixedIn); i++ {
fifv := &vulnerability.FixedIn[i]
if fifv.Feature.Namespace.Name == "" {
// As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's
// Namespace.
fifv.Feature.Namespace = vulnerability.Namespace
} else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name {
msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability"
log.Warning(msg)
return commonerr.NewBadRequestError(msg)
}
}
// We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities.
defer observeQueryTime("insertVulnerability", "all", tf)
// Begin transaction.
tx, err := pgSQL.Begin()
if err != nil { if err != nil {
tx.Rollback() return nil, handleError("insertVulnerabilityAffected", err)
return handleError("insertVulnerability.Begin()", err)
} }
// Find existing vulnerability and its Vulnerability_FixedIn_Features (for update). defer stmt.Close()
existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, true) for i, vuln := range vulnerabilities {
if err != nil && err != commonerr.ErrNotFound { // affected feature row ID -> affected feature
tx.Rollback() affectedFeatures := map[int64]database.AffectedFeature{}
return err for _, f := range vuln.Affected {
} err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID)
if onlyFixedIn {
// Because this call tries to update FixedIn FeatureVersion, import all other data from the
// existing one.
if existingVulnerability.ID == 0 {
return commonerr.ErrNotFound
}
fixedIn := vulnerability.FixedIn
vulnerability = existingVulnerability
vulnerability.FixedIn = fixedIn
}
if existingVulnerability.ID != 0 {
updateMetadata := vulnerability.Description != existingVulnerability.Description ||
vulnerability.Link != existingVulnerability.Link ||
vulnerability.Severity != existingVulnerability.Severity ||
!reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata)
// Construct the entire list of FixedIn FeatureVersion, by using the
// the FixedIn list of the old vulnerability.
//
// TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the
// existing vulnerability in order to make metadata updates much faster.
var updateFixedIn bool
vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn)
if !updateMetadata && !updateFixedIn {
tx.Commit()
return nil
}
// Mark the old vulnerability as non latest.
_, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name)
if err != nil { if err != nil {
tx.Rollback() return nil, handleError("insertVulnerabilityAffected", err)
return handleError("removeVulnerability", err)
} }
} else { affectedFeatures[affectedID] = f
// The vulnerability is new, we don't want to have any
// versionfmt.MinVersion as they are only used for diffing existing
// vulnerabilities.
var fixedIn []database.FeatureVersion
for _, fv := range vulnerability.FixedIn {
if fv.Version != versionfmt.MinVersion {
fixedIn = append(fixedIn, fv)
} }
} vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures}
vulnerability.FixedIn = fixedIn
} }
// Find or insert Vulnerability's Namespace. return vulnFeature, nil
namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace) }
// insertVulnerabilities inserts a set of unique vulnerabilities into database,
// under the assumption that all vulnerabilities are valid.
func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
var (
vulnID int64
vulnIDs = make([]int64, 0, len(vulnerabilities))
vulnMap = map[database.VulnerabilityID]struct{}{}
)
for _, v := range vulnerabilities {
key := database.VulnerabilityID{
Name: v.Name,
Namespace: v.Namespace.Name,
}
// Ensure uniqueness of vulnerability IDs
if _, ok := vulnMap[key]; ok {
return nil, errors.New("inserting duplicated vulnerabilities is not allowed")
}
vulnMap[key] = struct{}{}
}
//TODO(Sida): Change to bulk insert.
stmt, err := tx.Prepare(insertVulnerability)
if err != nil { if err != nil {
return err return nil, handleError("insertVulnerability", err)
} }
// Insert vulnerability. defer stmt.Close()
err = tx.QueryRow( for _, vuln := range vulnerabilities {
insertVulnerability, err := stmt.QueryRow(vuln.Name, vuln.Description,
namespaceID, vuln.Link, &vuln.Severity, &vuln.Metadata,
vulnerability.Name, vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
vulnerability.Description,
vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
).Scan(&vulnerability.ID)
if err != nil { if err != nil {
tx.Rollback() return nil, handleError("insertVulnerability", err)
return handleError("insertVulnerability", err)
} }
// Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. vulnIDs = append(vulnIDs, vulnID)
err = pgSQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn)
if err != nil {
tx.Rollback()
return err
} }
// Create a notification. return vulnIDs, nil
if generateNotification {
err = createNotification(tx, existingVulnerability.ID, vulnerability.ID)
if err != nil {
return err
}
}
// Commit transaction.
err = tx.Commit()
if err != nil {
tx.Rollback()
return handleError("insertVulnerability.Commit()", err)
}
return nil
} }
// castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that // castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that
@ -376,241 +231,208 @@ func castMetadata(m database.MetadataMap) database.MetadataMap {
return c return c
} }
// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result. func (tx *pgSession) lockFeatureVulnerabilityCache() error {
func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) { _, err := tx.Exec(lockVulnerabilityAffects)
currentMap, currentNames := createFeatureVersionNameMap(currentList)
diffMap, diffNames := createFeatureVersionNameMap(diff)
addedNames := compareStringLists(diffNames, currentNames)
inBothNames := compareStringListsInBoth(diffNames, currentNames)
different := false
for _, name := range addedNames {
if diffMap[name].Version == versionfmt.MinVersion {
// MinVersion only makes sense when a Feature is already fixed in some version,
// in which case we would be in the "inBothNames".
continue
}
currentMap[name] = diffMap[name]
different = true
}
for _, name := range inBothNames {
fv := diffMap[name]
if fv.Version == versionfmt.MinVersion {
// MinVersion means that the Feature doesn't affect the Vulnerability anymore.
delete(currentMap, name)
different = true
} else if fv.Version != currentMap[name].Version {
// The version got updated.
currentMap[name] = diffMap[name]
different = true
}
}
// Convert currentMap to a slice and return it.
var newList []database.FeatureVersion
for _, fv := range currentMap {
newList = append(newList, fv)
}
return newList, different
}
func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) {
m := make(map[string]database.FeatureVersion, 0)
s := make([]string, 0, len(features))
for i := 0; i < len(features); i++ {
featureVersion := features[i]
m[featureVersion.Feature.Name] = featureVersion
s = append(s, featureVersion.Feature.Name)
}
return m, s
}
// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given
// vulnerability with the specified database.FeatureVersion list and uses
// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to
// Vulnerability_Affects_FeatureVersion.
func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error {
defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now())
// Insert or find the Features.
// TODO(Quentin-M): Batch me.
var err error
var features []*database.Feature
for i := 0; i < len(fixedIn); i++ {
features = append(features, &fixedIn[i].Feature)
}
for _, feature := range features {
if feature.ID == 0 {
if feature.ID, err = pgSQL.insertFeature(*feature); err != nil {
return err
}
}
}
// Lock Vulnerability_Affects_FeatureVersion exclusively.
// We want to prevent InsertFeatureVersion to modify it.
promConcurrentLockVAFV.Inc()
defer promConcurrentLockVAFV.Dec()
t := time.Now()
_, err = tx.Exec(lockVulnerabilityAffects)
observeQueryTime("insertVulnerability", "lock", t)
if err != nil { if err != nil {
tx.Rollback() return handleError("lockVulnerabilityAffects", err)
return handleError("insertVulnerability.lockVulnerabilityAffects", err)
} }
for _, fv := range fixedIn {
var fixedInID int
var created bool
// Find or create entry in Vulnerability_FixedIn_Feature.
err = tx.QueryRow(
soiVulnerabilityFixedInFeature,
vulnerabilityID, fv.Feature.ID,
&fv.Version,
).Scan(&created, &fixedInID)
if err != nil {
return handleError("insertVulnerabilityFixedInFeature", err)
}
if !created {
// The relationship between the feature and the vulnerability already
// existed, no need to update Vulnerability_Affects_FeatureVersion.
continue
}
// Insert Vulnerability_Affects_FeatureVersion.
err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Feature.Namespace.VersionFormat, fv.Version)
if err != nil {
return err
}
}
return nil return nil
} }
func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, versionFormat, fixedInVersion string) error { // cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
// Find every FeatureVersions of the Feature that the vulnerability affects. // to affected feature rows and caches them.
// TODO(Quentin-M): LIMIT func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error {
rows, err := tx.Query(searchFeatureVersionByFeature, featureID) // Prevent InsertNamespacedFeatures to modify it.
err := tx.lockFeatureVulnerabilityCache()
if err != nil { if err != nil {
return handleError("searchFeatureVersionByFeature", err) return err
} }
vulnIDs := []int64{}
for id := range affected {
vulnIDs = append(vulnIDs, id)
}
rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
if err != nil {
return handleError("searchVulnerabilityPotentialAffected", err)
}
defer rows.Close() defer rows.Close()
var affecteds []database.FeatureVersion relation := []affectRelation{}
for rows.Next() { for rows.Next() {
var affected database.FeatureVersion var (
vulnID int64
nsfID int64
fVersion string
addedBy int64
)
err := rows.Scan(&affected.ID, &affected.Version) err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
if err != nil { if err != nil {
return handleError("searchFeatureVersionByFeature.Scan()", err) return handleError("searchVulnerabilityPotentialAffected", err)
} }
cmp, err := versionfmt.Compare(versionFormat, affected.Version, fixedInVersion) candidate, ok := affected[vulnID].rows[addedBy]
if err != nil {
if !ok {
return errors.New("vulnerability affected feature not found")
}
if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat,
fVersion,
candidate.AffectedVersion); err == nil {
if in {
relation = append(relation,
affectRelation{
vulnerabilityID: vulnID,
namespacedFeatureID: nsfID,
addedBy: addedBy,
})
}
} else {
return err return err
} }
if cmp < 0 {
// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
// thus, this FeatureVersion is affected by it.
affecteds = append(affecteds, affected)
} }
}
if err = rows.Err(); err != nil {
return handleError("searchFeatureVersionByFeature.Rows()", err)
}
rows.Close()
// Insert into Vulnerability_Affects_FeatureVersion. //TODO(Sida): Change to bulk insert.
for _, affected := range affecteds { for _, r := range relation {
// TODO(Quentin-M): Batch me. result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID)
if err != nil { if err != nil {
return handleError("insertVulnerabilityAffectsFeatureVersion", err) return handleError("insertVulnerabilityAffectedNamespacedFeature", err)
}
if num, err := result.RowsAffected(); err == nil {
if num <= 0 {
return errors.New("Nothing cached in database")
}
} else {
return err
} }
} }
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation))
return nil return nil
} }
func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error {
defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now())
v := database.Vulnerability{
Name: vulnerabilityName,
Namespace: database.Namespace{
Name: vulnerabilityNamespace,
},
FixedIn: fixes,
}
return pgSQL.insertVulnerability(v, true, true)
}
func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error {
defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now())
v := database.Vulnerability{
Name: vulnerabilityName,
Namespace: database.Namespace{
Name: vulnerabilityNamespace,
},
FixedIn: []database.FeatureVersion{
{
Feature: database.Feature{
Name: featureName,
Namespace: database.Namespace{
Name: vulnerabilityNamespace,
},
},
Version: versionfmt.MinVersion,
},
},
}
return pgSQL.insertVulnerability(v, true, true)
}
func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error {
defer observeQueryTime("DeleteVulnerability", "all", time.Now()) defer observeQueryTime("DeleteVulnerability", "all", time.Now())
// Begin transaction. vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities)
tx, err := pgSQL.Begin()
if err != nil {
tx.Rollback()
return handleError("DeleteVulnerability.Begin()", err)
}
var vulnerabilityID int
err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID)
if err != nil {
tx.Rollback()
return handleError("removeVulnerability", err)
}
// Create a notification.
err = createNotification(tx, vulnerabilityID, 0)
if err != nil { if err != nil {
return err return err
} }
// Commit transaction. if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil {
err = tx.Commit() return err
}
return nil
}
func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error {
if len(vulnerabilityIDs) == 0 {
return nil
}
// Prevent InsertNamespacedFeatures to modify it.
err := tx.lockFeatureVulnerabilityCache()
if err != nil { if err != nil {
tx.Rollback() return err
return handleError("DeleteVulnerability.Commit()", err) }
//TODO(Sida): Make a nicer interface for bulk inserting.
keys := make([]interface{}, len(vulnerabilityIDs))
for i, id := range vulnerabilityIDs {
keys[i] = id
}
_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
if err != nil {
return handleError("removeVulnerabilityAffectedFeature", err)
} }
return nil return nil
} }
func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) {
var (
vulnID sql.NullInt64
vulnIDs []int64
)
// mark vulnerabilities deleted
stmt, err := tx.Prepare(removeVulnerability)
if err != nil {
return nil, handleError("removeVulnerability", err)
}
defer stmt.Close()
for _, vuln := range vulnerabilities {
err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
if err != nil {
return nil, handleError("removeVulnerability", err)
}
if !vulnID.Valid {
return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
}
vulnIDs = append(vulnIDs, vulnID.Int64)
}
return vulnIDs, nil
}
// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
// database and the order of output array is not guaranteed.
func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
return tx.findVulnerabilityIDs(vulnIDs, true)
}
func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
return tx.findVulnerabilityIDs(vulnIDs, false)
}
func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
if len(vulnIDs) == 0 {
return nil, nil
}
vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{}
keys := make([]interface{}, len(vulnIDs)*2)
for i, vulnID := range vulnIDs {
keys[i*2] = vulnID.Name
keys[i*2+1] = vulnID.Namespace
vulnIDMap[vulnID] = sql.NullInt64{}
}
query := ""
if withLatestDeleted {
query = querySearchLastDeletedVulnerabilityID(len(vulnIDs))
} else {
query = querySearchNotDeletedVulnerabilityID(len(vulnIDs))
}
rows, err := tx.Query(query, keys...)
if err != nil {
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
}
defer rows.Close()
var (
id sql.NullInt64
vulnID database.VulnerabilityID
)
for rows.Next() {
err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
if err != nil {
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
}
vulnIDMap[vulnID] = id
}
ids := make([]sql.NullInt64, len(vulnIDs))
for i, v := range vulnIDs {
ids[i] = vulnIDMap[v]
}
return ids, nil
}

View File

@ -15,282 +15,329 @@
package pgsql package pgsql
import ( import (
"reflect"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt"
"github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/ext/versionfmt/dpkg"
"github.com/coreos/clair/pkg/commonerr"
) )
func TestFindVulnerability(t *testing.T) { func TestInsertVulnerabilities(t *testing.T) {
datastore, err := openDatabaseForTest("FindVulnerability", true) store, tx := openSessionForTest(t, "InsertVulnerabilities", true)
if err != nil {
t.Error(err) ns1 := database.Namespace{
return Name: "name",
VersionFormat: "random stuff",
} }
defer datastore.Close()
// Find a vulnerability that does not exist. ns2 := database.Namespace{
_, err = datastore.FindVulnerability("", "") Name: "debian:7",
assert.Equal(t, commonerr.ErrNotFound, err) VersionFormat: "dpkg",
}
// Find a normal vulnerability. // invalid vulnerability
v1 := database.Vulnerability{ v1 := database.Vulnerability{
Name: "invalid",
Namespace: ns1,
}
vwa1 := database.VulnerabilityWithAffected{
Vulnerability: v1,
}
// valid vulnerability
v2 := database.Vulnerability{
Name: "valid",
Namespace: ns2,
Severity: database.UnknownSeverity,
}
vwa2 := database.VulnerabilityWithAffected{
Vulnerability: v2,
}
// empty
err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{})
assert.Nil(t, err)
// invalid content: vwa1 is invalid
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2})
assert.NotNil(t, err)
tx = restartSession(t, store, tx, false)
// invalid content: duplicated input
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2})
assert.NotNil(t, err)
tx = restartSession(t, store, tx, false)
// valid content
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
assert.Nil(t, err)
tx = restartSession(t, store, tx, true)
// ensure the content is in database
vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}})
if assert.Nil(t, err) && assert.Len(t, vulns, 1) {
assert.True(t, vulns[0].Valid)
}
tx = restartSession(t, store, tx, false)
// valid content: vwa2 removed and inserted
err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}})
assert.Nil(t, err)
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
assert.Nil(t, err)
closeTest(t, store, tx)
}
func TestCachingVulnerable(t *testing.T) {
datastore, tx := openSessionForTest(t, "CachingVulnerable", true)
defer closeTest(t, datastore, tx)
ns := database.Namespace{
Name: "debian:8",
VersionFormat: dpkg.ParserName,
}
f := database.NamespacedFeature{
Feature: database.Feature{
Name: "openssl",
Version: "1.0",
VersionFormat: dpkg.ParserName,
},
Namespace: ns,
}
vuln := database.VulnerabilityWithAffected{
Vulnerability: database.Vulnerability{
Name: "CVE-YAY",
Namespace: ns,
Severity: database.HighSeverity,
},
Affected: []database.AffectedFeature{
{
Namespace: ns,
FeatureName: "openssl",
AffectedVersion: "2.0",
FixedInVersion: "2.1",
},
},
}
vuln2 := database.VulnerabilityWithAffected{
Vulnerability: database.Vulnerability{
Name: "CVE-YAY2",
Namespace: ns,
Severity: database.HighSeverity,
},
Affected: []database.AffectedFeature{
{
Namespace: ns,
FeatureName: "openssl",
AffectedVersion: "2.1",
FixedInVersion: "2.2",
},
},
}
vulnFixed1 := database.VulnerabilityWithFixedIn{
Vulnerability: database.Vulnerability{
Name: "CVE-YAY",
Namespace: ns,
Severity: database.HighSeverity,
},
FixedInVersion: "2.1",
}
vulnFixed2 := database.VulnerabilityWithFixedIn{
Vulnerability: database.Vulnerability{
Name: "CVE-YAY2",
Namespace: ns,
Severity: database.HighSeverity,
},
FixedInVersion: "2.2",
}
if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) {
t.FailNow()
}
r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f})
assert.Nil(t, err)
assert.Len(t, r, 1)
for _, anf := range r {
if assert.True(t, anf.Valid) && assert.Len(t, anf.AffectedBy, 2) {
for _, a := range anf.AffectedBy {
if a.Name == "CVE-YAY" {
assert.Equal(t, vulnFixed1, a)
} else if a.Name == "CVE-YAY2" {
assert.Equal(t, vulnFixed2, a)
} else {
t.FailNow()
}
}
}
}
}
func TestFindVulnerabilities(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindVulnerabilities", true)
defer closeTest(t, datastore, tx)
vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
{Name: "CVE-NOPE", Namespace: "debian:7"},
{Name: "CVE-NOT HERE"},
})
ns := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
expectedExisting := []database.VulnerabilityWithAffected{
{
Vulnerability: database.Vulnerability{
Namespace: ns,
Name: "CVE-OPENSSL-1-DEB7", Name: "CVE-OPENSSL-1-DEB7",
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: database.HighSeverity, Severity: database.HighSeverity,
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: dpkg.ParserName,
}, },
FixedIn: []database.FeatureVersion{ Affected: []database.AffectedFeature{
{ {
Feature: database.Feature{Name: "openssl"}, FeatureName: "openssl",
Version: "2.0", AffectedVersion: "2.0",
FixedInVersion: "2.0",
Namespace: ns,
}, },
{ {
Feature: database.Feature{Name: "libssl"}, FeatureName: "libssl",
Version: "1.9-abc", AffectedVersion: "1.9-abc",
FixedInVersion: "1.9-abc",
Namespace: ns,
}, },
}, },
} },
{
v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") Vulnerability: database.Vulnerability{
if assert.Nil(t, err) { Namespace: ns,
equalsVuln(t, &v1, &v1f)
}
// Find a vulnerability that has no link, no severity and no FixedIn.
v2 := database.Vulnerability{
Name: "CVE-NOPE", Name: "CVE-NOPE",
Description: "A vulnerability affecting nothing", Description: "A vulnerability affecting nothing",
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: dpkg.ParserName,
},
Severity: database.UnknownSeverity,
}
v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE")
if assert.Nil(t, err) {
equalsVuln(t, &v2, &v2f)
}
}
func TestDeleteVulnerability(t *testing.T) {
datastore, err := openDatabaseForTest("InsertVulnerability", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Delete non-existing Vulnerability.
err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7")
assert.Equal(t, commonerr.ErrNotFound, err)
err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1")
assert.Equal(t, commonerr.ErrNotFound, err)
// Delete Vulnerability.
err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
if assert.Nil(t, err) {
_, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
assert.Equal(t, commonerr.ErrNotFound, err)
}
}
func TestInsertVulnerability(t *testing.T) {
datastore, err := openDatabaseForTest("InsertVulnerability", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Create some data.
n1 := database.Namespace{
Name: "TestInsertVulnerabilityNamespace1",
VersionFormat: dpkg.ParserName,
}
n2 := database.Namespace{
Name: "TestInsertVulnerabilityNamespace2",
VersionFormat: dpkg.ParserName,
}
f1 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion1",
Namespace: n1,
},
Version: "1.0",
}
f2 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion1",
Namespace: n2,
},
Version: "1.0",
}
f3 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion2",
},
Version: versionfmt.MaxVersion,
}
f4 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion2",
},
Version: "1.4",
}
f5 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion3",
},
Version: "1.5",
}
f6 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion4",
},
Version: "0.1",
}
f7 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion5",
},
Version: versionfmt.MaxVersion,
}
f8 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion5",
},
Version: versionfmt.MinVersion,
}
// Insert invalid vulnerabilities.
for _, vulnerability := range []database.Vulnerability{
{
Name: "",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1},
Severity: database.UnknownSeverity, Severity: database.UnknownSeverity,
}, },
{
Name: "TestInsertVulnerability0",
Namespace: database.Namespace{},
FixedIn: []database.FeatureVersion{f1},
Severity: database.UnknownSeverity,
}, },
{
Name: "TestInsertVulnerability0-",
Namespace: database.Namespace{},
FixedIn: []database.FeatureVersion{f1},
},
{
Name: "TestInsertVulnerability0",
Namespace: n1,
FixedIn: []database.FeatureVersion{f2},
Severity: database.UnknownSeverity,
},
} {
err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true)
assert.Error(t, err)
} }
// Insert a simple vulnerability and find it. expectedExistingMap := map[database.VulnerabilityID]database.VulnerabilityWithAffected{}
v1meta := make(map[string]interface{}) for _, v := range expectedExisting {
v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1" expectedExistingMap[database.VulnerabilityID{Name: v.Name, Namespace: v.Namespace.Name}] = v
v1meta["TestInsertVulnerabilityMetadata2"] = struct {
Test string
}{
Test: "TestInsertVulnerabilityMetadataValue1",
} }
v1 := database.Vulnerability{ nonexisting := database.VulnerabilityWithAffected{
Name: "TestInsertVulnerability1", Vulnerability: database.Vulnerability{Name: "CVE-NOT HERE"},
Namespace: n1,
FixedIn: []database.FeatureVersion{f1, f3, f6, f7},
Severity: database.LowSeverity,
Description: "TestInsertVulnerabilityDescription1",
Link: "TestInsertVulnerabilityLink1",
Metadata: v1meta,
} }
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true)
if assert.Nil(t, err) { if assert.Nil(t, err) {
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) for _, v := range vuln {
if v.Valid {
key := database.VulnerabilityID{
Name: v.Name,
Namespace: v.Namespace.Name,
}
expected, ok := expectedExistingMap[key]
if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) {
assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected)
}
} else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) {
t.FailNow()
}
}
}
// same vulnerability
r, err := tx.FindVulnerabilities([]database.VulnerabilityID{
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
})
if assert.Nil(t, err) { if assert.Nil(t, err) {
equalsVuln(t, &v1, &v1f) for _, vuln := range r {
} if assert.True(t, vuln.Valid) {
} expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}]
assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected)
// Update vulnerability.
v1.Description = "TestInsertVulnerabilityLink2"
v1.Link = "TestInsertVulnerabilityLink2"
v1.Severity = database.HighSeverity
// Update f3 in f4, add fixed in f5, add fixed in f6 which already exists,
// removes fixed in f7 by adding f8 which is f7 but with MinVersion, and
// add fixed by f5 a second time (duplicated).
v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8, f5}
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true)
if assert.Nil(t, err) {
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
if assert.Nil(t, err) {
// Remove f8 from the struct for comparison as it was just here to cancel f7.
// Remove one of the f5 too as it was twice in the struct but the database
// implementation should have dedup'd it.
v1.FixedIn = v1.FixedIn[:len(v1.FixedIn)-2]
// We already had f1 before the update.
// Add it to the struct for comparison.
v1.FixedIn = append(v1.FixedIn, f1)
equalsVuln(t, &v1, &v1f)
}
}
}
func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name)
assert.Equal(t, expected.Description, actual.Description)
assert.Equal(t, expected.Link, actual.Link)
assert.Equal(t, expected.Severity, actual.Severity)
assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata))
if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) {
for _, actualFeatureVersion := range actual.FixedIn {
found := false
for _, expectedFeatureVersion := range expected.FixedIn {
if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name {
found = true
assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name)
assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version)
}
}
if !found {
t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name)
} }
} }
} }
} }
func TestStringComparison(t *testing.T) { func TestDeleteVulnerabilities(t *testing.T) {
cmp := compareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"}) datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true)
assert.Len(t, cmp, 1) defer closeTest(t, datastore, tx)
assert.NotContains(t, cmp, "a")
assert.Contains(t, cmp, "b")
cmp = compareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"}) remove := []database.VulnerabilityID{}
assert.Len(t, cmp, 2) // empty case
assert.NotContains(t, cmp, "b") assert.Nil(t, tx.DeleteVulnerabilities(remove))
assert.Contains(t, cmp, "a") // invalid case
assert.Contains(t, cmp, "c") remove = append(remove, database.VulnerabilityID{})
assert.NotNil(t, tx.DeleteVulnerabilities(remove))
// valid case
validRemove := []database.VulnerabilityID{
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
{Name: "CVE-NOPE", Namespace: "debian:7"},
}
assert.Nil(t, tx.DeleteVulnerabilities(validRemove))
vuln, err := tx.FindVulnerabilities(validRemove)
if assert.Nil(t, err) {
for _, v := range vuln {
assert.False(t, v.Valid)
}
}
}
func TestFindVulnerabilityIDs(t *testing.T) {
store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true)
defer closeTest(t, store, tx)
ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
if assert.Nil(t, err) {
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, ids[0].Int64)) {
assert.Fail(t, "")
}
}
ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
if assert.Nil(t, err) {
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, ids[0].Int64)) {
assert.Fail(t, "")
}
}
}
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
}
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.AffectedFeature]bool{}
for _, i := range expected {
has[i] = false
}
for _, i := range actual {
if visited, ok := has[i]; !ok {
return false
} else if visited {
return false
}
has[i] = true
}
return true
}
return false
} }