From a5c6400065a873f6ae14d50b73550dc07239d7bf Mon Sep 17 00:00:00 2001 From: Sida Chen Date: Wed, 26 Jul 2017 16:23:54 -0700 Subject: [PATCH] database: postgres implementation with tests. --- database/pgsql/ancestry.go | 261 ++++++ database/pgsql/ancestry_test.go | 207 +++++ database/pgsql/complex_test.go | 223 +++-- database/pgsql/feature.go | 520 +++++++---- database/pgsql/feature_test.go | 311 +++++-- database/pgsql/keyvalue.go | 60 +- database/pgsql/keyvalue_test.go | 28 +- database/pgsql/layer.go | 641 +++++-------- database/pgsql/layer_test.go | 471 ++-------- database/pgsql/lock.go | 76 +- database/pgsql/lock_test.go | 54 +- .../pgsql/migrations/00001_change_migrator.go | 53 -- .../pgsql/migrations/00001_initial_schema.go | 192 ++++ .../pgsql/migrations/00002_initial_schema.go | 128 --- .../pgsql/migrations/00003_add_indexes.go | 35 - ...00004_add_index_notification_deleted_at.go | 29 - database/pgsql/migrations/00005_ldfv_index.go | 29 - .../migrations/00006_add_version_format.go | 31 - .../migrations/00007_expand_column_width.go | 31 - .../pgsql/migrations/00008_add_multiplens.go | 44 - database/pgsql/namespace.go | 91 +- database/pgsql/namespace_test.go | 93 +- database/pgsql/notification.go | 421 +++++---- database/pgsql/notification_test.go | 376 ++++---- database/pgsql/pgsql.go | 74 +- database/pgsql/pgsql_test.go | 224 ++++- database/pgsql/queries.go | 512 ++++++---- database/pgsql/testdata/data.sql | 148 +-- database/pgsql/vulnerability.go | 872 +++++++----------- database/pgsql/vulnerability_test.go | 523 ++++++----- 30 files changed, 3682 insertions(+), 3076 deletions(-) create mode 100644 database/pgsql/ancestry.go create mode 100644 database/pgsql/ancestry_test.go delete mode 100644 database/pgsql/migrations/00001_change_migrator.go create mode 100644 database/pgsql/migrations/00001_initial_schema.go delete mode 100644 database/pgsql/migrations/00002_initial_schema.go delete mode 100644 database/pgsql/migrations/00003_add_indexes.go delete mode 100644 database/pgsql/migrations/00004_add_index_notification_deleted_at.go delete mode 100644 database/pgsql/migrations/00005_ldfv_index.go delete mode 100644 database/pgsql/migrations/00006_add_version_format.go delete mode 100644 database/pgsql/migrations/00007_expand_column_width.go delete mode 100644 database/pgsql/migrations/00008_add_multiplens.go diff --git a/database/pgsql/ancestry.go b/database/pgsql/ancestry.go new file mode 100644 index 00000000..17033144 --- /dev/null +++ b/database/pgsql/ancestry.go @@ -0,0 +1,261 @@ +package pgsql + +import ( + "database/sql" + "errors" + "fmt" + "strings" + + "github.com/lib/pq" + log "github.com/sirupsen/logrus" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/commonerr" +) + +func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry, features []database.NamespacedFeature, processedBy database.Processors) error { + if ancestry.Name == "" { + log.Warning("Empty ancestry name is not allowed") + return commonerr.NewBadRequestError("could not insert an ancestry with empty name") + } + + if len(ancestry.Layers) == 0 { + log.Warning("Empty ancestry is not allowed") + return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers") + } + + err := tx.deleteAncestry(ancestry.Name) + if err != nil { + return err + } + + var ancestryID int64 + err = tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID) + if err != nil { + if isErrUniqueViolation(err) { + return handleError("insertAncestry", errors.New("Other Go-routine is processing this ancestry (skip).")) + } + return handleError("insertAncestry", err) + } + + err = tx.insertAncestryLayers(ancestryID, ancestry.Layers) + if err != nil { + return err + } + + err = tx.insertAncestryFeatures(ancestryID, features) + if err != nil { + return err + } + + return tx.persistProcessors(persistAncestryLister, + "persistAncestryLister", + persistAncestryDetector, + "persistAncestryDetector", + ancestryID, processedBy) +} + +func (tx *pgSession) FindAncestry(name string) (database.Ancestry, database.Processors, bool, error) { + ancestry := database.Ancestry{Name: name} + processed := database.Processors{} + + var ancestryID int64 + err := tx.QueryRow(searchAncestry, name).Scan(&ancestryID) + if err != nil { + if err == sql.ErrNoRows { + return ancestry, processed, false, nil + } + return ancestry, processed, false, handleError("searchAncestry", err) + } + + ancestry.Layers, err = tx.findAncestryLayers(ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + processed.Detectors, err = tx.findProcessors(searchAncestryDetectors, "searchAncestryDetectors", "detector", ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + processed.Listers, err = tx.findProcessors(searchAncestryListers, "searchAncestryListers", "lister", ancestryID) + if err != nil { + return ancestry, processed, false, err + } + + return ancestry, processed, true, nil +} + +func (tx *pgSession) FindAncestryFeatures(name string) (database.AncestryWithFeatures, bool, error) { + var ( + awf database.AncestryWithFeatures + ok bool + err error + ) + awf.Ancestry, awf.ProcessedBy, ok, err = tx.FindAncestry(name) + if err != nil { + return awf, false, err + } + + if !ok { + return awf, false, nil + } + + rows, err := tx.Query(searchAncestryFeatures, name) + if err != nil { + return awf, false, handleError("searchAncestryFeatures", err) + } + + for rows.Next() { + nf := database.NamespacedFeature{} + err := rows.Scan(&nf.Namespace.Name, &nf.Namespace.VersionFormat, &nf.Feature.Name, &nf.Feature.Version) + if err != nil { + return awf, false, handleError("searchAncestryFeatures", err) + } + nf.Feature.VersionFormat = nf.Namespace.VersionFormat + awf.Features = append(awf.Features, nf) + } + + return awf, true, nil +} + +func (tx *pgSession) deleteAncestry(name string) error { + result, err := tx.Exec(removeAncestry, name) + if err != nil { + return handleError("removeAncestry", err) + } + + _, err = result.RowsAffected() + if err != nil { + return handleError("removeAncestry", err) + } + + return nil +} + +func (tx *pgSession) findProcessors(query, queryName, processorType string, id int64) ([]string, error) { + rows, err := tx.Query(query, id) + if err != nil { + if err == sql.ErrNoRows { + log.Warning("No " + processorType + " are used") + return nil, nil + } + return nil, handleError(queryName, err) + } + + var ( + processors []string + processor string + ) + + for rows.Next() { + err := rows.Scan(&processor) + if err != nil { + return nil, handleError(queryName, err) + } + processors = append(processors, processor) + } + + return processors, nil +} + +func (tx *pgSession) findAncestryLayers(ancestryID int64) ([]database.Layer, error) { + rows, err := tx.Query(searchAncestryLayer, ancestryID) + if err != nil { + return nil, handleError("searchAncestryLayer", err) + } + layers := []database.Layer{} + for rows.Next() { + var layer database.Layer + err := rows.Scan(&layer.Hash) + if err != nil { + return nil, handleError("searchAncestryLayer", err) + } + layers = append(layers, layer) + } + return layers, nil +} + +func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []database.Layer) error { + layerIDs := map[string]sql.NullInt64{} + for _, l := range layers { + layerIDs[l.Hash] = sql.NullInt64{} + } + + layerHashes := []string{} + for hash := range layerIDs { + layerHashes = append(layerHashes, hash) + } + + rows, err := tx.Query(searchLayerIDs, pq.Array(layerHashes)) + if err != nil { + return handleError("searchLayerIDs", err) + } + + for rows.Next() { + var ( + layerID sql.NullInt64 + layerName string + ) + err := rows.Scan(&layerID, &layerName) + if err != nil { + return handleError("searchLayerIDs", err) + } + layerIDs[layerName] = layerID + } + + notFound := []string{} + for hash, id := range layerIDs { + if !id.Valid { + notFound = append(notFound, hash) + } + } + + if len(notFound) > 0 { + return handleError("searchLayerIDs", fmt.Errorf("Layer %s is not found in database", strings.Join(notFound, ","))) + } + + //TODO(Sida): use bulk insert. + stmt, err := tx.Prepare(insertAncestryLayer) + if err != nil { + return handleError("insertAncestryLayer", err) + } + + defer stmt.Close() + for index, layer := range layers { + _, err := stmt.Exec(ancestryID, index, layerIDs[layer.Hash].Int64) + if err != nil { + return handleError("insertAncestryLayer", commonerr.CombineErrors(err, stmt.Close())) + } + } + + return nil +} + +func (tx *pgSession) insertAncestryFeatures(ancestryID int64, features []database.NamespacedFeature) error { + featureIDs, err := tx.findNamespacedFeatureIDs(features) + if err != nil { + return err + } + + //TODO(Sida): use bulk insert. + stmtFeatures, err := tx.Prepare(insertAncestryFeature) + if err != nil { + return handleError("insertAncestryFeature", err) + } + + defer stmtFeatures.Close() + + for _, id := range featureIDs { + if !id.Valid { + return errors.New("requested namespaced feature is not in database") + } + + _, err := stmtFeatures.Exec(ancestryID, id) + if err != nil { + return handleError("insertAncestryFeature", err) + } + } + + return nil +} diff --git a/database/pgsql/ancestry_test.go b/database/pgsql/ancestry_test.go new file mode 100644 index 00000000..7851163c --- /dev/null +++ b/database/pgsql/ancestry_test.go @@ -0,0 +1,207 @@ +// Copyright 2017 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "sort" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coreos/clair/database" +) + +func TestUpsertAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "UpsertAncestry", true) + defer closeTest(t, store, tx) + a1 := database.Ancestry{ + Name: "a1", + Layers: []database.Layer{ + {Hash: "layer-N"}, + }, + } + + a2 := database.Ancestry{} + + a3 := database.Ancestry{ + Name: "a", + Layers: []database.Layer{ + {Hash: "layer-0"}, + }, + } + + a4 := database.Ancestry{ + Name: "a", + Layers: []database.Layer{ + {Hash: "layer-1"}, + }, + } + + f1 := database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + } + + // not in database + f2 := database.Feature{ + Name: "wechat", + Version: "0.6", + VersionFormat: "dpkg", + } + + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + p := database.Processors{ + Listers: []string{"dpkg", "non-existing"}, + Detectors: []string{"os-release", "non-existing"}, + } + + nsf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + + // not in database + nsf2 := database.NamespacedFeature{ + Namespace: n1, + Feature: f2, + } + + // invalid case + assert.NotNil(t, tx.UpsertAncestry(a1, nil, database.Processors{})) + assert.NotNil(t, tx.UpsertAncestry(a2, nil, database.Processors{})) + // valid case + assert.Nil(t, tx.UpsertAncestry(a3, nil, database.Processors{})) + // replace invalid case + assert.NotNil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1, nsf2}, p)) + // replace valid case + assert.Nil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1}, p)) + // validate + ancestry, ok, err := tx.FindAncestryFeatures("a") + assert.Nil(t, err) + assert.True(t, ok) + assert.Equal(t, a4, ancestry.Ancestry) +} + +func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool { + sort.Strings(expected.Detectors) + sort.Strings(actual.Detectors) + sort.Strings(expected.Listers) + sort.Strings(actual.Listers) + return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers) +} + +func TestFindAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "FindAncestry", true) + defer closeTest(t, store, tx) + + // not found + _, _, ok, err := tx.FindAncestry("ancestry-non") + assert.Nil(t, err) + assert.False(t, ok) + + expected := database.Ancestry{ + Name: "ancestry-1", + Layers: []database.Layer{ + {Hash: "layer-0"}, + {Hash: "layer-1"}, + {Hash: "layer-2"}, + {Hash: "layer-3a"}, + }, + } + + expectedProcessors := database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"dpkg"}, + } + + // found + a, p, ok2, err := tx.FindAncestry("ancestry-1") + if assert.Nil(t, err) && assert.True(t, ok2) { + assertAncestryEqual(t, expected, a) + assertProcessorsEqual(t, expectedProcessors, p) + } +} + +func assertAncestryWithFeatureEqual(t *testing.T, expected database.AncestryWithFeatures, actual database.AncestryWithFeatures) bool { + return assertAncestryEqual(t, expected.Ancestry, actual.Ancestry) && + assertNamespacedFeatureEqual(t, expected.Features, actual.Features) && + assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) +} +func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool { + return assert.Equal(t, expected.Name, actual.Name) && assert.Equal(t, expected.Layers, actual.Layers) +} + +func TestFindAncestryFeatures(t *testing.T) { + store, tx := openSessionForTest(t, "FindAncestryFeatures", true) + defer closeTest(t, store, tx) + + // invalid + _, ok, err := tx.FindAncestryFeatures("ancestry-non") + if assert.Nil(t, err) { + assert.False(t, ok) + } + + expected := database.AncestryWithFeatures{ + Ancestry: database.Ancestry{ + Name: "ancestry-2", + Layers: []database.Layer{ + {Hash: "layer-0"}, + {Hash: "layer-1"}, + {Hash: "layer-2"}, + {Hash: "layer-3b"}, + }, + }, + ProcessedBy: database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"dpkg"}, + }, + Features: []database.NamespacedFeature{ + { + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + Feature: database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + }, + }, + { + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: "dpkg", + }, + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + }, + }, + }, + } + // valid + ancestry, ok, err := tx.FindAncestryFeatures("ancestry-2") + if assert.Nil(t, err) && assert.True(t, ok) { + assertAncestryEqual(t, expected.Ancestry, ancestry.Ancestry) + assertNamespacedFeatureEqual(t, expected.Features, ancestry.Features) + assertProcessorsEqual(t, expected.ProcessedBy, ancestry.ProcessedBy) + } +} diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index ed038b4e..07d6f55f 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -27,135 +27,200 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" + "github.com/coreos/clair/pkg/strutil" ) const ( numVulnerabilities = 100 - numFeatureVersions = 100 + numFeatures = 100 ) -func TestRaceAffects(t *testing.T) { - datastore, err := openDatabaseForTest("RaceAffects", false) - if err != nil { - t.Error(err) - return +func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) { + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - defer datastore.Close() - // Insert the Feature on which we'll work. - feature := database.Feature{ - Namespace: database.Namespace{ - Name: "TestRaceAffectsFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestRaceAffecturesFeature1", + featureName := "TestFeature" + featureVersionFormat := dpkg.ParserName + // Insert the namespace on which we'll work. + namespace := database.Namespace{ + Name: "TestRaceAffectsFeatureNamespace1", + VersionFormat: dpkg.ParserName, } - _, err = datastore.insertFeature(feature) - if err != nil { - t.Error(err) - return + + if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) { + t.FailNow() } // Initialize random generator and enforce max procs. rand.Seed(time.Now().UnixNano()) runtime.GOMAXPROCS(runtime.NumCPU()) - // Generate FeatureVersions. - featureVersions := make([]database.FeatureVersion, numFeatureVersions) - for i := 0; i < numFeatureVersions; i++ { - version := rand.Intn(numFeatureVersions) + // Generate Distinct random features + features := make([]database.Feature, numFeatures) + nsFeatures := make([]database.NamespacedFeature, numFeatures) + for i := 0; i < numFeatures; i++ { + version := rand.Intn(numFeatures) - featureVersions[i] = database.FeatureVersion{ - Feature: feature, - Version: strconv.Itoa(version), + features[i] = database.Feature{ + Name: featureName, + VersionFormat: featureVersionFormat, + Version: strconv.Itoa(version), } + + nsFeatures[i] = database.NamespacedFeature{ + Namespace: namespace, + Feature: features[i], + } + } + + // insert features + if !assert.Nil(t, tx.PersistFeatures(features)) { + t.FailNow() } // Generate vulnerabilities. - // They are mapped by fixed version, which will make verification really easy afterwards. - vulnerabilities := make(map[int][]database.Vulnerability) + vulnerabilities := []database.VulnerabilityWithAffected{} for i := 0; i < numVulnerabilities; i++ { - version := rand.Intn(numFeatureVersions) + 1 + // any version less than this is vulnerable + version := rand.Intn(numFeatures) + 1 - // if _, ok := vulnerabilities[version]; !ok { - // vulnerabilities[version] = make([]database.Vulnerability) - // } - - vulnerability := database.Vulnerability{ - Name: uuid.New(), - Namespace: feature.Namespace, - FixedIn: []database.FeatureVersion{ + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: uuid.New(), + Namespace: namespace, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{ { - Feature: feature, - Version: strconv.Itoa(version), + Namespace: namespace, + FeatureName: featureName, + AffectedVersion: strconv.Itoa(version), + FixedInVersion: strconv.Itoa(version), }, }, - Severity: database.UnknownSeverity, } - vulnerabilities[version] = append(vulnerabilities[version], vulnerability) + vulnerabilities = append(vulnerabilities, vulnerability) } + tx.Commit() + + return nsFeatures, vulnerabilities +} + +func TestConcurrency(t *testing.T) { + store, err := openDatabaseForTest("Concurrency", false) + if !assert.Nil(t, err) { + t.FailNow() + } + defer store.Close() + + start := time.Now() + var wg sync.WaitGroup + wg.Add(100) + for i := 0; i < 100; i++ { + go func() { + defer wg.Done() + nsNamespaces := genRandomNamespaces(t, 100) + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } + assert.Nil(t, tx.PersistNamespaces(nsNamespaces)) + tx.Commit() + }() + } + wg.Wait() + fmt.Println("total", time.Since(start)) +} + +func genRandomNamespaces(t *testing.T, count int) []database.Namespace { + r := make([]database.Namespace, count) + for i := 0; i < count; i++ { + r[i] = database.Namespace{ + Name: uuid.New(), + VersionFormat: "dpkg", + } + } + return r +} + +func TestCaching(t *testing.T) { + store, err := openDatabaseForTest("Caching", false) + if !assert.Nil(t, err) { + t.FailNow() + } + defer store.Close() + + nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store) + + fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities)) - // Insert featureversions and vulnerabilities in parallel. var wg sync.WaitGroup wg.Add(2) - go func() { defer wg.Done() - for _, vulnerabilitiesM := range vulnerabilities { - for _, vulnerability := range vulnerabilitiesM { - err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) - assert.Nil(t, err) - } + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - fmt.Println("finished to insert vulnerabilities") + + assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures)) + fmt.Println("finished to insert namespaced features") + + tx.Commit() }() go func() { defer wg.Done() - for i := 0; i < len(featureVersions); i++ { - featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i]) - assert.Nil(t, err) + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() } - fmt.Println("finished to insert featureVersions") + + assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities)) + fmt.Println("finished to insert vulnerabilities") + tx.Commit() + }() wg.Wait() + tx, err := store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } + defer tx.Rollback() + // Verify consistency now. - var actualAffectedNames []string - var expectedAffectedNames []string + affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures) + if !assert.Nil(t, err) { + t.FailNow() + } - for _, featureVersion := range featureVersions { - featureVersionVersion, _ := strconv.Atoi(featureVersion.Version) - - // Get actual affects. - rows, err := datastore.Query(searchComplexTestFeatureVersionAffects, - featureVersion.ID) - assert.Nil(t, err) - defer rows.Close() - - var vulnName string - for rows.Next() { - err = rows.Scan(&vulnName) - if !assert.Nil(t, err) { - continue - } - actualAffectedNames = append(actualAffectedNames, vulnName) - } - if assert.Nil(t, rows.Err()) { - rows.Close() + for _, ansf := range affected { + if !assert.True(t, ansf.Valid) { + t.FailNow() } - // Get expected affects. - for i := numVulnerabilities; i > featureVersionVersion; i-- { - for _, vulnerability := range vulnerabilities[i] { - expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name) + expectedAffectedNames := []string{} + for _, vuln := range vulnerabilities { + if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil { + if ok { + expectedAffectedNames = append(expectedAffectedNames, vuln.Name) + } } } - assert.Len(t, compareStringLists(expectedAffectedNames, actualAffectedNames), 0) - assert.Len(t, compareStringLists(actualAffectedNames, expectedAffectedNames), 0) + actualAffectedNames := []string{} + for _, s := range ansf.AffectedBy { + actualAffectedNames = append(actualAffectedNames, s.Name) + } + + assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) + assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) } } diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index c39bd5b7..81ef857d 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -16,230 +16,366 @@ package pgsql import ( "database/sql" - "strings" - "time" + "errors" + "sort" + + "github.com/lib/pq" + log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { - if feature.Name == "" { - return 0, commonerr.NewBadRequestError("could not find/insert invalid Feature") - } +var ( + errFeatureNotFound = errors.New("Feature not found") +) - // Do cache lookup. - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("feature").Inc() - id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name) - if found { - promCacheHitsTotal.WithLabelValues("feature").Inc() - return id.(int), nil - } - } - - // We do `defer observeQueryTime` here because we don't want to observe cached features. - defer observeQueryTime("insertFeature", "all", time.Now()) - - // Find or create Namespace. - namespaceID, err := pgSQL.insertNamespace(feature.Namespace) - if err != nil { - return 0, err - } - - // Find or create Feature. - var id int - err = pgSQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id) - if err != nil { - return 0, handleError("soiFeature", err) - } - - if pgSQL.cache != nil { - pgSQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id) - } - - return id, nil +type vulnerabilityAffecting struct { + vulnerabilityID int64 + addedByID int64 } -func (pgSQL *pgSQL) insertFeatureVersion(fv database.FeatureVersion) (id int, err error) { - err = versionfmt.Valid(fv.Feature.Namespace.VersionFormat, fv.Version) - if err != nil { - return 0, commonerr.NewBadRequestError("could not find/insert invalid FeatureVersion") +func (tx *pgSession) PersistFeatures(features []database.Feature) error { + if len(features) == 0 { + return nil } - // Do cache lookup. - cacheIndex := strings.Join([]string{"featureversion", fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("featureversion").Inc() - id, found := pgSQL.cache.Get(cacheIndex) - if found { - promCacheHitsTotal.WithLabelValues("featureversion").Inc() - return id.(int), nil + // Sorting is needed before inserting into database to prevent deadlock. + sort.Slice(features, func(i, j int) bool { + return features[i].Name < features[j].Name || + features[i].Version < features[j].Version || + features[i].VersionFormat < features[j].VersionFormat + }) + + // TODO(Sida): A better interface for bulk insertion is needed. + keys := make([]interface{}, len(features)*3) + for i, f := range features { + keys[i*3] = f.Name + keys[i*3+1] = f.Version + keys[i*3+2] = f.VersionFormat + if f.Name == "" || f.Version == "" || f.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") } } - // We do `defer observeQueryTime` here because we don't want to observe cached featureversions. - defer observeQueryTime("insertFeatureVersion", "all", time.Now()) - - // Find or create Feature first. - t := time.Now() - featureID, err := pgSQL.insertFeature(fv.Feature) - observeQueryTime("insertFeatureVersion", "insertFeature", t) - - if err != nil { - return 0, err - } - - fv.Feature.ID = featureID - - // Try to find the FeatureVersion. - // - // In a populated database, the likelihood of the FeatureVersion already being there is high. - // If we can find it here, we then avoid using a transaction and locking the database. - err = pgSQL.QueryRow(searchFeatureVersion, featureID, fv.Version).Scan(&fv.ID) - if err != nil && err != sql.ErrNoRows { - return 0, handleError("searchFeatureVersion", err) - } - if err == nil { - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - - return fv.ID, nil - } - - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return 0, handleError("insertFeatureVersion.Begin()", err) - } - - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertVulnerability to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t = time.Now() - _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertFeatureVersion", "lock", t) - - if err != nil { - tx.Rollback() - return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err) - } - - // Find or create FeatureVersion. - var created bool - - t = time.Now() - err = tx.QueryRow(soiFeatureVersion, featureID, fv.Version).Scan(&created, &fv.ID) - observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t) - - if err != nil { - tx.Rollback() - return 0, handleError("soiFeatureVersion", err) - } - - if !created { - // The featureVersion already existed, no need to link it to - // vulnerabilities. - tx.Commit() - - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - - return fv.ID, nil - } - - // Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in - // Vulnerability_Affects_FeatureVersion. - t = time.Now() - err = linkFeatureVersionToVulnerabilities(tx, fv) - observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t) - - if err != nil { - tx.Rollback() - return 0, err - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - return 0, handleError("insertFeatureVersion.Commit()", err) - } - - if pgSQL.cache != nil { - pgSQL.cache.Add(cacheIndex, fv.ID) - } - - return fv.ID, nil + _, err := tx.Exec(queryPersistFeature(len(features)), keys...) + return handleError("queryPersistFeature", err) } -// TODO(Quentin-M): Batch me -func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) { - IDs := make([]int, 0, len(featureVersions)) +type namespacedFeatureWithID struct { + database.NamespacedFeature - for i := 0; i < len(featureVersions); i++ { - id, err := pgSQL.insertFeatureVersion(featureVersions[i]) - if err != nil { - return IDs, err - } - IDs = append(IDs, id) + ID int64 +} + +type vulnerabilityCache struct { + nsFeatureID int64 + vulnID int64 + vulnAffectingID int64 +} + +func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) { + if len(features) == 0 { + return nil, nil } - return IDs, nil -} - -type vulnerabilityAffectsFeatureVersion struct { - vulnerabilityID int - fixedInID int - fixedInVersion string -} - -func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { - // Select every vulnerability and the fixed version that affect this Feature. - // TODO(Quentin-M): LIMIT - rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID) + ids, err := tx.findNamespacedFeatureIDs(features) if err != nil { - return handleError("searchVulnerabilityFixedInFeature", err) + return nil, err } + + fMap := map[int64]database.NamespacedFeature{} + for i, f := range features { + if !ids[i].Valid { + return nil, errFeatureNotFound + } + fMap[ids[i].Int64] = f + } + + cacheTable := []vulnerabilityCache{} + rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids)) + if err != nil { + return nil, handleError("searchPotentialAffectingVulneraibilities", err) + } + defer rows.Close() - - var affects []vulnerabilityAffectsFeatureVersion for rows.Next() { - var affect vulnerabilityAffectsFeatureVersion + var ( + cache vulnerabilityCache + affected string + ) - err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion) + err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID) if err != nil { - return handleError("searchVulnerabilityFixedInFeature.Scan()", err) + return nil, err } - cmp, err := versionfmt.Compare(featureVersion.Feature.Namespace.VersionFormat, featureVersion.Version, affect.fixedInVersion) - if err != nil { - return err - } - if cmp < 0 { - // The version of the FeatureVersion we are inserting is lower than the fixed version on this - // Vulnerability, thus, this FeatureVersion is affected by it. - affects = append(affects, affect) + if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil { + return nil, err + } else if ok { + cacheTable = append(cacheTable, cache) } } - if err = rows.Err(); err != nil { - return handleError("searchVulnerabilityFixedInFeature.Rows()", err) - } - rows.Close() - // Insert into Vulnerability_Affects_FeatureVersion. - for _, affect := range affects { - // TODO(Quentin-M): Batch me. - _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID, - featureVersion.ID, affect.fixedInID) - if err != nil { - return handleError("insertVulnerabilityAffectsFeatureVersion", err) + return cacheTable, nil +} + +func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + _, err := tx.Exec(lockVulnerabilityAffects) + if err != nil { + return handleError("lockVulnerabilityAffects", err) + } + + cache, err := tx.searchAffectingVulnerabilities(features) + + keys := make([]interface{}, len(cache)*3) + for i, c := range cache { + keys[i*3] = c.vulnID + keys[i*3+1] = c.nsFeatureID + keys[i*3+2] = c.vulnAffectingID + } + + if len(cache) == 0 { + return nil + } + + affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...) + if err != nil { + return handleError("persistVulnerabilityAffectedNamespacedFeature", err) + } + if count, err := affected.RowsAffected(); err != nil { + log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count) + } + return nil +} + +func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + nsIDs := map[database.Namespace]sql.NullInt64{} + fIDs := map[database.Feature]sql.NullInt64{} + for _, f := range features { + nsIDs[f.Namespace] = sql.NullInt64{} + fIDs[f.Feature] = sql.NullInt64{} + } + + fToFind := []database.Feature{} + for f := range fIDs { + fToFind = append(fToFind, f) + } + + sort.Slice(fToFind, func(i, j int) bool { + return fToFind[i].Name < fToFind[j].Name || + fToFind[i].Version < fToFind[j].Version || + fToFind[i].VersionFormat < fToFind[j].VersionFormat + }) + + if ids, err := tx.findFeatureIDs(fToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return errFeatureNotFound + } + fIDs[fToFind[i]] = id } + } else { + return err + } + + nsToFind := []database.Namespace{} + for ns := range nsIDs { + nsToFind = append(nsToFind, ns) + } + + if ids, err := tx.findNamespaceIDs(nsToFind); err == nil { + for i, id := range ids { + if !id.Valid { + return errNamespaceNotFound + } + nsIDs[nsToFind[i]] = id + } + } else { + return err + } + + keys := make([]interface{}, len(features)*2) + for i, f := range features { + keys[i*2] = fIDs[f.Feature] + keys[i*2+1] = nsIDs[f.Namespace] + } + + _, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...) + if err != nil { + return err } return nil } + +// FindAffectedNamespacedFeatures looks up cache table and retrieves all +// vulnerabilities associated with the features. +func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { + if len(features) == 0 { + return nil, nil + } + + returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) + + // featureMap is used to keep track of duplicated features. + featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{} + // initialize return value and generate unique feature request queries. + for i, f := range features { + returnFeatures[i] = database.NullableAffectedNamespacedFeature{ + AffectedNamespacedFeature: database.AffectedNamespacedFeature{ + NamespacedFeature: f, + }, + } + + featureMap[f] = append(featureMap[f], &returnFeatures[i]) + } + + // query unique namespaced features + distinctFeatures := []database.NamespacedFeature{} + for f := range featureMap { + distinctFeatures = append(distinctFeatures, f) + } + + nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures) + if err != nil { + return nil, err + } + + toQuery := []int64{} + featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{} + for i, id := range nsFeatureIDs { + if id.Valid { + toQuery = append(toQuery, id.Int64) + for _, f := range featureMap[distinctFeatures[i]] { + f.Valid = id.Valid + featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f) + } + } + } + + rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery)) + if err != nil { + return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) + } + defer rows.Close() + + for rows.Next() { + var ( + featureID int64 + vuln database.VulnerabilityWithFixedIn + ) + err := rows.Scan(&featureID, + &vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.FixedInVersion, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat, + ) + if err != nil { + return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) + } + + for _, f := range featureIDMap[featureID] { + f.AffectedBy = append(f.AffectedBy, vuln) + } + } + + return returnFeatures, nil +} + +func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { + if len(nfs) == 0 { + return nil, nil + } + + nfsMap := map[database.NamespacedFeature]sql.NullInt64{} + keys := make([]interface{}, len(nfs)*4) + for i, nf := range nfs { + keys[i*4] = nfs[i].Name + keys[i*4+1] = nfs[i].Version + keys[i*4+2] = nfs[i].VersionFormat + keys[i*4+3] = nfs[i].Namespace.Name + nfsMap[nf] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...) + if err != nil { + return nil, handleError("searchNamespacedFeature", err) + } + + defer rows.Close() + var ( + id sql.NullInt64 + nf database.NamespacedFeature + ) + + for rows.Next() { + err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name) + nf.Namespace.VersionFormat = nf.VersionFormat + if err != nil { + return nil, handleError("searchNamespacedFeature", err) + } + nfsMap[nf] = id + } + + ids := make([]sql.NullInt64, len(nfs)) + for i, nf := range nfs { + ids[i] = nfsMap[nf] + } + + return ids, nil +} + +func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) { + if len(fs) == 0 { + return nil, nil + } + + fMap := map[database.Feature]sql.NullInt64{} + + keys := make([]interface{}, len(fs)*3) + for i, f := range fs { + keys[i*3] = f.Name + keys[i*3+1] = f.Version + keys[i*3+2] = f.VersionFormat + fMap[f] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...) + if err != nil { + return nil, handleError("querySearchFeatureID", err) + } + defer rows.Close() + + var ( + id sql.NullInt64 + f database.Feature + ) + for rows.Next() { + err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat) + if err != nil { + return nil, handleError("querySearchFeatureID", err) + } + fMap[f] = id + } + + ids := make([]sql.NullInt64, len(fs)) + for i, f := range fs { + ids[i] = fMap[f] + } + + return ids, nil +} diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature_test.go index 5b7f8078..934b8cc1 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature_test.go @@ -20,96 +20,237 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" + + // register dpkg feature lister for testing + _ "github.com/coreos/clair/ext/featurefmt/dpkg" ) -func TestInsertFeature(t *testing.T) { - datastore, err := openDatabaseForTest("InsertFeature", false) +func TestPersistFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistFeatures", false) + defer closeTest(t, datastore, tx) + + f1 := database.Feature{} + f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"} + + // empty + assert.Nil(t, tx.PersistFeatures([]database.Feature{})) + // invalid + assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1})) + // duplicated + assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2})) + // existing + assert.Nil(t, tx.PersistFeatures([]database.Feature{f2})) + + fs := listFeatures(t, tx) + assert.Len(t, fs, 1) + assert.Equal(t, f2, fs[0]) +} + +func TestPersistNamespacedFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + + // existing features + f1 := database.Feature{ + Name: "wechat", + Version: "0.5", + VersionFormat: "dpkg", + } + + // non-existing features + f2 := database.Feature{ + Name: "fake!", + } + + f3 := database.Feature{ + Name: "openssl", + Version: "2.0", + VersionFormat: "dpkg", + } + + // exising namespace + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + n3 := database.Namespace{ + Name: "debian:8", + VersionFormat: "dpkg", + } + + // non-existing namespace + n2 := database.Namespace{ + Name: "debian:non", + VersionFormat: "dpkg", + } + + // existing namespaced feature + nf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + + // invalid namespaced feature + nf2 := database.NamespacedFeature{ + Namespace: n2, + Feature: f2, + } + + // new namespaced feature affected by vulnerability + nf3 := database.NamespacedFeature{ + Namespace: n3, + Feature: f3, + } + + // namespaced features with namespaces or features not in the database will + // generate error. + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) + + assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2})) + // valid case: insert nf3 + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3})) + + all := listNamespacedFeatures(t, tx) + assert.Contains(t, all, nf1) + assert.Contains(t, all, nf3) +} + +func TestVulnerableFeature(t *testing.T) { + datastore, tx := openSessionForTest(t, "VulnerableFeature", true) + defer closeTest(t, datastore, tx) + + f1 := database.Feature{ + Name: "openssl", + Version: "1.3", + VersionFormat: "dpkg", + } + + n1 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + nf1 := database.NamespacedFeature{ + Namespace: n1, + Feature: f1, + } + assert.Nil(t, tx.PersistFeatures([]database.Feature{f1})) + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1})) + assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})) + // ensure the namespaced feature is affected correctly + anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}) + if assert.Nil(t, err) && + assert.Len(t, anf, 1) && + assert.True(t, anf[0].Valid) && + assert.Len(t, anf[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name) + } +} + +func TestFindAffectedNamespacedFeatures(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true) + defer closeTest(t, datastore, tx) + ns := database.NamespacedFeature{ + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + }, + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + } + + ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns}) + if assert.Nil(t, err) && + assert.Len(t, ans, 1) && + assert.True(t, ans[0].Valid) && + assert.Len(t, ans[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name) + } +} + +func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature { + rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format + FROM feature AS f, namespace AS n, namespaced_feature AS nf + WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`) if err != nil { t.Error(err) - return - } - defer datastore.Close() - - // Invalid Feature. - id0, err := datastore.insertFeature(database.Feature{}) - assert.NotNil(t, err) - assert.Zero(t, id0) - - id0, err = datastore.insertFeature(database.Feature{ - Namespace: database.Namespace{}, - Name: "TestInsertFeature0", - }) - assert.NotNil(t, err) - assert.Zero(t, id0) - - // Insert Feature and ensure we can find it. - feature := database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature1", - } - id1, err := datastore.insertFeature(feature) - assert.Nil(t, err) - id2, err := datastore.insertFeature(feature) - assert.Nil(t, err) - assert.Equal(t, id1, id2) - - // Insert invalid FeatureVersion. - for _, invalidFeatureVersion := range []database.FeatureVersion{ - { - Feature: database.Feature{}, - Version: "1.0", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{}, - Name: "TestInsertFeature2", - }, - Version: "1.0", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature2", - }, - Version: "", - }, - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature2", - }, - Version: "bad version", - }, - } { - id3, err := datastore.insertFeatureVersion(invalidFeatureVersion) - assert.Error(t, err) - assert.Zero(t, id3) + t.FailNow() } - // Insert FeatureVersion and ensure we can find it. - featureVersion := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertFeatureNamespace1", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertFeature1", - }, - Version: "2:3.0-imba", + nf := []database.NamespacedFeature{} + for rows.Next() { + f := database.NamespacedFeature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat) + if err != nil { + t.Error(err) + t.FailNow() + } + nf = append(nf, f) } - id4, err := datastore.insertFeatureVersion(featureVersion) - assert.Nil(t, err) - id5, err := datastore.insertFeatureVersion(featureVersion) - assert.Nil(t, err) - assert.Equal(t, id4, id5) + + return nf +} + +func listFeatures(t *testing.T, tx *pgSession) []database.Feature { + rows, err := tx.Query("SELECT name, version, version_format FROM feature") + if err != nil { + t.FailNow() + } + + fs := []database.Feature{} + for rows.Next() { + f := database.Feature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) + if err != nil { + t.FailNow() + } + fs = append(fs, f) + } + return fs +} + +func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Feature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Name+" is expected") { + return false + } + return true + } + } + return false +} + +func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.NamespacedFeature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") { + return false + } + } + return true + } + return false } diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue.go index ab599588..1f85fab5 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue.go @@ -23,63 +23,35 @@ import ( "github.com/coreos/clair/pkg/commonerr" ) -// InsertKeyValue stores (or updates) a single key / value tuple. -func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { +func (tx *pgSession) UpdateKeyValue(key, value string) (err error) { if key == "" || value == "" { log.Warning("could not insert a flag which has an empty name or value") return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value") } - defer observeQueryTime("InsertKeyValue", "all", time.Now()) + defer observeQueryTime("PersistKeyValue", "all", time.Now()) - // Upsert. - // - // Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS. - // The best solution is currently the use of http://dba.stackexchange.com/a/13477 - // but the key/value storage doesn't need to be super-efficient and super-safe at the - // moment so we can just use a client-side solution with transactions, based on - // http://postgresql.org/docs/current/static/plpgsql-control-structures.html. - // TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable. - - for { - // First, try to update. - r, err := pgSQL.Exec(updateKeyValue, value, key) - if err != nil { - return handleError("updateKeyValue", err) - } - if n, _ := r.RowsAffected(); n > 0 { - // Updated successfully. - return nil - } - - // Try to insert the key. - // If someone else inserts the same key concurrently, we could get a unique-key violation error. - _, err = pgSQL.Exec(insertKeyValue, key, value) - if err != nil { - if isErrUniqueViolation(err) { - // Got unique constraint violation, retry. - continue - } - return handleError("insertKeyValue", err) - } - - return nil + _, err = tx.Exec(upsertKeyValue, key, value) + if err != nil { + return handleError("insertKeyValue", err) } + + return nil } -// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. -func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { - defer observeQueryTime("GetKeyValue", "all", time.Now()) +func (tx *pgSession) FindKeyValue(key string) (string, bool, error) { + defer observeQueryTime("FindKeyValue", "all", time.Now()) var value string - err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value) + err := tx.QueryRow(searchKeyValue, key).Scan(&value) if err == sql.ErrNoRows { - return "", nil - } - if err != nil { - return "", handleError("searchKeyValue", err) + return "", false, nil } - return value, nil + if err != nil { + return "", false, handleError("searchKeyValue", err) + } + + return value, true, nil } diff --git a/database/pgsql/keyvalue_test.go b/database/pgsql/keyvalue_test.go index 4a8b6593..9991bf48 100644 --- a/database/pgsql/keyvalue_test.go +++ b/database/pgsql/keyvalue_test.go @@ -21,32 +21,30 @@ import ( ) func TestKeyValue(t *testing.T) { - datastore, err := openDatabaseForTest("KeyValue", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() + datastore, tx := openSessionForTest(t, "KeyValue", true) + defer closeTest(t, datastore, tx) // Get non-existing key/value - f, err := datastore.GetKeyValue("test") + f, ok, err := tx.FindKeyValue("test") assert.Nil(t, err) - assert.Empty(t, "", f) + assert.False(t, ok) // Try to insert invalid key/value. - assert.Error(t, datastore.InsertKeyValue("test", "")) - assert.Error(t, datastore.InsertKeyValue("", "test")) - assert.Error(t, datastore.InsertKeyValue("", "")) + assert.Error(t, tx.UpdateKeyValue("test", "")) + assert.Error(t, tx.UpdateKeyValue("", "test")) + assert.Error(t, tx.UpdateKeyValue("", "")) // Insert and verify. - assert.Nil(t, datastore.InsertKeyValue("test", "test1")) - f, err = datastore.GetKeyValue("test") + assert.Nil(t, tx.UpdateKeyValue("test", "test1")) + f, ok, err = tx.FindKeyValue("test") assert.Nil(t, err) + assert.True(t, ok) assert.Equal(t, "test1", f) // Update and verify. - assert.Nil(t, datastore.InsertKeyValue("test", "test2")) - f, err = datastore.GetKeyValue("test") + assert.Nil(t, tx.UpdateKeyValue("test", "test2")) + f, ok, err = tx.FindKeyValue("test") assert.Nil(t, err) + assert.True(t, ok) assert.Equal(t, "test2", f) } diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 64e9a475..c7cd5ce2 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -16,464 +16,293 @@ package pgsql import ( "database/sql" - "strings" - "time" - - "github.com/guregu/null/zero" - log "github.com/sirupsen/logrus" + "sort" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { - subquery := "all" - if withFeatures { - subquery += "/features" - } else if withVulnerabilities { - subquery += "/features+vulnerabilities" - } - defer observeQueryTime("FindLayer", subquery, time.Now()) +func (tx *pgSession) FindLayer(hash string) (database.Layer, database.Processors, bool, error) { + l, p, _, ok, err := tx.findLayer(hash) + return l, p, ok, err +} - // Find the layer +func (tx *pgSession) FindLayerWithContent(hash string) (database.LayerWithContent, bool, error) { var ( - layer database.Layer - parentID zero.Int - parentName zero.String - nsID zero.Int - nsName sql.NullString - nsVersionFormat sql.NullString + layer database.LayerWithContent + layerID int64 + ok bool + err error ) - t := time.Now() - err := pgSQL.QueryRow(searchLayer, name).Scan( - &layer.ID, - &layer.Name, - &layer.EngineVersion, - &parentID, - &parentName, - ) - observeQueryTime("FindLayer", "searchLayer", t) - + layer.Layer, layer.ProcessedBy, layerID, ok, err = tx.findLayer(hash) if err != nil { - return layer, handleError("searchLayer", err) + return layer, false, err } - if !parentID.IsZero() { - layer.Parent = &database.Layer{ - Model: database.Model{ID: int(parentID.Int64)}, - Name: parentName.String, - } + if !ok { + return layer, false, nil } - rows, err := pgSQL.Query(searchLayerNamespace, layer.ID) - defer rows.Close() - if err != nil { - return layer, handleError("searchLayerNamespace", err) - } - for rows.Next() { - err = rows.Scan(&nsID, &nsName, &nsVersionFormat) - if err != nil { - return layer, handleError("searchLayerNamespace", err) - } - if !nsID.IsZero() { - layer.Namespaces = append(layer.Namespaces, database.Namespace{ - Model: database.Model{ID: int(nsID.Int64)}, - Name: nsName.String, - VersionFormat: nsVersionFormat.String, - }) - } - } - - // Find its features - if withFeatures || withVulnerabilities { - // Create a transaction to disable hash/merge joins as our experiments have shown that - // PostgreSQL 9.4 makes bad planning decisions about: - // - joining the layer tree to feature versions and feature - // - joining the feature versions to affected/fixed feature version and vulnerabilities - // It would for instance do a merge join between affected feature versions (300 rows, estimated - // 3000 rows) and fixed in feature version (100k rows). In this case, it is much more - // preferred to use a nested loop. - tx, err := pgSQL.Begin() - if err != nil { - return layer, handleError("FindLayer.Begin()", err) - } - defer tx.Commit() - - _, err = tx.Exec(disableHashJoin) - if err != nil { - log.WithError(err).Warningf("FindLayer: could not disable hash join") - } - _, err = tx.Exec(disableMergeJoin) - if err != nil { - log.WithError(err).Warningf("FindLayer: could not disable merge join") - } - - t = time.Now() - featureVersions, err := getLayerFeatureVersions(tx, layer.ID) - observeQueryTime("FindLayer", "getLayerFeatureVersions", t) - - if err != nil { - return layer, err - } - - layer.Features = featureVersions - - if withVulnerabilities { - // Load the vulnerabilities that affect the FeatureVersions. - t = time.Now() - err := loadAffectedBy(tx, layer.Features) - observeQueryTime("FindLayer", "loadAffectedBy", t) - - if err != nil { - return layer, err - } - } - } - - return layer, nil + layer.Features, err = tx.findLayerFeatures(layerID) + layer.Namespaces, err = tx.findLayerNamespaces(layerID) + return layer, true, nil } -// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has. -func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) { - var featureVersions []database.FeatureVersion +func (tx *pgSession) PersistLayer(layer database.Layer) error { + if layer.Hash == "" { + return commonerr.NewBadRequestError("Empty Layer Hash is not allowed") + } - // Query. - rows, err := tx.Query(searchLayerFeatureVersion, layerID) + _, err := tx.Exec(queryPersistLayer(1), layer.Hash) if err != nil { - return featureVersions, handleError("searchLayerFeatureVersion", err) - } - defer rows.Close() - - // Scan query. - var modification string - mapFeatureVersions := make(map[int]database.FeatureVersion) - for rows.Next() { - var fv database.FeatureVersion - err = rows.Scan( - &fv.ID, - &modification, - &fv.Feature.Namespace.ID, - &fv.Feature.Namespace.Name, - &fv.Feature.Namespace.VersionFormat, - &fv.Feature.ID, - &fv.Feature.Name, - &fv.ID, - &fv.Version, - &fv.AddedBy.ID, - &fv.AddedBy.Name, - ) - if err != nil { - return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err) - } - - // Do transitive closure. - switch modification { - case "add": - mapFeatureVersions[fv.ID] = fv - case "del": - delete(mapFeatureVersions, fv.ID) - default: - log.WithField("modification", modification).Warning("unknown Layer_diff_FeatureVersion's modification") - return featureVersions, database.ErrInconsistent - } - } - if err = rows.Err(); err != nil { - return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err) + return handleError("queryPersistLayer", err) } - // Build result by converting our map to a slice. - for _, featureVersion := range mapFeatureVersions { - featureVersions = append(featureVersions, featureVersion) - } - - return featureVersions, nil + return nil } -// loadAffectedBy returns the list of database.Vulnerability that affect the given -// FeatureVersion. -func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error { - if len(featureVersions) == 0 { +// PersistLayerContent relates layer identified by hash with namespaces, +// features and processors provided. If the layer, namespaces, features are not +// in database, the function returns an error. +func (tx *pgSession) PersistLayerContent(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error { + if hash == "" { + return commonerr.NewBadRequestError("Empty layer hash is not allowed") + } + + var layerID int64 + err := tx.QueryRow(searchLayer, hash).Scan(&layerID) + if err != nil { + return err + } + + if err = tx.persistLayerNamespace(layerID, namespaces); err != nil { + return err + } + + if err = tx.persistLayerFeatures(layerID, features); err != nil { + return err + } + + if err = tx.persistLayerDetectors(layerID, processedBy.Detectors); err != nil { + return err + } + + if err = tx.persistLayerListers(layerID, processedBy.Listers); err != nil { + return err + } + + return nil +} + +func (tx *pgSession) persistLayerDetectors(id int64, detectors []string) error { + if len(detectors) == 0 { return nil } - // Construct list of FeatureVersion IDs, we will do a single query - featureVersionIDs := make([]int, 0, len(featureVersions)) - for i := 0; i < len(featureVersions); i++ { - featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID) + // Sorting is needed before inserting into database to prevent deadlock. + sort.Strings(detectors) + keys := make([]interface{}, len(detectors)*2) + for i, d := range detectors { + keys[i*2] = id + keys[i*2+1] = d + } + _, err := tx.Exec(queryPersistLayerDetectors(len(detectors)), keys...) + if err != nil { + return handleError("queryPersistLayerDetectors", err) + } + return nil +} + +func (tx *pgSession) persistLayerListers(id int64, listers []string) error { + if len(listers) == 0 { + return nil } - rows, err := tx.Query(searchFeatureVersionVulnerability, - buildInputArray(featureVersionIDs)) - if err != nil && err != sql.ErrNoRows { - return handleError("searchFeatureVersionVulnerability", err) + sort.Strings(listers) + keys := make([]interface{}, len(listers)*2) + for i, d := range listers { + keys[i*2] = id + keys[i*2+1] = d + } + + _, err := tx.Exec(queryPersistLayerListers(len(listers)), keys...) + if err != nil { + return handleError("queryPersistLayerDetectors", err) + } + return nil +} + +func (tx *pgSession) persistLayerFeatures(id int64, features []database.Feature) error { + if len(features) == 0 { + return nil + } + + fIDs, err := tx.findFeatureIDs(features) + if err != nil { + return err + } + + ids := make([]int, len(fIDs)) + for i, fID := range fIDs { + if !fID.Valid { + return errNamespaceNotFound + } + ids[i] = int(fID.Int64) + } + + sort.IntSlice(ids).Sort() + keys := make([]interface{}, len(features)*2) + for i, fID := range ids { + keys[i*2] = id + keys[i*2+1] = fID + } + + _, err = tx.Exec(queryPersistLayerFeature(len(features)), keys...) + if err != nil { + return handleError("queryPersistLayerFeature", err) + } + return nil +} + +func (tx *pgSession) persistLayerNamespace(id int64, namespaces []database.Namespace) error { + if len(namespaces) == 0 { + return nil + } + + nsIDs, err := tx.findNamespaceIDs(namespaces) + if err != nil { + return err + } + + // for every bulk persist operation, the input data should be sorted. + ids := make([]int, len(nsIDs)) + for i, nsID := range nsIDs { + if !nsID.Valid { + panic(errNamespaceNotFound) + } + ids[i] = int(nsID.Int64) + } + + sort.IntSlice(ids).Sort() + + keys := make([]interface{}, len(namespaces)*2) + for i, nsID := range ids { + keys[i*2] = id + keys[i*2+1] = nsID + } + + _, err = tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...) + if err != nil { + return handleError("queryPersistLayerNamespace", err) + } + return nil +} + +func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int64, processors database.Processors) error { + stmt, err := tx.Prepare(listerQuery) + if err != nil { + return handleError(listerQueryName, err) + } + + for _, l := range processors.Listers { + _, err := stmt.Exec(id, l) + if err != nil { + stmt.Close() + return handleError(listerQueryName, err) + } + } + + if err := stmt.Close(); err != nil { + return handleError(listerQueryName, err) + } + + stmt, err = tx.Prepare(detectorQuery) + if err != nil { + return handleError(detectorQueryName, err) + } + + for _, d := range processors.Detectors { + _, err := stmt.Exec(id, d) + if err != nil { + stmt.Close() + return handleError(detectorQueryName, err) + } + } + + if err := stmt.Close(); err != nil { + return handleError(detectorQueryName, err) + } + + return nil +} + +func (tx *pgSession) findLayerNamespaces(layerID int64) ([]database.Namespace, error) { + var namespaces []database.Namespace + + rows, err := tx.Query(searchLayerNamespaces, layerID) + if err != nil { + return nil, handleError("searchLayerFeatures", err) } - defer rows.Close() - vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions)) - var featureversionID int for rows.Next() { - var vulnerability database.Vulnerability - err := rows.Scan( - &featureversionID, - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.FixedBy, - ) + ns := database.Namespace{} + err := rows.Scan(&ns.Name, &ns.VersionFormat) if err != nil { - return handleError("searchFeatureVersionVulnerability.Scan()", err) + return nil, err } - vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability) + namespaces = append(namespaces, ns) } - if err = rows.Err(); err != nil { - return handleError("searchFeatureVersionVulnerability.Rows()", err) - } - - // Assign vulnerabilities to every FeatureVersions - for i := 0; i < len(featureVersions); i++ { - featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID] - } - - return nil + return namespaces, nil } -// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent, -// the Feature list will be compared to the parent's Feature list and the difference will be stored. -// Note that when the Namespace of a layer differs from its parent, it is expected that several -// Feature that were already included a parent will have their Namespace updated as well -// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed -// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't -// been modified. -func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { - tf := time.Now() +func (tx *pgSession) findLayerFeatures(layerID int64) ([]database.Feature, error) { + var features []database.Feature - // Verify parameters - if layer.Name == "" { - log.Warning("could not insert a layer which has an empty Name") - return commonerr.NewBadRequestError("could not insert a layer which has an empty Name") - } - - // Get a potentially existing layer. - existingLayer, err := pgSQL.FindLayer(layer.Name, true, false) - if err != nil && err != commonerr.ErrNotFound { - return err - } else if err == nil { - if existingLayer.EngineVersion >= layer.EngineVersion { - // The layer exists and has an equal or higher engine version, do nothing. - return nil - } - - layer.ID = existingLayer.ID - } - - // We do `defer observeQueryTime` here because we don't want to observe existing layers. - defer observeQueryTime("InsertLayer", "all", tf) - - // Get parent ID. - var parentID zero.Int - if layer.Parent != nil { - if layer.Parent.ID == 0 { - log.Warning("Parent is expected to be retrieved from database when inserting a layer.") - return commonerr.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.") - } - - parentID = zero.IntFrom(int64(layer.Parent.ID)) - } - - // namespaceIDs will contain inherited and new namespaces - namespaceIDs := make(map[int]struct{}) - - // try to insert the new namespaces - for _, ns := range layer.Namespaces { - n, err := pgSQL.insertNamespace(ns) - if err != nil { - return handleError("pgSQL.insertNamespace", err) - } - namespaceIDs[n] = struct{}{} - } - - // inherit namespaces from parent layer - if layer.Parent != nil { - for _, ns := range layer.Parent.Namespaces { - namespaceIDs[ns.ID] = struct{}{} - } - } - - // Begin transaction. - tx, err := pgSQL.Begin() + rows, err := tx.Query(searchLayerFeatures, layerID) if err != nil { - tx.Rollback() - return handleError("InsertLayer.Begin()", err) + return nil, handleError("searchLayerFeatures", err) } - if layer.ID == 0 { - // Insert a new layer. - err = tx.QueryRow(insertLayer, layer.Name, layer.EngineVersion, parentID). - Scan(&layer.ID) + for rows.Next() { + f := database.Feature{} + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) if err != nil { - tx.Rollback() - - if isErrUniqueViolation(err) { - // Ignore this error, another process collided. - log.Debug("Attempted to insert duplicate layer.") - return nil - } - return handleError("insertLayer", err) - } - } else { - // Update an existing layer. - _, err = tx.Exec(updateLayer, layer.ID, layer.EngineVersion) - if err != nil { - tx.Rollback() - return handleError("updateLayer", err) - } - - // replace the old namespace in the database - _, err := tx.Exec(removeLayerNamespace, layer.ID) - if err != nil { - tx.Rollback() - return handleError("removeLayerNamespace", err) - } - // Remove all existing Layer_diff_FeatureVersion. - _, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID) - if err != nil { - tx.Rollback() - return handleError("removeLayerDiffFeatureVersion", err) + return nil, err } + features = append(features, f) } - - // insert the layer's namespaces - stmt, err := tx.Prepare(insertLayerNamespace) - - if err != nil { - tx.Rollback() - return handleError("failed to prepare statement", err) - } - - defer func() { - err = stmt.Close() - if err != nil { - tx.Rollback() - log.WithError(err).Error("failed to close prepared statement") - } - }() - - for nsid := range namespaceIDs { - _, err := stmt.Exec(layer.ID, nsid) - if err != nil { - tx.Rollback() - return handleError("insertLayerNamespace", err) - } - } - - // Update Layer_diff_FeatureVersion now. - err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer) - if err != nil { - tx.Rollback() - return err - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("InsertLayer.Commit()", err) - } - - return nil + return features, nil } -func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error { - // add and del are the FeatureVersion diff we should insert. - var add []database.FeatureVersion - var del []database.FeatureVersion +func (tx *pgSession) findLayer(hash string) (database.Layer, database.Processors, int64, bool, error) { + var ( + layerID int64 + layer = database.Layer{Hash: hash} + processors database.Processors + ) - if layer.Parent == nil { - // There is no parent, every Features are added. - add = append(add, layer.Features...) - } else if layer.Parent != nil { - // There is a parent, we need to diff the Features with it. - - // Build name:version structures. - layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features) - parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features) - - // Calculate the added and deleted FeatureVersions name:version. - addNV := compareStringLists(layerFeaturesNV, parentLayerFeaturesNV) - delNV := compareStringLists(parentLayerFeaturesNV, layerFeaturesNV) - - // Fill the structures containing the added and deleted FeatureVersions. - for _, nv := range addNV { - add = append(add, *layerFeaturesMapNV[nv]) - } - for _, nv := range delNV { - del = append(del, *parentLayerFeaturesMapNV[nv]) - } + if hash == "" { + return layer, processors, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed") } - // Insert FeatureVersions in the database. - addIDs, err := pgSQL.insertFeatureVersions(add) + err := tx.QueryRow(searchLayer, hash).Scan(&layerID) if err != nil { - return err + if err == sql.ErrNoRows { + return layer, processors, layerID, false, nil + } + return layer, processors, layerID, false, err } - delIDs, err := pgSQL.insertFeatureVersions(del) + + processors.Detectors, err = tx.findProcessors(searchLayerDetectors, "searchLayerDetectors", "detector", layerID) if err != nil { - return err + return layer, processors, layerID, false, err } - // Insert diff in the database. - if len(addIDs) > 0 { - _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "add", buildInputArray(addIDs)) - if err != nil { - return handleError("insertLayerDiffFeatureVersion.Add", err) - } - } - if len(delIDs) > 0 { - _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "del", buildInputArray(delIDs)) - if err != nil { - return handleError("insertLayerDiffFeatureVersion.Del", err) - } + processors.Listers, err = tx.findProcessors(searchLayerListers, "searchLayerListers", "lister", layerID) + if err != nil { + return layer, processors, layerID, false, err } - return nil -} - -func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) { - mapNV := make(map[string]*database.FeatureVersion, 0) - sliceNV := make([]string, 0, len(features)) - - for i := 0; i < len(features); i++ { - fv := &features[i] - nv := strings.Join([]string{fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":") - mapNV[nv] = fv - sliceNV = append(sliceNV, nv) - } - - return mapNV, sliceNV -} - -func (pgSQL *pgSQL) DeleteLayer(name string) error { - defer observeQueryTime("DeleteLayer", "all", time.Now()) - - result, err := pgSQL.Exec(removeLayer, name) - if err != nil { - return handleError("removeLayer", err) - } - - affected, err := result.RowsAffected() - if err != nil { - return handleError("removeLayer.RowsAffected()", err) - } - - if affected <= 0 { - return commonerr.ErrNotFound - } - - return nil + return layer, processors, layerID, true, nil } diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer_test.go index 6f35bbde..e823a048 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer_test.go @@ -15,423 +15,100 @@ package pgsql import ( - "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" ) +func TestPersistLayer(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayer", false) + defer closeTest(t, datastore, tx) + + l1 := database.Layer{} + l2 := database.Layer{Hash: "HESOYAM"} + + // invalid + assert.NotNil(t, tx.PersistLayer(l1)) + // valid + assert.Nil(t, tx.PersistLayer(l2)) + // duplicated + assert.Nil(t, tx.PersistLayer(l2)) +} + +func TestPersistLayerProcessors(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayerProcessors", true) + defer closeTest(t, datastore, tx) + + // invalid + assert.NotNil(t, tx.PersistLayerContent("hash", []database.Namespace{}, []database.Feature{}, database.Processors{})) + // valid + assert.Nil(t, tx.PersistLayerContent("layer-4", []database.Namespace{}, []database.Feature{}, database.Processors{Detectors: []string{"new detector!"}})) +} + func TestFindLayer(t *testing.T) { - datastore, err := openDatabaseForTest("FindLayer", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() + datastore, tx := openSessionForTest(t, "FindLayer", true) + defer closeTest(t, datastore, tx) - // Layer-0: no parent, no namespace, no feature, no vulnerability - layer, err := datastore.FindLayer("layer-0", false, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Equal(t, "layer-0", layer.Name) - assert.Len(t, layer.Namespaces, 0) - assert.Nil(t, layer.Parent) - assert.Equal(t, 1, layer.EngineVersion) - assert.Len(t, layer.Features, 0) + expected := database.Layer{Hash: "layer-4"} + expectedProcessors := database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"dpkg", "rpm"}, } - layer, err = datastore.FindLayer("layer-0", true, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Len(t, layer.Features, 0) - } - - // Layer-1: one parent, adds two features, one vulnerability - layer, err = datastore.FindLayer("layer-1", false, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) { - assert.Equal(t, layer.Name, "layer-1") - assertExpectedNamespaceName(t, &layer, []string{"debian:7"}) - if assert.NotNil(t, layer.Parent) { - assert.Equal(t, "layer-0", layer.Parent.Name) - } - assert.Equal(t, 1, layer.EngineVersion) - assert.Len(t, layer.Features, 0) - } - - layer, err = datastore.FindLayer("layer-1", true, false) - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - for _, featureVersion := range layer.Features { - assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) - - switch featureVersion.Feature.Name { - case "wechat": - assert.Equal(t, "0.5", featureVersion.Version) - case "openssl": - assert.Equal(t, "1.0", featureVersion.Version) - default: - t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) - } - } - } - - layer, err = datastore.FindLayer("layer-1", true, true) - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - for _, featureVersion := range layer.Features { - assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) - - switch featureVersion.Feature.Name { - case "wechat": - assert.Equal(t, "0.5", featureVersion.Version) - case "openssl": - assert.Equal(t, "1.0", featureVersion.Version) - - if assert.Len(t, featureVersion.AffectedBy, 1) { - assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name) - assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name) - assert.Equal(t, database.HighSeverity, featureVersion.AffectedBy[0].Severity) - assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description) - assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link) - assert.Equal(t, "2.0", featureVersion.AffectedBy[0].FixedBy) - } - default: - t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) - } - } - } - - // Testing Multiple namespaces layer-3b has debian:7 and debian:8 namespaces - layer, err = datastore.FindLayer("layer-3b", true, true) - - if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { - assert.Equal(t, "layer-3b", layer.Name) - // validate the namespace - assertExpectedNamespaceName(t, &layer, []string{"debian:7", "debian:8"}) - for _, featureVersion := range layer.Features { - switch featureVersion.Feature.Namespace.Name { - case "debian:7": - assert.Equal(t, "wechat", featureVersion.Feature.Name) - assert.Equal(t, "0.5", featureVersion.Version) - case "debian:8": - assert.Equal(t, "openssl", featureVersion.Feature.Name) - assert.Equal(t, "1.0", featureVersion.Version) - default: - t.Errorf("unexpected package %s for layer-3b", featureVersion.Feature.Name) - } - } - } -} - -func TestInsertLayer(t *testing.T) { - datastore, err := openDatabaseForTest("InsertLayer", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Insert invalid layer. - testInsertLayerInvalid(t, datastore) - - // Insert a layer tree. - testInsertLayerTree(t, datastore) - - // Update layer. - testInsertLayerUpdate(t, datastore) - - // Delete layer. - testInsertLayerDelete(t, datastore) -} - -func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) { - invalidLayers := []database.Layer{ - {}, - {Name: "layer0", Parent: &database.Layer{}}, - {Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}}, - } - - for _, invalidLayer := range invalidLayers { - err := datastore.InsertLayer(invalidLayer) - assert.Error(t, err) - } -} - -func testInsertLayerTree(t *testing.T, datastore database.Datastore) { - f1 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature1", - }, - Version: "1.0", - } - f2 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature2", - }, - Version: "0.34", - } - f3 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature3", - }, - Version: "0.56", - } - f4 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature2", - }, - Version: "0.34", - } - f5 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature3", - }, - Version: "0.56", - } - f6 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature4", - }, - Version: "0.666", - } - - layers := []database.Layer{ - { - Name: "TestInsertLayer1", - }, - { - Name: "TestInsertLayer2", - Parent: &database.Layer{Name: "TestInsertLayer1"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace1", - VersionFormat: dpkg.ParserName, - }}, - }, - // This layer changes the namespace and adds Features. - { - Name: "TestInsertLayer3", - Parent: &database.Layer{Name: "TestInsertLayer2"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace2", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{f1, f2, f3}, - }, - // This layer covers the case where the last layer doesn't provide any new Feature. - { - Name: "TestInsertLayer4a", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Features: []database.FeatureVersion{f1, f2, f3}, - }, - // This layer covers the case where the last layer provides Features. - // It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their - // Namespaces should then remain unchanged. - { - Name: "TestInsertLayer4b", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{ - // Deletes TestInsertLayerFeature1. - // Keep TestInsertLayerFeature2 (old Namespace should be kept): - f4, - // Upgrades TestInsertLayerFeature3 (with new Namespace): - f5, - // Adds TestInsertLayerFeature4: - f6, - }, - }, - } - - var err error - retrievedLayers := make(map[string]database.Layer) - for _, layer := range layers { - if layer.Parent != nil { - // Retrieve from database its parent and assign. - parent := retrievedLayers[layer.Parent.Name] - layer.Parent = &parent - } - - err = datastore.InsertLayer(layer) - assert.Nil(t, err) - - retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false) - assert.Nil(t, err) - } - - // layer inherits all namespaces from its ancestries - l4a := retrievedLayers["TestInsertLayer4a"] - assertExpectedNamespaceName(t, &l4a, []string{"TestInsertLayerNamespace2", "TestInsertLayerNamespace1"}) - assert.Len(t, l4a.Features, 3) - for _, featureVersion := range l4a.Features { - if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) { - assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3)) - } - } - - l4b := retrievedLayers["TestInsertLayer4b"] - assertExpectedNamespaceName(t, &l4b, []string{"TestInsertLayerNamespace1", "TestInsertLayerNamespace2", "TestInsertLayerNamespace3"}) - assert.Len(t, l4b.Features, 3) - for _, featureVersion := range l4b.Features { - if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) { - assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6)) - } - } -} - -func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) { - f7 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "TestInsertLayerNamespace3", - VersionFormat: dpkg.ParserName, - }, - Name: "TestInsertLayerFeature7", - }, - Version: "0.01", - } - - l3, _ := datastore.FindLayer("TestInsertLayer3", true, false) - l3u := database.Layer{ - Name: l3.Name, - Parent: l3.Parent, - Namespaces: []database.Namespace{database.Namespace{ - Name: "TestInsertLayerNamespaceUpdated1", - VersionFormat: dpkg.ParserName, - }}, - Features: []database.FeatureVersion{f7}, - } - - l4u := database.Layer{ - Name: "TestInsertLayer4", - Parent: &database.Layer{Name: "TestInsertLayer3"}, - Features: []database.FeatureVersion{f7}, - EngineVersion: 2, - } - - // Try to re-insert without increasing the EngineVersion. - err := datastore.InsertLayer(l3u) + // invalid + _, _, _, err := tx.FindLayer("") + assert.NotNil(t, err) + _, _, ok, err := tx.FindLayer("layer-non") assert.Nil(t, err) + assert.False(t, ok) - l3uf, err := datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3, &l3uf) - assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion) - assert.Len(t, l3uf.Features, len(l3.Features)) + // valid + layer, processors, ok2, err := tx.FindLayer("layer-4") + if assert.Nil(t, err) && assert.True(t, ok2) { + assert.Equal(t, expected, layer) + assertProcessorsEqual(t, expectedProcessors, processors) } +} - // Update layer l3. - // Verify that the Namespace, EngineVersion and FeatureVersions got updated. - l3u.EngineVersion = 2 - err = datastore.InsertLayer(l3u) +func TestFindLayerWithContent(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindLayerWithContent", true) + defer closeTest(t, datastore, tx) + + _, _, err := tx.FindLayerWithContent("") + assert.NotNil(t, err) + _, ok, err := tx.FindLayerWithContent("layer-non") assert.Nil(t, err) + assert.False(t, ok) - l3uf, err = datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3u, &l3uf) - assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion) - if assert.Len(t, l3uf.Features, 1) { - assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0]) - } + expectedL := database.LayerWithContent{ + Layer: database.Layer{ + Hash: "layer-4", + }, + Features: []database.Feature{ + {Name: "fake", Version: "2.0", VersionFormat: "rpm"}, + {Name: "openssl", Version: "2.0", VersionFormat: "dpkg"}, + }, + Namespaces: []database.Namespace{ + {Name: "debian:7", VersionFormat: "dpkg"}, + {Name: "fake:1.0", VersionFormat: "rpm"}, + }, + ProcessedBy: database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"dpkg", "rpm"}, + }, } - // Update layer l4. - // Verify that the Namespace got updated from its new Parent's, and also verify the - // EnginVersion and FeatureVersions. - l4u.Parent = &l3uf - err = datastore.InsertLayer(l4u) - assert.Nil(t, err) - - l4uf, err := datastore.FindLayer(l3u.Name, true, false) - if assert.Nil(t, err) { - assertSameNamespaceName(t, &l3u, &l4uf) - assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion) - if assert.Len(t, l4uf.Features, 1) { - assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0]) - } + layer, ok2, err := tx.FindLayerWithContent("layer-4") + if assert.Nil(t, err) && assert.True(t, ok2) { + assertLayerWithContentEqual(t, expectedL, layer) } } -func assertSameNamespaceName(t *testing.T, layer1 *database.Layer, layer2 *database.Layer) { - assert.Len(t, compareStringLists(extractNamespaceName(layer1), extractNamespaceName(layer2)), 0) -} - -func assertExpectedNamespaceName(t *testing.T, layer *database.Layer, expectedNames []string) { - assert.Len(t, compareStringLists(extractNamespaceName(layer), expectedNames), 0) -} - -func extractNamespaceName(layer *database.Layer) []string { - slist := make([]string, 0, len(layer.Namespaces)) - for _, ns := range layer.Namespaces { - slist = append(slist, ns.Name) - } - return slist -} - -func testInsertLayerDelete(t *testing.T, datastore database.Datastore) { - err := datastore.DeleteLayer("TestInsertLayerX") - assert.Equal(t, commonerr.ErrNotFound, err) - - // ensure layer_namespace table is cleaned up once a layer is removed - layer3, err := datastore.FindLayer("TestInsertLayer3", false, false) - layer4a, err := datastore.FindLayer("TestInsertLayer4a", false, false) - layer4b, err := datastore.FindLayer("TestInsertLayer4b", false, false) - - err = datastore.DeleteLayer("TestInsertLayer3") - assert.Nil(t, err) - - _, err = datastore.FindLayer("TestInsertLayer3", false, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer3.ID, datastore) - _, err = datastore.FindLayer("TestInsertLayer4a", false, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer4a.ID, datastore) - _, err = datastore.FindLayer("TestInsertLayer4b", true, false) - assert.Equal(t, commonerr.ErrNotFound, err) - assertNotInLayerNamespace(t, layer4b.ID, datastore) -} - -func assertNotInLayerNamespace(t *testing.T, layerID int, datastore database.Datastore) { - pg, ok := datastore.(*pgSQL) - if !assert.True(t, ok) { - return - } - tx, err := pg.Begin() - if !assert.Nil(t, err) { - return - } - rows, err := tx.Query(searchLayerNamespace, layerID) - assert.False(t, rows.Next()) -} - -func cmpFV(a, b database.FeatureVersion) bool { - return a.Feature.Name == b.Feature.Name && - a.Feature.Namespace.Name == b.Feature.Namespace.Name && - a.Version == b.Version +func assertLayerWithContentEqual(t *testing.T, expected database.LayerWithContent, actual database.LayerWithContent) bool { + return assert.Equal(t, expected.Layer, actual.Layer) && + assertFeaturesEqual(t, expected.Features, actual.Features) && + assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) && + assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces) } diff --git a/database/pgsql/lock.go b/database/pgsql/lock.go index d3521b75..c8918ebc 100644 --- a/database/pgsql/lock.go +++ b/database/pgsql/lock.go @@ -15,6 +15,7 @@ package pgsql import ( + "errors" "time" log "github.com/sirupsen/logrus" @@ -22,86 +23,91 @@ import ( "github.com/coreos/clair/pkg/commonerr" ) +var ( + errLockNotFound = errors.New("lock is not in database") +) + // Lock tries to set a temporary lock in the database. // // Lock does not block, instead, it returns true and its expiration time -// is the lock has been successfully acquired or false otherwise -func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { +// is the lock has been successfully acquired or false otherwise. +func (tx *pgSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) { if name == "" || owner == "" || duration == 0 { log.Warning("could not create an invalid lock") - return false, time.Time{} + return false, time.Time{}, commonerr.NewBadRequestError("Invalid Lock Parameters") } - defer observeQueryTime("Lock", "all", time.Now()) - - // Compute expiration. until := time.Now().Add(duration) - if renew { + defer observeQueryTime("Lock", "update", time.Now()) // Renew lock. - r, err := pgSQL.Exec(updateLock, name, owner, until) + r, err := tx.Exec(updateLock, name, owner, until) if err != nil { - handleError("updateLock", err) - return false, until + return false, until, handleError("updateLock", err) } - if n, _ := r.RowsAffected(); n > 0 { - // Updated successfully. - return true, until + + if n, err := r.RowsAffected(); err == nil { + return n > 0, until, nil } - } else { - // Prune locks. - pgSQL.pruneLocks() + return false, until, handleError("updateLock", err) + } else if err := tx.pruneLocks(); err != nil { + return false, until, err } // Lock. - _, err := pgSQL.Exec(insertLock, name, owner, until) + defer observeQueryTime("Lock", "soiLock", time.Now()) + _, err := tx.Exec(soiLock, name, owner, until) if err != nil { - if !isErrUniqueViolation(err) { - handleError("insertLock", err) + if isErrUniqueViolation(err) { + return false, until, nil } - return false, until + return false, until, handleError("insertLock", err) } - - return true, until + return true, until, nil } // Unlock unlocks a lock specified by its name if I own it -func (pgSQL *pgSQL) Unlock(name, owner string) { +func (tx *pgSession) Unlock(name, owner string) error { if name == "" || owner == "" { - log.Warning("could not delete an invalid lock") - return + return commonerr.NewBadRequestError("Invalid Lock Parameters") } defer observeQueryTime("Unlock", "all", time.Now()) - pgSQL.Exec(removeLock, name, owner) + _, err := tx.Exec(removeLock, name, owner) + return err } // FindLock returns the owner of a lock specified by its name and its // expiration time. -func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { +func (tx *pgSession) FindLock(name string) (string, time.Time, bool, error) { if name == "" { - log.Warning("could not find an invalid lock") - return "", time.Time{}, commonerr.NewBadRequestError("could not find an invalid lock") + return "", time.Time{}, false, commonerr.NewBadRequestError("could not find an invalid lock") } defer observeQueryTime("FindLock", "all", time.Now()) var owner string var until time.Time - err := pgSQL.QueryRow(searchLock, name).Scan(&owner, &until) + err := tx.QueryRow(searchLock, name).Scan(&owner, &until) if err != nil { - return owner, until, handleError("searchLock", err) + return owner, until, false, handleError("searchLock", err) } - return owner, until, nil + return owner, until, true, nil } // pruneLocks removes every expired locks from the database -func (pgSQL *pgSQL) pruneLocks() { +func (tx *pgSession) pruneLocks() error { defer observeQueryTime("pruneLocks", "all", time.Now()) - if _, err := pgSQL.Exec(removeLockExpired); err != nil { - handleError("removeLockExpired", err) + if r, err := tx.Exec(removeLockExpired); err != nil { + return handleError("removeLockExpired", err) + } else if affected, err := r.RowsAffected(); err != nil { + return handleError("removeLockExpired", err) + } else { + log.Debugf("Pruned %d Locks", affected) } + + return nil } diff --git a/database/pgsql/lock_test.go b/database/pgsql/lock_test.go index cbd2d999..19a5a934 100644 --- a/database/pgsql/lock_test.go +++ b/database/pgsql/lock_test.go @@ -22,48 +22,72 @@ import ( ) func TestLock(t *testing.T) { - datastore, err := openDatabaseForTest("InsertNamespace", false) - if err != nil { - t.Error(err) - return - } + datastore, tx := openSessionForTest(t, "Lock", true) defer datastore.Close() var l bool var et time.Time // Create a first lock. - l, _ = datastore.Lock("test1", "owner1", time.Minute, false) + l, _, err := tx.Lock("test1", "owner1", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) - // Try to lock the same lock with another owner. - l, _ = datastore.Lock("test1", "owner2", time.Minute, true) + // lock again by itself, the previous lock is not expired yet. + l, _, err = tx.Lock("test1", "owner1", time.Minute, false) + assert.Nil(t, err) assert.False(t, l) + tx = restartSession(t, datastore, tx, false) - l, _ = datastore.Lock("test1", "owner2", time.Minute, false) + // Try to renew the same lock with another owner. + l, _, err = tx.Lock("test1", "owner2", time.Minute, true) + assert.Nil(t, err) assert.False(t, l) + tx = restartSession(t, datastore, tx, false) + + l, _, err = tx.Lock("test1", "owner2", time.Minute, false) + assert.Nil(t, err) + assert.False(t, l) + tx = restartSession(t, datastore, tx, false) // Renew the lock. - l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true) + l, _, err = tx.Lock("test1", "owner1", 2*time.Minute, true) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // Unlock and then relock by someone else. - datastore.Unlock("test1", "owner1") + err = tx.Unlock("test1", "owner1") + assert.Nil(t, err) + tx = restartSession(t, datastore, tx, true) - l, et = datastore.Lock("test1", "owner2", time.Minute, false) + l, et, err = tx.Lock("test1", "owner2", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // LockInfo - o, et2, err := datastore.FindLock("test1") + o, et2, ok, err := tx.FindLock("test1") + assert.True(t, ok) assert.Nil(t, err) assert.Equal(t, "owner2", o) assert.Equal(t, et.Second(), et2.Second()) + tx = restartSession(t, datastore, tx, true) // Create a second lock which is actually already expired ... - l, _ = datastore.Lock("test2", "owner1", -time.Minute, false) + l, _, err = tx.Lock("test2", "owner1", -time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) // Take over the lock - l, _ = datastore.Lock("test2", "owner2", time.Minute, false) + l, _, err = tx.Lock("test2", "owner2", time.Minute, false) + assert.Nil(t, err) assert.True(t, l) + tx = restartSession(t, datastore, tx, true) + + if !assert.Nil(t, tx.Rollback()) { + t.FailNow() + } } diff --git a/database/pgsql/migrations/00001_change_migrator.go b/database/pgsql/migrations/00001_change_migrator.go deleted file mode 100644 index 8fef9ea0..00000000 --- a/database/pgsql/migrations/00001_change_migrator.go +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import ( - "database/sql" - - "github.com/remind101/migrate" -) - -func init() { - // This migration removes the data maintained by the previous migration tool - // (liamstask/goose), and if it was present, mark the 00002_initial_schema - // migration as done. - RegisterMigration(migrate.Migration{ - ID: 1, - Up: func(tx *sql.Tx) error { - // Verify that goose was in use before, otherwise skip this migration. - var e bool - err := tx.QueryRow("SELECT true FROM pg_class WHERE relname = $1", "goose_db_version").Scan(&e) - if err == sql.ErrNoRows { - return nil - } - if err != nil { - return err - } - - // Delete goose's data. - _, err = tx.Exec("DROP TABLE goose_db_version CASCADE") - if err != nil { - return err - } - - // Mark the '00002_initial_schema' as done. - _, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES (2)") - - return err - }, - Down: migrate.Queries([]string{}), - }) -} diff --git a/database/pgsql/migrations/00001_initial_schema.go b/database/pgsql/migrations/00001_initial_schema.go new file mode 100644 index 00000000..14fff7d4 --- /dev/null +++ b/database/pgsql/migrations/00001_initial_schema.go @@ -0,0 +1,192 @@ +// Copyright 2016 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package migrations + +import "github.com/remind101/migrate" + +func init() { + RegisterMigration(migrate.Migration{ + ID: 1, + Up: migrate.Queries([]string{ + // namespaces + `CREATE TABLE IF NOT EXISTS namespace ( + id SERIAL PRIMARY KEY, + name TEXT NULL, + version_format TEXT, + UNIQUE (name, version_format));`, + `CREATE INDEX ON namespace(name);`, + + // features + `CREATE TABLE IF NOT EXISTS feature ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + version TEXT NOT NULL, + version_format TEXT NOT NULL, + UNIQUE (name, version, version_format));`, + `CREATE INDEX ON feature(name);`, + + `CREATE TABLE IF NOT EXISTS namespaced_feature ( + id SERIAL PRIMARY KEY, + namespace_id INT REFERENCES namespace, + feature_id INT REFERENCES feature, + UNIQUE (namespace_id, feature_id));`, + + // layers + `CREATE TABLE IF NOT EXISTS layer( + id SERIAL PRIMARY KEY, + hash TEXT NOT NULL UNIQUE);`, + + `CREATE TABLE IF NOT EXISTS layer_feature ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + feature_id INT REFERENCES feature ON DELETE CASCADE, + UNIQUE (layer_id, feature_id));`, + `CREATE INDEX ON layer_feature(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_lister ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + lister TEXT NOT NULL, + UNIQUE (layer_id, lister));`, + `CREATE INDEX ON layer_lister(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_detector ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + detector TEXT, + UNIQUE (layer_id, detector));`, + `CREATE INDEX ON layer_detector(layer_id);`, + + `CREATE TABLE IF NOT EXISTS layer_namespace ( + id SERIAL PRIMARY KEY, + layer_id INT REFERENCES layer ON DELETE CASCADE, + namespace_id INT REFERENCES namespace ON DELETE CASCADE, + UNIQUE (layer_id, namespace_id));`, + `CREATE INDEX ON layer_namespace(layer_id);`, + + // ancestry + `CREATE TABLE IF NOT EXISTS ancestry ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL UNIQUE);`, + + `CREATE TABLE IF NOT EXISTS ancestry_layer ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + ancestry_index INT NOT NULL, + layer_id INT REFERENCES layer ON DELETE RESTRICT, + UNIQUE (ancestry_id, ancestry_index));`, + `CREATE INDEX ON ancestry_layer(ancestry_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_feature ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + namespaced_feature_id INT REFERENCES namespaced_feature ON DELETE CASCADE, + UNIQUE (ancestry_id, namespaced_feature_id));`, + + `CREATE TABLE IF NOT EXISTS ancestry_lister ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + lister TEXT, + UNIQUE (ancestry_id, lister));`, + `CREATE INDEX ON ancestry_lister(ancestry_id);`, + + `CREATE TABLE IF NOT EXISTS ancestry_detector ( + id SERIAL PRIMARY KEY, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, + detector TEXT, + UNIQUE (ancestry_id, detector));`, + `CREATE INDEX ON ancestry_detector(ancestry_id);`, + + `CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`, + + // vulnerability + `CREATE TABLE IF NOT EXISTS vulnerability ( + id SERIAL PRIMARY KEY, + namespace_id INT NOT NULL REFERENCES Namespace, + name TEXT NOT NULL, + description TEXT NULL, + link TEXT NULL, + severity severity NOT NULL, + metadata TEXT NULL, + created_at TIMESTAMP WITH TIME ZONE, + deleted_at TIMESTAMP WITH TIME ZONE NULL);`, + `CREATE INDEX ON vulnerability(namespace_id, name);`, + `CREATE INDEX ON vulnerability(namespace_id);`, + + `CREATE TABLE IF NOT EXISTS vulnerability_affected_feature ( + id SERIAL PRIMARY KEY, + vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE, + feature_name TEXT NOT NULL, + affected_version TEXT, + fixedin TEXT);`, + `CREATE INDEX ON vulnerability_affected_feature(vulnerability_id, feature_name);`, + + `CREATE TABLE IF NOT EXISTS vulnerability_affected_namespaced_feature( + id SERIAL PRIMARY KEY, + vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE, + namespaced_feature_id INT NOT NULL REFERENCES namespaced_feature ON DELETE CASCADE, + added_by INT NOT NULL REFERENCES vulnerability_affected_feature ON DELETE CASCADE, + UNIQUE (vulnerability_id, namespaced_feature_id));`, + `CREATE INDEX ON vulnerability_affected_namespaced_feature(namespaced_feature_id);`, + + `CREATE TABLE IF NOT EXISTS KeyValue ( + id SERIAL PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + value TEXT);`, + + `CREATE TABLE IF NOT EXISTS Lock ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + owner VARCHAR(64) NOT NULL, + until TIMESTAMP WITH TIME ZONE);`, + `CREATE INDEX ON Lock (owner);`, + + // Notification + `CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + created_at TIMESTAMP WITH TIME ZONE, + notified_at TIMESTAMP WITH TIME ZONE NULL, + deleted_at TIMESTAMP WITH TIME ZONE NULL, + old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE, + new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`, + `CREATE INDEX ON Vulnerability_Notification (notified_at);`, + }), + Down: migrate.Queries([]string{ + `DROP TABLE IF EXISTS + ancestry, + ancestry_layer, + ancestry_feature, + ancestry_detector, + ancestry_lister, + feature, + namespaced_feature, + keyvalue, + layer, + layer_detector, + layer_feature, + layer_lister, + layer_namespace, + lock, + namespace, + vulnerability, + vulnerability_affected_feature, + vulnerability_affected_namespaced_feature, + vulnerability_notification + CASCADE;`, + `DROP TYPE IF EXISTS severity;`, + }), + }) +} diff --git a/database/pgsql/migrations/00002_initial_schema.go b/database/pgsql/migrations/00002_initial_schema.go deleted file mode 100644 index f7cc17e6..00000000 --- a/database/pgsql/migrations/00002_initial_schema.go +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - // This migration creates the initial Clair's schema. - RegisterMigration(migrate.Migration{ - ID: 2, - Up: migrate.Queries([]string{ - `CREATE TABLE IF NOT EXISTS Namespace ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NULL);`, - - `CREATE TABLE IF NOT EXISTS Layer ( - id SERIAL PRIMARY KEY, - name VARCHAR(128) NOT NULL UNIQUE, - engineversion SMALLINT NOT NULL, - parent_id INT NULL REFERENCES Layer ON DELETE CASCADE, - namespace_id INT NULL REFERENCES Namespace, - created_at TIMESTAMP WITH TIME ZONE);`, - `CREATE INDEX ON Layer (parent_id);`, - `CREATE INDEX ON Layer (namespace_id);`, - - `CREATE TABLE IF NOT EXISTS Feature ( - id SERIAL PRIMARY KEY, - namespace_id INT NOT NULL REFERENCES Namespace, - name VARCHAR(128) NOT NULL, - UNIQUE (namespace_id, name));`, - - `CREATE TABLE IF NOT EXISTS FeatureVersion ( - id SERIAL PRIMARY KEY, - feature_id INT NOT NULL REFERENCES Feature, - version VARCHAR(128) NOT NULL);`, - `CREATE INDEX ON FeatureVersion (feature_id);`, - - `CREATE TYPE modification AS ENUM ('add', 'del');`, - `CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion ( - id SERIAL PRIMARY KEY, - layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE, - featureversion_id INT NOT NULL REFERENCES FeatureVersion, - modification modification NOT NULL, - UNIQUE (layer_id, featureversion_id));`, - `CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);`, - `CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);`, - `CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);`, - - `CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`, - `CREATE TABLE IF NOT EXISTS Vulnerability ( - id SERIAL PRIMARY KEY, - namespace_id INT NOT NULL REFERENCES Namespace, - name VARCHAR(128) NOT NULL, - description TEXT NULL, - link VARCHAR(128) NULL, - severity severity NOT NULL, - metadata TEXT NULL, - created_at TIMESTAMP WITH TIME ZONE, - deleted_at TIMESTAMP WITH TIME ZONE NULL);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature ( - id SERIAL PRIMARY KEY, - vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE, - feature_id INT NOT NULL REFERENCES Feature, - version VARCHAR(128) NOT NULL, - UNIQUE (vulnerability_id, feature_id));`, - `CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion ( - id SERIAL PRIMARY KEY, - vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE, - featureversion_id INT NOT NULL REFERENCES FeatureVersion, - fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE, - UNIQUE (vulnerability_id, featureversion_id));`, - `CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);`, - `CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);`, - - `CREATE TABLE IF NOT EXISTS KeyValue ( - id SERIAL PRIMARY KEY, - key VARCHAR(128) NOT NULL UNIQUE, - value TEXT);`, - - `CREATE TABLE IF NOT EXISTS Lock ( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL UNIQUE, - owner VARCHAR(64) NOT NULL, - until TIMESTAMP WITH TIME ZONE);`, - `CREATE INDEX ON Lock (owner);`, - - `CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( - id SERIAL PRIMARY KEY, - name VARCHAR(64) NOT NULL UNIQUE, - created_at TIMESTAMP WITH TIME ZONE, - notified_at TIMESTAMP WITH TIME ZONE NULL, - deleted_at TIMESTAMP WITH TIME ZONE NULL, - old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE, - new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`, - `CREATE INDEX ON Vulnerability_Notification (notified_at);`, - }), - Down: migrate.Queries([]string{ - `DROP TABLE IF EXISTS - Namespace, - Layer, - Feature, - FeatureVersion, - Layer_diff_FeatureVersion, - Vulnerability, - Vulnerability_FixedIn_Feature, - Vulnerability_Affects_FeatureVersion, - Vulnerability_Notification, - KeyValue, - Lock - CASCADE;`, - }), - }) -} diff --git a/database/pgsql/migrations/00003_add_indexes.go b/database/pgsql/migrations/00003_add_indexes.go deleted file mode 100644 index 78ccaba2..00000000 --- a/database/pgsql/migrations/00003_add_indexes.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 3, - Up: migrate.Queries([]string{ - `CREATE UNIQUE INDEX namespace_name_key ON Namespace (name);`, - `CREATE INDEX vulnerability_name_idx ON Vulnerability (name);`, - `CREATE INDEX vulnerability_namespace_id_name_idx ON Vulnerability (namespace_id, name);`, - `CREATE UNIQUE INDEX featureversion_feature_id_version_key ON FeatureVersion (feature_id, version);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX namespace_name_key;`, - `DROP INDEX vulnerability_name_idx;`, - `DROP INDEX vulnerability_namespace_id_name_idx;`, - `DROP INDEX featureversion_feature_id_version_key;`, - }), - }) -} diff --git a/database/pgsql/migrations/00004_add_index_notification_deleted_at.go b/database/pgsql/migrations/00004_add_index_notification_deleted_at.go deleted file mode 100644 index 12f38ab2..00000000 --- a/database/pgsql/migrations/00004_add_index_notification_deleted_at.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 4, - Up: migrate.Queries([]string{ - `CREATE INDEX vulnerability_notification_deleted_at_idx ON Vulnerability_Notification (deleted_at);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX vulnerability_notification_deleted_at_idx;`, - }), - }) -} diff --git a/database/pgsql/migrations/00005_ldfv_index.go b/database/pgsql/migrations/00005_ldfv_index.go deleted file mode 100644 index ec8e7137..00000000 --- a/database/pgsql/migrations/00005_ldfv_index.go +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 5, - Up: migrate.Queries([]string{ - `CREATE INDEX layer_diff_featureversion_layer_id_modification_idx ON Layer_diff_FeatureVersion (layer_id, modification);`, - }), - Down: migrate.Queries([]string{ - `DROP INDEX layer_diff_featureversion_layer_id_modification_idx;`, - }), - }) -} diff --git a/database/pgsql/migrations/00006_add_version_format.go b/database/pgsql/migrations/00006_add_version_format.go deleted file mode 100644 index 3a08f6f0..00000000 --- a/database/pgsql/migrations/00006_add_version_format.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 6, - Up: migrate.Queries([]string{ - `ALTER TABLE Namespace ADD COLUMN version_format varchar(128);`, - `UPDATE Namespace SET version_format = 'rpm' WHERE name LIKE 'rhel%' OR name LIKE 'centos%' OR name LIKE 'fedora%' OR name LIKE 'amzn%' OR name LIKE 'scientific%' OR name LIKE 'ol%' OR name LIKE 'oracle%';`, - `UPDATE Namespace SET version_format = 'dpkg' WHERE version_format is NULL;`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE Namespace DROP COLUMN version_format;`, - }), - }) -} diff --git a/database/pgsql/migrations/00007_expand_column_width.go b/database/pgsql/migrations/00007_expand_column_width.go deleted file mode 100644 index 8bfdaaab..00000000 --- a/database/pgsql/migrations/00007_expand_column_width.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 7, - Up: migrate.Queries([]string{ - `ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(256);`, - `ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(256);`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(128);`, - `ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(128);`, - }), - }) -} diff --git a/database/pgsql/migrations/00008_add_multiplens.go b/database/pgsql/migrations/00008_add_multiplens.go deleted file mode 100644 index ecfb4762..00000000 --- a/database/pgsql/migrations/00008_add_multiplens.go +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2016 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package migrations - -import "github.com/remind101/migrate" - -func init() { - RegisterMigration(migrate.Migration{ - ID: 8, - Up: migrate.Queries([]string{ - // set on deletion, remove the corresponding rows in database - `CREATE TABLE IF NOT EXISTS Layer_Namespace( - id SERIAL PRIMARY KEY, - layer_id INT REFERENCES Layer(id) ON DELETE CASCADE, - namespace_id INT REFERENCES Namespace(id) ON DELETE CASCADE, - unique(layer_id, namespace_id) - );`, - `CREATE INDEX ON Layer_Namespace (namespace_id);`, - `CREATE INDEX ON Layer_Namespace (layer_id);`, - // move the namespace_id to the table - `INSERT INTO Layer_Namespace (layer_id, namespace_id) SELECT id, namespace_id FROM Layer;`, - // alter the Layer table to remove the column - `ALTER TABLE IF EXISTS Layer DROP namespace_id;`, - }), - Down: migrate.Queries([]string{ - `ALTER TABLE IF EXISTS Layer ADD namespace_id INT NULL REFERENCES Namespace;`, - `CREATE INDEX ON Layer (namespace_id);`, - `UPDATE IF EXISTS Layer SET namespace_id = (SELECT lns.namespace_id FROM Layer_Namespace lns WHERE Layer.id = lns.layer_id LIMIT 1);`, - `DROP TABLE IF EXISTS Layer_Namespace;`, - }), - }) -} diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index 8d4b304b..1a78f837 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -15,61 +15,82 @@ package pgsql import ( - "time" + "database/sql" + "errors" + "sort" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { - if namespace.Name == "" { - return 0, commonerr.NewBadRequestError("could not find/insert invalid Namespace") +var ( + errNamespaceNotFound = errors.New("Requested Namespace is not in database") +) + +// PersistNamespaces soi namespaces into database. +func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { + if len(namespaces) == 0 { + return nil } - if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("namespace").Inc() - if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found { - promCacheHitsTotal.WithLabelValues("namespace").Inc() - return id.(int), nil + // Sorting is needed before inserting into database to prevent deadlock. + sort.Slice(namespaces, func(i, j int) bool { + return namespaces[i].Name < namespaces[j].Name && + namespaces[i].VersionFormat < namespaces[j].VersionFormat + }) + + keys := make([]interface{}, len(namespaces)*2) + for i, ns := range namespaces { + if ns.Name == "" || ns.VersionFormat == "" { + return commonerr.NewBadRequestError("Empty namespace name or version format is not allowed") } + keys[i*2] = ns.Name + keys[i*2+1] = ns.VersionFormat } - // We do `defer observeQueryTime` here because we don't want to observe cached namespaces. - defer observeQueryTime("insertNamespace", "all", time.Now()) - - var id int - err := pgSQL.QueryRow(soiNamespace, namespace.Name, namespace.VersionFormat).Scan(&id) + _, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...) if err != nil { - return 0, handleError("soiNamespace", err) + return handleError("queryPersistNamespace", err) } - - if pgSQL.cache != nil { - pgSQL.cache.Add("namespace:"+namespace.Name, id) - } - - return id, nil + return nil } -func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error) { - rows, err := pgSQL.Query(listNamespace) - if err != nil { - return namespaces, handleError("listNamespace", err) +func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) { + if len(namespaces) == 0 { + return nil, nil } + + keys := make([]interface{}, len(namespaces)*2) + nsMap := map[database.Namespace]sql.NullInt64{} + for i, n := range namespaces { + keys[i*2] = n.Name + keys[i*2+1] = n.VersionFormat + nsMap[n] = sql.NullInt64{} + } + + rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...) + if err != nil { + return nil, handleError("searchNamespace", err) + } + defer rows.Close() + var ( + id sql.NullInt64 + ns database.Namespace + ) for rows.Next() { - var ns database.Namespace - - err = rows.Scan(&ns.ID, &ns.Name, &ns.VersionFormat) + err := rows.Scan(&id, &ns.Name, &ns.VersionFormat) if err != nil { - return namespaces, handleError("listNamespace.Scan()", err) + return nil, handleError("searchNamespace", err) } - - namespaces = append(namespaces, ns) - } - if err = rows.Err(); err != nil { - return namespaces, handleError("listNamespace.Rows()", err) + nsMap[ns] = id } - return namespaces, err + ids := make([]sql.NullInt64, len(namespaces)) + for i, ns := range namespaces { + ids[i] = nsMap[ns] + } + + return ids, nil } diff --git a/database/pgsql/namespace_test.go b/database/pgsql/namespace_test.go index 0990b6f4..27ceefef 100644 --- a/database/pgsql/namespace_test.go +++ b/database/pgsql/namespace_test.go @@ -15,60 +15,69 @@ package pgsql import ( - "fmt" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt/dpkg" ) -func TestInsertNamespace(t *testing.T) { - datastore, err := openDatabaseForTest("InsertNamespace", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() +func TestPersistNamespaces(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistNamespaces", false) + defer closeTest(t, datastore, tx) - // Invalid Namespace. - id0, err := datastore.insertNamespace(database.Namespace{}) - assert.NotNil(t, err) - assert.Zero(t, id0) + ns1 := database.Namespace{} + ns2 := database.Namespace{Name: "t", VersionFormat: "b"} - // Insert Namespace and ensure we can find it. - id1, err := datastore.insertNamespace(database.Namespace{ - Name: "TestInsertNamespace1", - VersionFormat: dpkg.ParserName, - }) - assert.Nil(t, err) - id2, err := datastore.insertNamespace(database.Namespace{ - Name: "TestInsertNamespace1", - VersionFormat: dpkg.ParserName, - }) - assert.Nil(t, err) - assert.Equal(t, id1, id2) + // Empty Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{})) + // Invalid Case + assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1})) + // Duplicated Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2})) + // Existing Case + assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2})) + + nsList := listNamespaces(t, tx) + assert.Len(t, nsList, 1) + assert.Equal(t, ns2, nsList[0]) } -func TestListNamespace(t *testing.T) { - datastore, err := openDatabaseForTest("ListNamespaces", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - namespaces, err := datastore.ListNamespaces() - assert.Nil(t, err) - if assert.Len(t, namespaces, 2) { - for _, namespace := range namespaces { - switch namespace.Name { - case "debian:7", "debian:8": - continue - default: - assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name)) +func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Namespace]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + has[i] = true + } + for key, v := range has { + if !assert.True(t, v, key.Name+"is expected") { + return false } } + return true } + return false +} + +func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { + rows, err := tx.Query("SELECT name, version_format FROM namespace") + if err != nil { + t.FailNow() + } + defer rows.Close() + + namespaces := []database.Namespace{} + for rows.Next() { + var ns database.Namespace + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + t.FailNow() + } + namespaces = append(namespaces, ns) + } + + return namespaces } diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index f8c6960d..ebc346d3 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -16,235 +16,320 @@ package pgsql import ( "database/sql" + "errors" "time" "github.com/guregu/null/zero" - "github.com/pborman/uuid" - log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) -// do it in tx so we won't insert/update a vuln without notification and vice-versa. -// name and created doesn't matter. -func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error { - defer observeQueryTime("createNotification", "all", time.Now()) +var ( + errNotificationNotFound = errors.New("requested notification is not found") +) - // Insert Notification. - oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} - newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0} - _, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) +func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { + if len(notifications) == 0 { + return nil + } + + var ( + newVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64) + oldVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64) + ) + + invalidCreationTime := time.Time{} + for _, noti := range notifications { + if noti.Name == "" { + return commonerr.NewBadRequestError("notification should not have empty name") + } + if noti.Created == invalidCreationTime { + return commonerr.NewBadRequestError("notification should not have empty created time") + } + + if noti.New != nil { + key := database.VulnerabilityID{ + Name: noti.New.Name, + Namespace: noti.New.Namespace.Name, + } + newVulnIDMap[key] = sql.NullInt64{} + } + + if noti.Old != nil { + key := database.VulnerabilityID{ + Name: noti.Old.Name, + Namespace: noti.Old.Namespace.Name, + } + oldVulnIDMap[key] = sql.NullInt64{} + } + } + + var ( + newVulnIDs = make([]database.VulnerabilityID, 0, len(newVulnIDMap)) + oldVulnIDs = make([]database.VulnerabilityID, 0, len(oldVulnIDMap)) + ) + + for vulnID := range newVulnIDMap { + newVulnIDs = append(newVulnIDs, vulnID) + } + + for vulnID := range oldVulnIDMap { + oldVulnIDs = append(oldVulnIDs, vulnID) + } + + ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs) if err != nil { - tx.Rollback() - return handleError("insertNotification", err) + return err + } + + for i, id := range ids { + if !id.Valid { + return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) + } + newVulnIDMap[newVulnIDs[i]] = id + } + + ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs) + if err != nil { + return err + } + + for i, id := range ids { + if !id.Valid { + return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) + } + oldVulnIDMap[oldVulnIDs[i]] = id + } + + var ( + newVulnID sql.NullInt64 + oldVulnID sql.NullInt64 + ) + + keys := make([]interface{}, len(notifications)*4) + for i, noti := range notifications { + if noti.New != nil { + newVulnID = newVulnIDMap[database.VulnerabilityID{ + Name: noti.New.Name, + Namespace: noti.New.Namespace.Name, + }] + } + + if noti.Old != nil { + oldVulnID = oldVulnIDMap[database.VulnerabilityID{ + Name: noti.Old.Name, + Namespace: noti.Old.Namespace.Name, + }] + } + + keys[4*i] = noti.Name + keys[4*i+1] = noti.Created + keys[4*i+2] = oldVulnID + keys[4*i+3] = newVulnID + } + + // NOTE(Sida): The data is not sorted before inserting into database under + // the fact that there's only one updater running at a time. If there are + // multiple updaters, deadlock may happen. + _, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...) + if err != nil { + return handleError("queryInsertNotifications", err) } return nil } -// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)). -// Does not fill new/old vuln. -func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) { - defer observeQueryTime("GetAvailableNotification", "all", time.Now()) - - before := time.Now().Add(-renotifyInterval) - row := pgSQL.QueryRow(searchNotificationAvailable, before) - notification, err := pgSQL.scanNotification(row, false) - - return notification, handleError("searchNotificationAvailable", err) -} - -func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) { - defer observeQueryTime("GetNotification", "all", time.Now()) - - // Get Notification. - notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true) - if err != nil { - return notification, page, handleError("searchNotification", err) - } - - // Load vulnerabilities' LayersIntroducingVulnerability. - page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability( - notification.OldVulnerability, - limit, - page.OldVulnerability, +func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) { + var ( + notification database.NotificationHook + created zero.Time + notified zero.Time + deleted zero.Time ) + err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(¬ification.Name, &created, ¬ified, &deleted) if err != nil { - return notification, page, err - } - - page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability( - notification.NewVulnerability, - limit, - page.NewVulnerability, - ) - - if err != nil { - return notification, page, err - } - - return notification, page, nil -} - -func (pgSQL *pgSQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) { - var notification database.VulnerabilityNotification - var created zero.Time - var notified zero.Time - var deleted zero.Time - var oldVulnerabilityNullableID sql.NullInt64 - var newVulnerabilityNullableID sql.NullInt64 - - // Scan notification. - if hasVulns { - err := row.Scan( - ¬ification.ID, - ¬ification.Name, - &created, - ¬ified, - &deleted, - &oldVulnerabilityNullableID, - &newVulnerabilityNullableID, - ) - - if err != nil { - return notification, err - } - } else { - err := row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted) - - if err != nil { - return notification, err + if err == sql.ErrNoRows { + return notification, false, nil } + return notification, false, handleError("searchNotificationAvailable", err) } notification.Created = created.Time notification.Notified = notified.Time notification.Deleted = deleted.Time - if hasVulns { - if oldVulnerabilityNullableID.Valid { - vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64)) - if err != nil { - return notification, err - } - - notification.OldVulnerability = &vulnerability - } - - if newVulnerabilityNullableID.Valid { - vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64)) - if err != nil { - return notification, err - } - - notification.NewVulnerability = &vulnerability - } - } - - return notification, nil + return notification, true, nil } -// Fills Vulnerability.LayersIntroducingVulnerability. -// limit -1: won't do anything -// limit 0: will just get the startID of the second page -func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) { - tf := time.Now() - - if vulnerability == nil { - return -1, nil +func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { + vulnPage := database.PagedVulnerableAncestries{Limit: limit} + current := idPageNumber{0} + if currentPage != "" { + var err error + current, err = decryptPage(currentPage, tx.paginationKey) + if err != nil { + return vulnPage, err + } } - // A startID equals to -1 means that we reached the end already. - if startID == -1 || limit == -1 { - return -1, nil - } - - // Create a transaction to disable hash joins as our experience shows that - // PostgreSQL plans in certain cases a sequential scan and a hash on - // Layer_diff_FeatureVersion for the condition `ldfv.layer_id >= $2 AND - // ldfv.modification = 'add'` before realizing a hash inner join with - // Vulnerability_Affects_FeatureVersion. By disabling explictly hash joins, - // we force PostgreSQL to perform a bitmap index scan with - // `ldfv.featureversion_id = fv.id` on Layer_diff_FeatureVersion, followed by - // a bitmap heap scan on `ldfv.layer_id >= $2 AND ldfv.modification = 'add'`, - // thus avoiding a sequential scan on the biggest database table and - // allowing a small nested loop join instead. - tx, err := pgSQL.Begin() + err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( + &vulnPage.Name, + &vulnPage.Description, + &vulnPage.Link, + &vulnPage.Severity, + &vulnPage.Metadata, + &vulnPage.Namespace.Name, + &vulnPage.Namespace.VersionFormat, + ) if err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Begin()", err) - } - defer tx.Commit() - - _, err = tx.Exec(disableHashJoin) - if err != nil { - log.WithError(err).Warning("searchNotificationLayerIntroducingVulnerability: could not disable hash join") + return vulnPage, handleError("searchVulnerabilityByID", err) } - // We do `defer observeQueryTime` here because we don't want to observe invalid calls. - defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf) - - // Query with limit + 1, the last item will be used to know the next starting ID. - rows, err := tx.Query(searchNotificationLayerIntroducingVulnerability, - vulnerability.ID, startID, limit+1) + // the last result is used for the next page's startID + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) if err != nil { - return 0, handleError("searchNotificationLayerIntroducingVulnerability", err) + return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } defer rows.Close() - var layers []database.Layer + ancestries := []affectedAncestry{} for rows.Next() { - var layer database.Layer - - if err := rows.Scan(&layer.ID, &layer.Name); err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err) + var ancestry affectedAncestry + err := rows.Scan(&ancestry.id, &ancestry.name) + if err != nil { + return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } - - layers = append(layers, layer) - } - if err = rows.Err(); err != nil { - return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err) + ancestries = append(ancestries, ancestry) } - size := limit - if len(layers) < limit { - size = len(layers) - } - vulnerability.LayersIntroducingVulnerability = layers[:size] + lastIndex := 0 + if len(ancestries)-1 < limit { + lastIndex = len(ancestries) + vulnPage.End = true + } else { + // Use the last ancestry's ID as the next PageNumber. + lastIndex = len(ancestries) - 1 + vulnPage.Next, err = encryptPage( + idPageNumber{ + ancestries[len(ancestries)-1].id, + }, tx.paginationKey) - nextID := -1 - if len(layers) > limit { - nextID = layers[limit].ID + if err != nil { + return vulnPage, err + } } - return nextID, nil + vulnPage.Affected = map[int]string{} + for _, ancestry := range ancestries[0:lastIndex] { + vulnPage.Affected[int(ancestry.id)] = ancestry.name + } + + vulnPage.Current, err = encryptPage(current, tx.paginationKey) + if err != nil { + return vulnPage, err + } + + return vulnPage, nil } -func (pgSQL *pgSQL) SetNotificationNotified(name string) error { - defer observeQueryTime("SetNotificationNotified", "all", time.Now()) +func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( + database.VulnerabilityNotificationWithVulnerable, bool, error) { + var ( + noti database.VulnerabilityNotificationWithVulnerable + oldVulnID sql.NullInt64 + newVulnID sql.NullInt64 + created zero.Time + notified zero.Time + deleted zero.Time + ) - if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil { + if name == "" { + return noti, false, commonerr.NewBadRequestError("Empty notification name is not allowed") + } + + noti.Name = name + + err := tx.QueryRow(searchNotification, name).Scan(&created, ¬ified, + &deleted, &oldVulnID, &newVulnID) + + if err != nil { + if err == sql.ErrNoRows { + return noti, false, nil + } + return noti, false, handleError("searchNotification", err) + } + + if created.Valid { + noti.Created = created.Time + } + + if notified.Valid { + noti.Notified = notified.Time + } + + if deleted.Valid { + noti.Deleted = deleted.Time + } + + if oldVulnID.Valid { + page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) + if err != nil { + return noti, false, err + } + noti.Old = &page + } + + if newVulnID.Valid { + page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) + if err != nil { + return noti, false, err + } + noti.New = &page + } + + return noti, true, nil +} + +func (tx *pgSession) MarkNotificationNotified(name string) error { + if name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") + } + + r, err := tx.Exec(updatedNotificationNotified, name) + if err != nil { return handleError("updatedNotificationNotified", err) } + + affected, err := r.RowsAffected() + if err != nil { + return handleError("updatedNotificationNotified", err) + } + + if affected <= 0 { + return handleError("updatedNotificationNotified", errNotificationNotFound) + } return nil } -func (pgSQL *pgSQL) DeleteNotification(name string) error { - defer observeQueryTime("DeleteNotification", "all", time.Now()) +func (tx *pgSession) DeleteNotification(name string) error { + if name == "" { + return commonerr.NewBadRequestError("Empty notification name is not allowed") + } - result, err := pgSQL.Exec(removeNotification, name) + result, err := tx.Exec(removeNotification, name) if err != nil { return handleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("removeNotification.RowsAffected()", err) + return handleError("removeNotification", err) } if affected <= 0 { - return commonerr.ErrNotFound + return handleError("removeNotification", commonerr.ErrNotFound) } return nil diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 24e79246..0d930d08 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -21,211 +21,225 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" ) -func TestNotification(t *testing.T) { - datastore, err := openDatabaseForTest("Notification", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() +func TestPagination(t *testing.T) { + datastore, tx := openSessionForTest(t, "Pagination", true) + defer closeTest(t, datastore, tx) - // Try to get a notification when there is none. - _, err = datastore.GetAvailableNotification(time.Second) - assert.Equal(t, commonerr.ErrNotFound, err) - - // Create some data. - f1 := database.Feature{ - Name: "TestNotificationFeature1", - Namespace: database.Namespace{ - Name: "TestNotificationNamespace1", - VersionFormat: dpkg.ParserName, - }, + ns := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", } - f2 := database.Feature{ - Name: "TestNotificationFeature2", - Namespace: database.Namespace{ - Name: "TestNotificationNamespace1", - VersionFormat: dpkg.ParserName, - }, + vNew := database.Vulnerability{ + Namespace: ns, + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, } - l1 := database.Layer{ - Name: "TestNotificationLayer1", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.1", - }, - }, + vOld := database.Vulnerability{ + Namespace: ns, + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, } - l2 := database.Layer{ - Name: "TestNotificationLayer2", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.2", - }, - }, + noti, ok, err := tx.FindVulnerabilityNotification("test", 1, "", "") + oldPage := database.PagedVulnerableAncestries{ + Vulnerability: vOld, + Limit: 1, + Affected: make(map[int]string), + End: true, } - l3 := database.Layer{ - Name: "TestNotificationLayer3", - Features: []database.FeatureVersion{ - { - Feature: f1, - Version: "0.3", - }, - }, + newPage1 := database.PagedVulnerableAncestries{ + Vulnerability: vNew, + Limit: 1, + Affected: map[int]string{3: "ancestry-3"}, + End: false, } - l4 := database.Layer{ - Name: "TestNotificationLayer4", - Features: []database.FeatureVersion{ - { - Feature: f2, - Version: "0.1", - }, - }, + newPage2 := database.PagedVulnerableAncestries{ + Vulnerability: vNew, + Limit: 1, + Affected: map[int]string{4: "ancestry-4"}, + Next: "", + End: true, } - if !assert.Nil(t, datastore.InsertLayer(l1)) || - !assert.Nil(t, datastore.InsertLayer(l2)) || - !assert.Nil(t, datastore.InsertLayer(l3)) || - !assert.Nil(t, datastore.InsertLayer(l4)) { - return - } - - // Insert a new vulnerability that is introduced by three layers. - v1 := database.Vulnerability{ - Name: "TestNotificationVulnerability1", - Namespace: f1.Namespace, - Description: "TestNotificationDescription1", - Link: "TestNotificationLink1", - Severity: "Unknown", - FixedIn: []database.FeatureVersion{ - { - Feature: f1, - Version: "1.0", - }, - }, - } - assert.Nil(t, datastore.insertVulnerability(v1, false, true)) - - // Get the notification associated to the previously inserted vulnerability. - notification, err := datastore.GetAvailableNotification(time.Second) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - // Verify the renotify behaviour. - if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) { - _, err := datastore.GetAvailableNotification(time.Second) - assert.Equal(t, commonerr.ErrNotFound, err) - - time.Sleep(50 * time.Millisecond) - notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond) - assert.Nil(t, err) - assert.Equal(t, notification.Name, notificationB.Name) - - datastore.SetNotificationNotified(notification.Name) - } - - // Get notification. - filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage) - assert.Nil(t, filledNotification.OldVulnerability) - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2) - } - } - - // Get second page. - filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage) - if assert.Nil(t, err) { - assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage) - assert.Nil(t, filledNotification.OldVulnerability) - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) - } - } - - // Delete notification. - assert.Nil(t, datastore.DeleteNotification(notification.Name)) - - _, err = datastore.GetAvailableNotification(time.Millisecond) - assert.Equal(t, commonerr.ErrNotFound, err) - } - - // Update a vulnerability and ensure that the old/new vulnerabilities are correct. - v1b := v1 - v1b.Severity = database.HighSeverity - v1b.FixedIn = []database.FeatureVersion{ - { - Feature: f1, - Version: versionfmt.MinVersion, - }, - { - Feature: f2, - Version: versionfmt.MaxVersion, - }, - } - - if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) { - notification, err = datastore.GetAvailableNotification(time.Second) - assert.Nil(t, err) - assert.NotEmpty(t, notification.Name) - - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - if assert.NotNil(t, filledNotification.OldVulnerability) { - assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name) - assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity) - assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2) - } - - if assert.NotNil(t, filledNotification.NewVulnerability) { - assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name) - assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity) - assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) - } - - assert.Equal(t, -1, nextPage.NewVulnerability) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { + oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") } - assert.Nil(t, datastore.DeleteNotification(notification.Name)) + assert.Equal(t, int64(0), oldPageNum.StartID) + newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } + newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } + assert.Equal(t, int64(0), newPageNum.StartID) + assert.Equal(t, int64(4), newPageNextNum.StartID) + + noti.Old.Current = "" + noti.New.Current = "" + noti.New.Next = "" + assert.Equal(t, oldPage, *noti.Old) + assert.Equal(t, newPage1, *noti.New) } } - // Delete a vulnerability and verify the notification. - if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) { - notification, err = datastore.GetAvailableNotification(time.Second) - assert.Nil(t, err) - assert.NotEmpty(t, notification.Name) + page1, err := encryptPage(idPageNumber{0}, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } - if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { - filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) - if assert.Nil(t, err) { - assert.Nil(t, filledNotification.NewVulnerability) + page2, err := encryptPage(idPageNumber{4}, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } - if assert.NotNil(t, filledNotification.OldVulnerability) { - assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name) - assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity) - assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1) - } + noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { + oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") } - assert.Nil(t, datastore.DeleteNotification(notification.Name)) + newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey) + if !assert.Nil(t, err) { + assert.FailNow(t, "") + } + + assert.Equal(t, int64(0), oldCurrentPage.StartID) + assert.Equal(t, int64(4), newCurrentPage.StartID) + noti.Old.Current = "" + noti.New.Current = "" + assert.Equal(t, oldPage, *noti.Old) + assert.Equal(t, newPage2, *noti.New) } } } + +func TestInsertVulnerabilityNotifications(t *testing.T) { + datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true) + + n1 := database.VulnerabilityNotification{} + n3 := database.VulnerabilityNotification{ + NotificationHook: database.NotificationHook{ + Name: "random name", + Created: time.Now(), + }, + Old: nil, + New: &database.Vulnerability{}, + } + n4 := database.VulnerabilityNotification{ + NotificationHook: database.NotificationHook{ + Name: "random name", + Created: time.Now(), + }, + Old: nil, + New: &database.Vulnerability{ + Name: "CVE-OPENSSL-1-DEB7", + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + }, + } + + // invalid case + err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1}) + assert.NotNil(t, err) + + // invalid case: unknown vulnerability + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3}) + assert.NotNil(t, err) + + // invalid case: duplicated input notification + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4}) + assert.NotNil(t, err) + tx = restartSession(t, datastore, tx, false) + + // valid case + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) + assert.Nil(t, err) + // invalid case: notification is already in database + err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) + assert.NotNil(t, err) + + closeTest(t, datastore, tx) +} + +func TestFindNewNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindNewNotification", true) + defer closeTest(t, datastore, tx) + + noti, ok, err := tx.FindNewNotification(time.Now()) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.Equal(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) + } + + // can't find the notified + assert.Nil(t, tx.MarkNotificationNotified("test")) + // if the notified time is before + noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second))) + assert.Nil(t, err) + assert.False(t, ok) + // can find the notified after a period of time + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + if assert.Nil(t, err) && assert.True(t, ok) { + assert.Equal(t, "test", noti.Name) + assert.NotEqual(t, time.Time{}, noti.Notified) + assert.Equal(t, time.Time{}, noti.Created) + assert.Equal(t, time.Time{}, noti.Deleted) + } + + assert.Nil(t, tx.DeleteNotification("test")) + // can't find in any time + noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000))) + assert.Nil(t, err) + assert.False(t, ok) + + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + assert.Nil(t, err) + assert.False(t, ok) +} + +func TestMarkNotificationNotified(t *testing.T) { + datastore, tx := openSessionForTest(t, "MarkNotificationNotified", true) + defer closeTest(t, datastore, tx) + + // invalid case: notification doesn't exist + assert.NotNil(t, tx.MarkNotificationNotified("non-existing")) + // valid case + assert.Nil(t, tx.MarkNotificationNotified("test")) + // valid case + assert.Nil(t, tx.MarkNotificationNotified("test")) +} + +func TestDeleteNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "DeleteNotification", true) + defer closeTest(t, datastore, tx) + + // invalid case: notification doesn't exist + assert.NotNil(t, tx.DeleteNotification("non-existing")) + // valid case + assert.Nil(t, tx.DeleteNotification("test")) + // invalid case: notification is already deleted + assert.NotNil(t, tx.DeleteNotification("test")) +} diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 34504a9a..8815fabb 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -31,6 +31,7 @@ import ( "github.com/remind101/migrate" log "github.com/sirupsen/logrus" + "github.com/coreos/clair/api/token" "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/migrations" "github.com/coreos/clair/pkg/commonerr" @@ -59,7 +60,7 @@ var ( promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ Name: "clair_pgsql_concurrent_lock_vafv_total", - Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.", + Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.", }) ) @@ -73,17 +74,65 @@ func init() { database.Register("pgsql", openDatabase) } -type Queryer interface { - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row +// pgSessionCache is the session's cache, which holds the pgSQL's cache and the +// individual session's cache. Only when session.Commit is called, all the +// changes to pgSQL cache will be applied. +type pgSessionCache struct { + c *lru.ARCCache } type pgSQL struct { *sql.DB + cache *lru.ARCCache config Config } +type pgSession struct { + *sql.Tx + + paginationKey string +} + +type idPageNumber struct { + // StartID is an implementation detail for paginating by an ID required to + // be unique to every ancestry and always increasing. + // + // StartID is used to search for ancestry with ID >= StartID + StartID int64 +} + +func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) { + resultBytes, err := token.Marshal(page, paginationKey) + if err != nil { + return result, err + } + result = database.PageNumber(resultBytes) + return result, nil +} + +func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) { + err = token.Unmarshal(string(page), paginationKey, &result) + return +} + +// Begin initiates a transaction to database. The expected transaction isolation +// level in this implementation is "Read Committed". +func (pgSQL *pgSQL) Begin() (database.Session, error) { + tx, err := pgSQL.DB.Begin() + if err != nil { + return nil, err + } + return &pgSession{ + Tx: tx, + paginationKey: pgSQL.config.PaginationKey, + }, nil +} + +func (tx *pgSession) Commit() error { + return tx.Tx.Commit() +} + // Close closes the database and destroys if ManageDatabaseLifecycle has been specified in // the configuration. func (pgSQL *pgSQL) Close() { @@ -109,6 +158,7 @@ type Config struct { ManageDatabaseLifecycle bool FixturePath string + PaginationKey string } // openDatabase opens a PostgresSQL-backed Datastore using the given @@ -134,6 +184,10 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) } + if pg.config.PaginationKey == "" { + panic("pagination key should be given") + } + dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) if err != nil { return nil, err @@ -179,7 +233,7 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig _, err = pg.DB.Exec(string(d)) if err != nil { pg.Close() - return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err) + return nil, fmt.Errorf("pgsql: an error occurred while importing fixtures: %v", err) } } @@ -217,7 +271,7 @@ func migrateDatabase(db *sql.DB) error { err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...) if err != nil { - return fmt.Errorf("pgsql: an error occured while running migrations: %v", err) + return fmt.Errorf("pgsql: an error occurred while running migrations: %v", err) } log.Info("database migration ran successfully") @@ -271,7 +325,8 @@ func dropDatabase(source, dbName string) error { } // handleError logs an error with an extra description and masks the error if it's an SQL one. -// This ensures we never return plain SQL errors and leak anything. +// The function ensures we never return plain SQL errors and leak anything. +// The function should be used for every database query error. func handleError(desc string, err error) error { if err == nil { return nil @@ -297,6 +352,11 @@ func isErrUniqueViolation(err error) bool { return ok && pqErr.Code == "23505" } +// observeQueryTime computes the time elapsed since `start` to represent the +// query time. +// 1. `query` is a pgSession function name. +// 2. `subquery` is a specific query or a batched query. +// 3. `start` is the time right before query is executed. func observeQueryTime(query, subquery string, start time.Time) { promQueryDurationMilliseconds. WithLabelValues(query, subquery). diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 93f53144..96241666 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -15,27 +15,193 @@ package pgsql import ( + "database/sql" "fmt" + "io/ioutil" "os" "path/filepath" "runtime" "strings" + "testing" + fernet "github.com/fernet/fernet-go" "github.com/pborman/uuid" + log "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" + yaml "gopkg.in/yaml.v2" "github.com/coreos/clair/database" ) -func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { - ds, err := openDatabase(generateTestConfig(testName, loadFixture)) +var ( + withFixtureName, withoutFixtureName string +) + +func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) { + config := generateTestConfig(name, loadFixture, false) + source := config.Options["source"].(string) + name, url, err := parseConnectionString(source) + if err != nil { + panic(err) + } + + fixturePath := config.Options["fixturepath"].(string) + + if err := createDatabase(url, name); err != nil { + panic(err) + } + + // migration and fixture + db, err := sql.Open("postgres", source) + if err != nil { + panic(err) + } + + // Verify database state. + if err := db.Ping(); err != nil { + panic(err) + } + + // Run migrations. + if err := migrateDatabase(db); err != nil { + panic(err) + } + + if loadFixture { + log.Info("pgsql: loading fixtures") + + d, err := ioutil.ReadFile(fixturePath) + if err != nil { + panic(err) + } + + _, err = db.Exec(string(d)) + if err != nil { + panic(err) + } + } + + db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name) + db.Close() + + log.Info("Generated Template database ", name) + return url, name +} + +func dropTemplateDatabase(url string, name string) { + db, err := sql.Open("postgres", url) + if err != nil { + panic(err) + } + + if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil { + panic(err) + } + + if err := db.Close(); err != nil { + panic(err) + } + + if err := dropDatabase(url, name); err != nil { + panic(err) + } + +} +func TestMain(m *testing.M) { + fURL, fName := genTemplateDatabase("fixture", true) + nfURL, nfName := genTemplateDatabase("nonfixture", false) + + withFixtureName = fName + withoutFixtureName = nfName + + m.Run() + + dropTemplateDatabase(fURL, fName) + dropTemplateDatabase(nfURL, nfName) +} + +func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) { + var fixtureName string + if fixture { + fixtureName = withFixtureName + } else { + fixtureName = withoutFixtureName + } + + // copy the database into new database + var pg pgSQL + // Parse configuration. + pg.config = Config{ + CacheSize: 16384, + } + + bytes, err := yaml.Marshal(testConfig.Options) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + err = yaml.Unmarshal(bytes, &pg.config) + if err != nil { + return nil, fmt.Errorf("pgsql: could not load configuration: %v", err) + } + + dbName, pgSourceURL, err := parseConnectionString(pg.config.Source) if err != nil { return nil, err } - datastore := ds.(*pgSQL) + + // Create database. + if pg.config.ManageDatabaseLifecycle { + if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil { + return nil, err + } + } + + // Open database. + pg.DB, err = sql.Open("postgres", pg.config.Source) + fmt.Println("database", pg.config.Source) + if err != nil { + pg.Close() + return nil, fmt.Errorf("pgsql: could not open database: %v", err) + } + + return &pg, nil +} + +// copyDatabase creates a new database with +func copyDatabase(url, name string, templateName string) error { + // Open database. + db, err := sql.Open("postgres", url) + if err != nil { + return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err) + } + defer db.Close() + + // Create database with copy + _, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName) + if err != nil { + return fmt.Errorf("pgsql: could not create database: %v", err) + } + + return nil +} + +func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) { + var ( + db database.Datastore + err error + testConfig = generateTestConfig(testName, loadFixture, true) + ) + + db, err = openCopiedDatabase(testConfig, loadFixture) + + if err != nil { + return nil, err + } + datastore := db.(*pgSQL) return datastore, nil } -func generateTestConfig(testName string, loadFixture bool) database.RegistrableComponentConfig { +func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig { dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1) var fixturePath string @@ -49,12 +215,60 @@ func generateTestConfig(testName string, loadFixture bool) database.RegistrableC source = fmt.Sprintf(sourceEnv, dbName) } + var key fernet.Key + if err := key.Generate(); err != nil { + panic("failed to generate pagination key" + err.Error()) + } + return database.RegistrableComponentConfig{ Options: map[string]interface{}{ "source": source, "cachesize": 0, - "managedatabaselifecycle": true, + "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, + "paginationkey": key.Encode(), }, } } + +func closeTest(t *testing.T, store database.Datastore, session database.Session) { + err := session.Rollback() + if err != nil { + t.Error(err) + t.FailNow() + } + + store.Close() +} + +func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) { + store, err := openDatabaseForTest(name, loadFixture) + if err != nil { + t.Error(err) + t.FailNow() + } + tx, err := store.Begin() + if err != nil { + t.Error(err) + t.FailNow() + } + return store, tx.(*pgSession) +} + +func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession { + var err error + if !commit { + err = tx.Rollback() + } else { + err = tx.Commit() + } + + if assert.Nil(t, err) { + session, err := datastore.Begin() + if assert.Nil(t, err) { + return session.(*pgSession) + } + } + t.FailNow() + return nil +} diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index 3fedf8d0..64e5b20a 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -14,185 +14,159 @@ package pgsql -import "strconv" +import ( + "fmt" + "strings" + + "github.com/lib/pq" +) const ( - lockVulnerabilityAffects = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` - disableHashJoin = `SET LOCAL enable_hashjoin = off` - disableMergeJoin = `SET LOCAL enable_mergejoin = off` + lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` // keyvalue.go - updateKeyValue = `UPDATE KeyValue SET value = $1 WHERE key = $2` - insertKeyValue = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1` + upsertKeyValue = ` + INSERT INTO KeyValue(key, value) + VALUES ($1, $2) + ON CONFLICT ON CONSTRAINT keyvalue_key_key + DO UPDATE SET key=$1, value=$2` // namespace.go - soiNamespace = ` - WITH new_namespace AS ( - INSERT INTO Namespace(name, version_format) - SELECT CAST($1 AS VARCHAR), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1) - RETURNING id - ) - SELECT id FROM Namespace WHERE name = $1 - UNION - SELECT id FROM new_namespace` - searchNamespace = `SELECT id FROM Namespace WHERE name = $1` - listNamespace = `SELECT id, name, version_format FROM Namespace` + searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` // feature.go - soiFeature = ` - WITH new_feature AS ( - INSERT INTO Feature(name, namespace_id) - SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER) - WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2) + soiNamespacedFeature = ` + WITH new_feature_ns AS ( + INSERT INTO namespaced_feature(feature_id, namespace_id) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER) + WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2) RETURNING id ) - SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2 + SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2 UNION - SELECT id FROM new_feature` + SELECT id FROM new_feature_ns` - searchFeatureVersion = ` - SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2` + searchPotentialAffectingVulneraibilities = ` + SELECT nf.id, v.id, vaf.affected_version, vaf.id + FROM vulnerability_affected_feature AS vaf, vulnerability AS v, + namespaced_feature AS nf, feature AS f + WHERE nf.id = ANY($1) + AND nf.feature_id = f.id + AND nf.namespace_id = v.namespace_id + AND vaf.feature_name = f.name + AND vaf.vulnerability_id = v.id + AND v.deleted_at IS NULL` - soiFeatureVersion = ` - WITH new_featureversion AS ( - INSERT INTO FeatureVersion(feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2) - RETURNING id - ) - SELECT false, id FROM FeatureVersion WHERE feature_id = $1 AND version = $2 - UNION - SELECT true, id FROM new_featureversion` - - searchVulnerabilityFixedInFeature = ` - SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature - WHERE feature_id = $1` - - insertVulnerabilityAffectsFeatureVersion = ` - INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id) - VALUES($1, $2, $3)` + searchNamespacedFeaturesVulnerabilities = ` + SELECT vanf.namespaced_feature_id, v.name, v.description, v.link, + v.severity, v.metadata, vaf.fixedin, n.name, n.version_format + FROM vulnerability_affected_namespaced_feature AS vanf, + Vulnerability AS v, + vulnerability_affected_feature AS vaf, + namespace AS n + WHERE vanf.namespaced_feature_id = ANY($1) + AND vaf.id = vanf.added_by + AND v.id = vanf.vulnerability_id + AND n.id = v.namespace_id + AND v.deleted_at IS NULL` // layer.go - searchLayer = ` - SELECT l.id, l.name, l.engineversion, p.id, p.name - FROM Layer l - LEFT JOIN Layer p ON l.parent_id = p.id - WHERE l.name = $1;` + searchLayerIDs = `SELECT id, hash FROM layer WHERE hash = ANY($1);` - searchLayerNamespace = ` - SELECT n.id, n.name, n.version_format - FROM Namespace n - JOIN Layer_Namespace lns ON lns.namespace_id = n.id - WHERE lns.layer_id = $1` + searchLayerFeatures = ` + SELECT feature.Name, feature.Version, feature.version_format + FROM feature, layer_feature + WHERE layer_feature.layer_id = $1 + AND layer_feature.feature_id = feature.id` - searchLayerFeatureVersion = ` - WITH RECURSIVE layer_tree(id, name, parent_id, depth, path, cycle) AS( - SELECT l.id, l.name, l.parent_id, 1, ARRAY[l.id], false - FROM Layer l - WHERE l.id = $1 - UNION ALL - SELECT l.id, l.name, l.parent_id, lt.depth + 1, path || l.id, l.id = ANY(path) - FROM Layer l, layer_tree lt - WHERE l.id = lt.parent_id - ) - SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, fn.version_format, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name - FROM Layer_diff_FeatureVersion ldf - JOIN ( - SELECT row_number() over (ORDER BY depth DESC), id, name FROM layer_tree - ) AS ltree (ordering, id, name) ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn - WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id - ORDER BY ltree.ordering` + searchLayerNamespaces = ` + SELECT namespace.Name, namespace.version_format + FROM namespace, layer_namespace + WHERE layer_namespace.layer_id = $1 + AND layer_namespace.namespace_id = namespace.id` - searchFeatureVersionVulnerability = ` - SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata, - vn.name, vn.version_format, vfif.version - FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v, - Namespace vn, Vulnerability_FixedIn_Feature vfif - WHERE vafv.featureversion_id = ANY($1::integer[]) - AND vfif.vulnerability_id = v.id - AND vafv.fixedin_id = vfif.id - AND v.namespace_id = vn.id - AND v.deleted_at IS NULL` - - insertLayer = ` - INSERT INTO Layer(name, engineversion, parent_id, created_at) - VALUES($1, $2, $3, CURRENT_TIMESTAMP) - RETURNING id` - - insertLayerNamespace = `INSERT INTO Layer_Namespace(layer_id, namespace_id) VALUES($1, $2)` - removeLayerNamespace = `DELETE FROM Layer_Namespace WHERE layer_id = $1` - - updateLayer = `UPDATE LAYER SET engineversion = $2 WHERE id = $1` - - removeLayerDiffFeatureVersion = ` - DELETE FROM Layer_diff_FeatureVersion - WHERE layer_id = $1` - - insertLayerDiffFeatureVersion = ` - INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) - SELECT $1, fv.id, $2 - FROM FeatureVersion fv - WHERE fv.id = ANY($3::integer[])` - - removeLayer = `DELETE FROM Layer WHERE name = $1` + searchLayer = `SELECT id FROM layer WHERE hash = $1` + searchLayerDetectors = `SELECT detector FROM layer_detector WHERE layer_id = $1` + searchLayerListers = `SELECT lister FROM layer_lister WHERE layer_id = $1` // lock.go - insertLock = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)` + soiLock = `INSERT INTO lock(name, owner, until) VALUES ($1, $2, $3)` + searchLock = `SELECT owner, until FROM Lock WHERE name = $1` updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2` removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2` removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` // vulnerability.go - searchVulnerabilityBase = ` - SELECT v.id, v.name, n.id, n.name, n.version_format, v.description, v.link, v.severity, v.metadata - FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` - searchVulnerabilityForUpdate = ` FOR UPDATE OF v` - searchVulnerabilityByNamespaceAndName = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL` - searchVulnerabilityByID = ` WHERE v.id = $1` - searchVulnerabilityByNamespace = ` WHERE n.name = $1 AND v.deleted_at IS NULL - AND v.id >= $2 - ORDER BY v.id - LIMIT $3` + searchVulnerability = ` + SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND v.name = $1 + AND n.name = $2 + AND v.deleted_at IS NULL + ` - searchVulnerabilityFixedIn = ` - SELECT vfif.version, f.id, f.Name - FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id - WHERE vfif.vulnerability_id = $1` + insertVulnerabilityAffected = ` + INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin) + VALUES ($1, $2, $3, $4) + RETURNING ID + ` + + searchVulnerabilityAffected = ` + SELECT vulnerability_id, feature_name, affected_version, fixedin + FROM vulnerability_affected_feature + WHERE vulnerability_id = ANY($1) + ` + + searchVulnerabilityByID = ` + SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND v.id = $1` + + searchVulnerabilityPotentialAffected = ` + WITH req AS ( + SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id + FROM vulnerability_affected_feature AS vaf, + vulnerability AS v, + namespace AS n + WHERE vaf.vulnerability_id = ANY($1) + AND v.id = vaf.vulnerability_id + AND n.id = v.namespace_id + ) + SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by + FROM feature AS f, namespaced_feature AS nf, req + WHERE f.name = req.name + AND nf.namespace_id = req.n_id + AND nf.feature_id = f.id` + + insertVulnerabilityAffectedNamespacedFeature = ` + INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) + VALUES ($1, $2, $3)` insertVulnerability = ` - INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) - VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP) - RETURNING id` - - soiVulnerabilityFixedInFeature = ` - WITH new_fixedinfeature AS ( - INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS INTEGER), CAST($3 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2) - RETURNING id + WITH ns AS ( + SELECT id FROM namespace WHERE name = $6 AND version_format = $7 ) - SELECT false, id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2 - UNION - SELECT true, id FROM new_fixedinfeature` - - searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = $1` + INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) + VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP) + RETURNING id` removeVulnerability = ` UPDATE Vulnerability - SET deleted_at = CURRENT_TIMESTAMP - WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) - AND name = $2 - AND deleted_at IS NULL - RETURNING id` + SET deleted_at = CURRENT_TIMESTAMP + WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) + AND name = $2 + AND deleted_at IS NULL + RETURNING id` // notification.go insertNotification = ` INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) - VALUES($1, CURRENT_TIMESTAMP, $2, $3)` + VALUES ($1, $2, $3, $4)` updatedNotificationNotified = ` UPDATE Vulnerability_Notification @@ -202,10 +176,10 @@ const ( removeNotification = ` UPDATE Vulnerability_Notification SET deleted_at = CURRENT_TIMESTAMP - WHERE name = $1` + WHERE name = $1 AND deleted_at IS NULL` searchNotificationAvailable = ` - SELECT id, name, created_at, notified_at, deleted_at + SELECT name, created_at, notified_at, deleted_at FROM Vulnerability_Notification WHERE (notified_at IS NULL OR notified_at < $1) AND deleted_at IS NULL @@ -214,43 +188,231 @@ const ( LIMIT 1` searchNotification = ` - SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id + SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id FROM Vulnerability_Notification WHERE name = $1` - searchNotificationLayerIntroducingVulnerability = ` - WITH LDFV AS ( - SELECT DISTINCT ldfv.layer_id - FROM Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv - WHERE ldfv.layer_id >= $2 - AND vafv.vulnerability_id = $1 - AND vafv.featureversion_id = fv.id - AND ldfv.featureversion_id = fv.id - AND ldfv.modification = 'add' - ORDER BY ldfv.layer_id - ) - SELECT l.id, l.name - FROM LDFV, Layer l - WHERE LDFV.layer_id = l.id - LIMIT $3` + searchNotificationVulnerableAncestry = ` + SELECT DISTINCT ON (a.id) + a.id, a.name + FROM vulnerability_affected_namespaced_feature AS vanf, + ancestry AS a, ancestry_feature AS af + WHERE vanf.vulnerability_id = $1 + AND a.id >= $2 + AND a.id = af.ancestry_id + AND af.namespaced_feature_id = vanf.namespaced_feature_id + ORDER BY a.id ASC + LIMIT $3;` - // complex_test.go - searchComplexTestFeatureVersionAffects = ` - SELECT v.name - FROM FeatureVersion fv - LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id - JOIN Vulnerability v ON vaf.vulnerability_id = v.id - WHERE featureversion_id = $1` + // ancestry.go + persistAncestryLister = ` + INSERT INTO ancestry_lister (ancestry_id, lister) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT) + WHERE NOT EXISTS (SELECT id FROM ancestry_lister WHERE ancestry_id = $1 AND lister = $2) ON CONFLICT DO NOTHING` + + persistAncestryDetector = ` + INSERT INTO ancestry_detector (ancestry_id, detector) + SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT) + WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector = $2) ON CONFLICT DO NOTHING` + + insertAncestry = `INSERT INTO ancestry (name) VALUES ($1) RETURNING id` + + searchAncestryLayer = ` + SELECT layer.hash + FROM layer, ancestry_layer + WHERE ancestry_layer.ancestry_id = $1 + AND ancestry_layer.layer_id = layer.id + ORDER BY ancestry_layer.ancestry_index ASC` + + searchAncestryFeatures = ` + SELECT namespace.name, namespace.version_format, feature.name, feature.version + FROM namespace, feature, ancestry, namespaced_feature, ancestry_feature + WHERE ancestry.name = $1 + AND ancestry.id = ancestry_feature.ancestry_id + AND ancestry_feature.namespaced_feature_id = namespaced_feature.id + AND namespaced_feature.feature_id = feature.id + AND namespaced_feature.namespace_id = namespace.id` + + searchAncestry = `SELECT id FROM ancestry WHERE name = $1` + searchAncestryDetectors = `SELECT detector FROM ancestry_detector WHERE ancestry_id = $1` + searchAncestryListers = `SELECT lister FROM ancestry_lister WHERE ancestry_id = $1` + removeAncestry = `DELETE FROM ancestry WHERE name = $1` + insertAncestryLayer = `INSERT INTO ancestry_layer(ancestry_id, ancestry_index, layer_id) VALUES($1,$2,$3)` + insertAncestryFeature = `INSERT INTO ancestry_feature(ancestry_id, namespaced_feature_id) VALUES ($1, $2)` ) -// buildInputArray constructs a PostgreSQL input array from the specified integers. -// Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using -// a single placeholder. -func buildInputArray(ints []int) string { - str := "{" - for i := 0; i < len(ints)-1; i++ { - str = str + strconv.Itoa(ints[i]) + "," - } - str = str + strconv.Itoa(ints[len(ints)-1]) + "}" - return str +// NOTE(Sida): Every search query can only have count less than postgres set +// stack depth. IN will be resolved to nested OR_s and the parser might exceed +// stack depth. TODO(Sida): Generate different queries for different count: if +// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for +// count > 65535, use is expected to split data into batches. +func querySearchLastDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT vid, vname, nname FROM ( + SELECT v.id AS vid, v.name AS vname, n.name AS nname, + row_number() OVER ( + PARTITION by (v.name, n.name) + ORDER BY v.deleted_at DESC + ) AS rownum + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND (v.name, n.name) IN ( %s ) + AND v.deleted_at IS NOT NULL + ) tmp WHERE rownum <= 1`, + queryString(2, count)) +} + +func querySearchNotDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) + AND v.deleted_at IS NULL`, + queryString(2, count)) +} + +func querySearchFeatureID(featureCount int) string { + return fmt.Sprintf(` + SELECT id, name, version, version_format + FROM Feature WHERE (name, version, version_format) IN (%s)`, + queryString(3, featureCount), + ) +} + +func querySearchNamespacedFeature(nsfCount int) string { + return fmt.Sprintf(` + SELECT nf.id, f.name, f.version, f.version_format, n.name + FROM namespaced_feature AS nf, feature AS f, namespace AS n + WHERE nf.feature_id = f.id + AND nf.namespace_id = n.id + AND n.version_format = f.version_format + AND (f.name, f.version, f.version_format, n.name) IN (%s)`, + queryString(4, nsfCount), + ) +} + +func querySearchNamespace(nsCount int) string { + return fmt.Sprintf( + `SELECT id, name, version_format + FROM namespace WHERE (name, version_format) IN (%s)`, + queryString(2, nsCount), + ) +} + +func queryInsert(count int, table string, columns ...string) string { + base := `INSERT INTO %s (%s) VALUES %s` + t := pq.QuoteIdentifier(table) + cols := make([]string, len(columns)) + for i, c := range columns { + cols[i] = pq.QuoteIdentifier(c) + } + colsQuoted := strings.Join(cols, ",") + return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count)) +} + +func queryPersist(count int, table, constraint string, columns ...string) string { + ct := "" + if constraint != "" { + ct = fmt.Sprintf("ON CONSTRAINT %s", constraint) + } + return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct) +} + +func queryInsertNotifications(count int) string { + return queryInsert(count, + "vulnerability_notification", + "name", + "created_at", + "old_vulnerability_id", + "new_vulnerability_id", + ) +} + +func queryPersistFeature(count int) string { + return queryPersist(count, + "feature", + "feature_name_version_version_format_key", + "name", + "version", + "version_format") +} + +func queryPersistLayerFeature(count int) string { + return queryPersist(count, + "layer_feature", + "layer_feature_layer_id_feature_id_key", + "layer_id", + "feature_id") +} + +func queryPersistNamespace(count int) string { + return queryPersist(count, + "namespace", + "namespace_name_version_format_key", + "name", + "version_format") +} + +func queryPersistLayerListers(count int) string { + return queryPersist(count, + "layer_lister", + "layer_lister_layer_id_lister_key", + "layer_id", + "lister") +} + +func queryPersistLayerDetectors(count int) string { + return queryPersist(count, + "layer_detector", + "layer_detector_layer_id_detector_key", + "layer_id", + "detector") +} + +func queryPersistLayerNamespace(count int) string { + return queryPersist(count, + "layer_namespace", + "layer_namespace_layer_id_namespace_id_key", + "layer_id", + "namespace_id") +} + +// size of key and array should be both greater than 0 +func queryString(keySize, arraySize int) string { + if arraySize <= 0 || keySize <= 0 { + panic("Bulk Query requires size of element tuple and number of elements to be greater than 0") + } + keys := make([]string, 0, arraySize) + for i := 0; i < arraySize; i++ { + key := make([]string, keySize) + for j := 0; j < keySize; j++ { + key[j] = fmt.Sprintf("$%d", i*keySize+j+1) + } + keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ","))) + } + return strings.Join(keys, ",") +} + +func queryPersistNamespacedFeature(count int) string { + return queryPersist(count, "namespaced_feature", + "namespaced_feature_namespace_id_feature_id_key", + "feature_id", + "namespace_id") +} + +func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string { + return queryPersist(count, "vulnerability_affected_namespaced_feature", + "vulnerability_affected_namesp_vulnerability_id_namespaced_f_key", + "vulnerability_id", + "namespaced_feature_id", + "added_by") +} + +func queryPersistLayer(count int) string { + return queryPersist(count, "layer", "", "hash") +} + +func queryInvalidateVulnerabilityCache(count int) string { + return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature + WHERE vulnerability_id = (%s)`, + queryString(1, count)) } diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testdata/data.sql index b01e170e..a4ccd31c 100644 --- a/database/pgsql/testdata/data.sql +++ b/database/pgsql/testdata/data.sql @@ -1,73 +1,117 @@ --- Copyright 2015 clair authors --- --- Licensed under the Apache License, Version 2.0 (the "License"); --- you may not use this file except in compliance with the License. --- You may obtain a copy of the License at --- --- http://www.apache.org/licenses/LICENSE-2.0 --- --- Unless required by applicable law or agreed to in writing, software --- distributed under the License is distributed on an "AS IS" BASIS, --- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. --- See the License for the specific language governing permissions and --- limitations under the License. - INSERT INTO namespace (id, name, version_format) VALUES - (1, 'debian:7', 'dpkg'), - (2, 'debian:8', 'dpkg'); +(1, 'debian:7', 'dpkg'), +(2, 'debian:8', 'dpkg'), +(3, 'fake:1.0', 'rpm'); -INSERT INTO feature (id, namespace_id, name) VALUES - (1, 1, 'wechat'), - (2, 1, 'openssl'), - (4, 1, 'libssl'), - (3, 2, 'openssl'); +INSERT INTO feature (id, name, version, version_format) VALUES +(1, 'wechat', '0.5', 'dpkg'), +(2, 'openssl', '1.0', 'dpkg'), +(3, 'openssl', '2.0', 'dpkg'), +(4, 'fake', '2.0', 'rpm'); -INSERT INTO featureversion (id, feature_id, version) VALUES - (1, 1, '0.5'), - (2, 2, '1.0'), - (3, 2, '2.0'), - (4, 3, '1.0'); +INSERT INTO layer (id, hash) VALUES + (1, 'layer-0'), -- blank + (2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0 + (3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0 + (4, 'layer-3a'),-- debian:7; + (5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0 + (6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake) -INSERT INTO layer (id, name, engineversion, parent_id) VALUES - (1, 'layer-0', 1, NULL), - (2, 'layer-1', 1, 1), - (3, 'layer-2', 1, 2), - (4, 'layer-3a', 1, 3), - (5, 'layer-3b', 1, 3); - -INSERT INTO layer_namespace (id, layer_id, namespace_id) VALUES +INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES (1, 2, 1), (2, 3, 1), (3, 4, 1), (4, 5, 2), - (5, 5, 1); + (5, 6, 1), + (6, 6, 3); -INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modification) VALUES - (1, 2, 1, 'add'), - (2, 2, 2, 'add'), - (3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0 - (4, 3, 3, 'add'), -- ^ - (5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0 - (6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0 +INSERT INTO layer_feature(id, layer_id, feature_id) VALUES + (1, 2, 1), + (2, 2, 2), + (3, 3, 1), + (4, 3, 3), + (5, 5, 1), + (6, 5, 2), + (7, 6, 4), + (8, 6, 3); + +INSERT INTO layer_lister(id, layer_id, lister) VALUES + (1, 1, 'dpkg'), + (2, 2, 'dpkg'), + (3, 3, 'dpkg'), + (4, 4, 'dpkg'), + (5, 5, 'dpkg'), + (6, 6, 'dpkg'), + (7, 6, 'rpm'); + +INSERT INTO layer_detector(id, layer_id, detector) VALUES + (1, 1, 'os-release'), + (2, 2, 'os-release'), + (3, 3, 'os-release'), + (4, 4, 'os-release'), + (5, 5, 'os-release'), + (6, 6, 'os-release'), + (7, 6, 'apt-sources'); + +INSERT INTO ancestry (id, name) VALUES + (1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a + (2, 'ancestry-2'), -- layer-0, layer-1, layer-2, layer-3b + (3, 'ancestry-3'), -- empty; just for testing the vulnerable ancestry + (4, 'ancestry-4'); -- empty; just for testing the vulnerable ancestry + +INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES + (1, 1, 'dpkg'), + (2, 2, 'dpkg'); + +INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES + (1, 1, 'os-release'), + (2, 2, 'os-release'); + +INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES + (1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3), + (5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3); + +INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES + (1, 1, 1), -- wechat 0.5, debian:7 + (2, 2, 1), -- openssl 1.0, debian:7 + (3, 2, 2), -- openssl 1.0, debian:8 + (4, 3, 1); -- openssl 2.0, debian:7 + +INSERT INTO ancestry_feature (id, ancestry_id, namespaced_feature_id) VALUES + (1, 1, 1), (2, 1, 4), + (3, 2, 1), (4, 2, 3), + (5, 3, 2), (6, 4, 2); -- assume that ancestry-3 and ancestry-4 are vulnerable. INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown'); -INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES - (1, 1, 2, '2.0'), - (2, 1, 4, '1.9-abc'); +INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES + (3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483'); + +INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES +(1, 1, 'openssl', '2.0', '2.0'), +(2, 1, 'libssl', '1.9-abc', '1.9-abc'); -INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES - (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 +INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES + (1, 1, 2, 1); + +INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES + (1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7' SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('featureversion', 'id'), (SELECT MAX(id) FROM featureversion)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('layer_diff_featureversion', 'id'), (SELECT MAX(id) FROM layer_diff_featureversion)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_fixedin_feature', 'id'), (SELECT MAX(id) FROM vulnerability_fixedin_feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affects_featureversion', 'id'), (SELECT MAX(id) FROM vulnerability_affects_featureversion)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1); diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index efb57392..ab92c0e9 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -17,352 +17,207 @@ package pgsql import ( "database/sql" "encoding/json" - "reflect" + "errors" "time" - "github.com/guregu/null/zero" + "github.com/lib/pq" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" - "github.com/coreos/clair/pkg/commonerr" ) -// compareStringLists returns the strings that are present in X but not in Y. -func compareStringLists(X, Y []string) []string { - m := make(map[string]bool) +var ( + errVulnerabilityNotFound = errors.New("vulnerability is not in database") +) - for _, y := range Y { - m[y] = true - } - - diff := []string{} - for _, x := range X { - if m[x] { - continue - } - - diff = append(diff, x) - m[x] = true - } - - return diff +type affectedAncestry struct { + name string + id int64 } -func compareStringListsInBoth(X, Y []string) []string { - m := make(map[string]struct{}) - - for _, y := range Y { - m[y] = struct{}{} - } - - diff := []string{} - for _, x := range X { - if _, e := m[x]; e { - diff = append(diff, x) - delete(m, x) - } - } - - return diff +type affectRelation struct { + vulnerabilityID int64 + namespacedFeatureID int64 + addedBy int64 } -func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) { - defer observeQueryTime("listVulnerabilities", "all", time.Now()) +type affectedFeatureRows struct { + rows map[int64]database.AffectedFeature +} - // Query Namespace. - var id int - err := pgSQL.QueryRow(searchNamespace, namespaceName).Scan(&id) +func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + defer observeQueryTime("findVulnerabilities", "", time.Now()) + resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) + vulnIDMap := map[int64][]*database.NullableVulnerability{} + + //TODO(Sida): Change to bulk search. + stmt, err := tx.Prepare(searchVulnerability) if err != nil { - return nil, -1, handleError("searchNamespace", err) - } else if id == 0 { - return nil, -1, commonerr.ErrNotFound + return nil, err } - // Query. - query := searchVulnerabilityBase + searchVulnerabilityByNamespace - rows, err := pgSQL.Query(query, namespaceName, startID, limit+1) - if err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace", err) - } - defer rows.Close() - - var vulns []database.Vulnerability - nextID := -1 - size := 0 - // Scan query. - for rows.Next() { - var vulnerability database.Vulnerability - - err := rows.Scan( - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Namespace.ID, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, + // load vulnerabilities + for i, key := range vulnerabilities { + var ( + id sql.NullInt64 + vuln = database.NullableVulnerability{ + VulnerabilityWithAffected: database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: key.Name, + Namespace: database.Namespace{ + Name: key.Namespace, + }, + }, + }, + } ) + + err := stmt.QueryRow(key.Name, key.Namespace).Scan( + &id, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.Namespace.VersionFormat, + ) + + if err != nil && err != sql.ErrNoRows { + stmt.Close() + return nil, handleError("searchVulnerability", err) + } + vuln.Valid = id.Valid + resultVuln[i] = vuln + if id.Valid { + vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i]) + } + } + + if err := stmt.Close(); err != nil { + return nil, handleError("searchVulnerability", err) + } + + toQuery := make([]int64, 0, len(vulnIDMap)) + for id := range vulnIDMap { + toQuery = append(toQuery, id) + } + + // load vulnerability affected features + rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) + if err != nil { + return nil, handleError("searchVulnerabilityAffected", err) + } + + for rows.Next() { + var ( + id int64 + f database.AffectedFeature + ) + + err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) if err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err) + return nil, handleError("searchVulnerabilityAffected", err) } - size++ - if size > limit { - nextID = vulnerability.ID - } else { - vulns = append(vulns, vulnerability) + + for _, vuln := range vulnIDMap[id] { + f.Namespace = vuln.Namespace + vuln.Affected = append(vuln.Affected, f) } } - if err := rows.Err(); err != nil { - return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err) + return resultVuln, nil +} + +func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { + defer observeQueryTime("insertVulnerabilities", "all", time.Now()) + // bulk insert vulnerabilities + vulnIDs, err := tx.insertVulnerabilities(vulnerabilities) + if err != nil { + return err } - return vulns, nextID, nil -} - -func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { - return findVulnerability(pgSQL, namespaceName, name, false) -} - -func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerability", "all", time.Now()) - - queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName" - query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName - if forUpdate { - queryName = queryName + "+searchVulnerabilityForUpdate" - query = query + searchVulnerabilityForUpdate + // bulk insert vulnerability affected features + vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities) + if err != nil { + return err } - return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name)) + return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap) } -func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) - - queryName := "searchVulnerabilityBase+searchVulnerabilityByID" - query := searchVulnerabilityBase + searchVulnerabilityByID - - return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id)) -} - -func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) { - var vulnerability database.Vulnerability - - err := vulnerabilityRow.Scan( - &vulnerability.ID, - &vulnerability.Name, - &vulnerability.Namespace.ID, - &vulnerability.Namespace.Name, - &vulnerability.Namespace.VersionFormat, - &vulnerability.Description, - &vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, +// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. +// +// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. +func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { + var ( + vulnFeature = map[int64]affectedFeatureRows{} + affectedID int64 ) + //TODO(Sida): Change to bulk insert. + stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { - return vulnerability, handleError(queryName+".Scan()", err) + return nil, handleError("insertVulnerabilityAffected", err) } - if vulnerability.ID == 0 { - return vulnerability, commonerr.ErrNotFound - } - - // Query the FixedIn FeatureVersion now. - rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID) - if err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) - } - defer rows.Close() - - for rows.Next() { - var featureVersionID zero.Int - var featureVersionVersion zero.String - var featureVersionFeatureName zero.String - - err := rows.Scan( - &featureVersionVersion, - &featureVersionID, - &featureVersionFeatureName, - ) - - if err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) - } - - if !featureVersionID.IsZero() { - // Note that the ID we fill in featureVersion is actually a Feature ID, and not - // a FeatureVersion ID. - featureVersion := database.FeatureVersion{ - Model: database.Model{ID: int(featureVersionID.Int64)}, - Feature: database.Feature{ - Model: database.Model{ID: int(featureVersionID.Int64)}, - Namespace: vulnerability.Namespace, - Name: featureVersionFeatureName.String, - }, - Version: featureVersionVersion.String, + defer stmt.Close() + for i, vuln := range vulnerabilities { + // affected feature row ID -> affected feature + affectedFeatures := map[int64]database.AffectedFeature{} + for _, f := range vuln.Affected { + err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID) + if err != nil { + return nil, handleError("insertVulnerabilityAffected", err) } - vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + affectedFeatures[affectedID] = f } + vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures} } - if err := rows.Err(); err != nil { - return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err) - } - - return vulnerability, nil + return vulnFeature, nil } -// FixedIn.Namespace are not necessary, they are overwritten by the vuln. -// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore. -func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error { - for _, vulnerability := range vulnerabilities { - err := pgSQL.insertVulnerability(vulnerability, false, generateNotifications) +// insertVulnerabilities inserts a set of unique vulnerabilities into database, +// under the assumption that all vulnerabilities are valid. +func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { + var ( + vulnID int64 + vulnIDs = make([]int64, 0, len(vulnerabilities)) + vulnMap = map[database.VulnerabilityID]struct{}{} + ) + + for _, v := range vulnerabilities { + key := database.VulnerabilityID{ + Name: v.Name, + Namespace: v.Namespace.Name, + } + + // Ensure uniqueness of vulnerability IDs + if _, ok := vulnMap[key]; ok { + return nil, errors.New("inserting duplicated vulnerabilities is not allowed") + } + vulnMap[key] = struct{}{} + } + + //TODO(Sida): Change to bulk insert. + stmt, err := tx.Prepare(insertVulnerability) + if err != nil { + return nil, handleError("insertVulnerability", err) + } + + defer stmt.Close() + for _, vuln := range vulnerabilities { + err := stmt.QueryRow(vuln.Name, vuln.Description, + vuln.Link, &vuln.Severity, &vuln.Metadata, + vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) if err != nil { - return err - } - } - return nil -} - -func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error { - tf := time.Now() - - // Verify parameters - if vulnerability.Name == "" || vulnerability.Namespace.Name == "" { - return commonerr.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace") - } - - for i := 0; i < len(vulnerability.FixedIn); i++ { - fifv := &vulnerability.FixedIn[i] - - if fifv.Feature.Namespace.Name == "" { - // As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's - // Namespace. - fifv.Feature.Namespace = vulnerability.Namespace - } else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name { - msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability" - log.Warning(msg) - return commonerr.NewBadRequestError(msg) - } - } - - // We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities. - defer observeQueryTime("insertVulnerability", "all", tf) - - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.Begin()", err) - } - - // Find existing vulnerability and its Vulnerability_FixedIn_Features (for update). - existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, true) - if err != nil && err != commonerr.ErrNotFound { - tx.Rollback() - return err - } - - if onlyFixedIn { - // Because this call tries to update FixedIn FeatureVersion, import all other data from the - // existing one. - if existingVulnerability.ID == 0 { - return commonerr.ErrNotFound + return nil, handleError("insertVulnerability", err) } - fixedIn := vulnerability.FixedIn - vulnerability = existingVulnerability - vulnerability.FixedIn = fixedIn + vulnIDs = append(vulnIDs, vulnID) } - if existingVulnerability.ID != 0 { - updateMetadata := vulnerability.Description != existingVulnerability.Description || - vulnerability.Link != existingVulnerability.Link || - vulnerability.Severity != existingVulnerability.Severity || - !reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata) - - // Construct the entire list of FixedIn FeatureVersion, by using the - // the FixedIn list of the old vulnerability. - // - // TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the - // existing vulnerability in order to make metadata updates much faster. - var updateFixedIn bool - vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn) - - if !updateMetadata && !updateFixedIn { - tx.Commit() - return nil - } - - // Mark the old vulnerability as non latest. - _, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name) - if err != nil { - tx.Rollback() - return handleError("removeVulnerability", err) - } - } else { - // The vulnerability is new, we don't want to have any - // versionfmt.MinVersion as they are only used for diffing existing - // vulnerabilities. - var fixedIn []database.FeatureVersion - for _, fv := range vulnerability.FixedIn { - if fv.Version != versionfmt.MinVersion { - fixedIn = append(fixedIn, fv) - } - } - vulnerability.FixedIn = fixedIn - } - - // Find or insert Vulnerability's Namespace. - namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace) - if err != nil { - return err - } - - // Insert vulnerability. - err = tx.QueryRow( - insertVulnerability, - namespaceID, - vulnerability.Name, - vulnerability.Description, - vulnerability.Link, - &vulnerability.Severity, - &vulnerability.Metadata, - ).Scan(&vulnerability.ID) - - if err != nil { - tx.Rollback() - return handleError("insertVulnerability", err) - } - - // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. - err = pgSQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn) - if err != nil { - tx.Rollback() - return err - } - - // Create a notification. - if generateNotification { - err = createNotification(tx, existingVulnerability.ID, vulnerability.ID) - if err != nil { - return err - } - } - - // Commit transaction. - err = tx.Commit() - if err != nil { - tx.Rollback() - return handleError("insertVulnerability.Commit()", err) - } - - return nil + return vulnIDs, nil } // castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that @@ -376,241 +231,208 @@ func castMetadata(m database.MetadataMap) database.MetadataMap { return c } -// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result. -func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) { - currentMap, currentNames := createFeatureVersionNameMap(currentList) - diffMap, diffNames := createFeatureVersionNameMap(diff) - - addedNames := compareStringLists(diffNames, currentNames) - inBothNames := compareStringListsInBoth(diffNames, currentNames) - - different := false - - for _, name := range addedNames { - if diffMap[name].Version == versionfmt.MinVersion { - // MinVersion only makes sense when a Feature is already fixed in some version, - // in which case we would be in the "inBothNames". - continue - } - - currentMap[name] = diffMap[name] - different = true - } - - for _, name := range inBothNames { - fv := diffMap[name] - - if fv.Version == versionfmt.MinVersion { - // MinVersion means that the Feature doesn't affect the Vulnerability anymore. - delete(currentMap, name) - different = true - } else if fv.Version != currentMap[name].Version { - // The version got updated. - currentMap[name] = diffMap[name] - different = true - } - } - - // Convert currentMap to a slice and return it. - var newList []database.FeatureVersion - for _, fv := range currentMap { - newList = append(newList, fv) - } - - return newList, different -} - -func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) { - m := make(map[string]database.FeatureVersion, 0) - s := make([]string, 0, len(features)) - - for i := 0; i < len(features); i++ { - featureVersion := features[i] - m[featureVersion.Feature.Name] = featureVersion - s = append(s, featureVersion.Feature.Name) - } - - return m, s -} - -// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given -// vulnerability with the specified database.FeatureVersion list and uses -// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to -// Vulnerability_Affects_FeatureVersion. -func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error { - defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now()) - - // Insert or find the Features. - // TODO(Quentin-M): Batch me. - var err error - var features []*database.Feature - for i := 0; i < len(fixedIn); i++ { - features = append(features, &fixedIn[i].Feature) - } - for _, feature := range features { - if feature.ID == 0 { - if feature.ID, err = pgSQL.insertFeature(*feature); err != nil { - return err - } - } - } - - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertFeatureVersion to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t := time.Now() - _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertVulnerability", "lock", t) - +func (tx *pgSession) lockFeatureVulnerabilityCache() error { + _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { - tx.Rollback() - return handleError("insertVulnerability.lockVulnerabilityAffects", err) + return handleError("lockVulnerabilityAffects", err) } - - for _, fv := range fixedIn { - var fixedInID int - var created bool - - // Find or create entry in Vulnerability_FixedIn_Feature. - err = tx.QueryRow( - soiVulnerabilityFixedInFeature, - vulnerabilityID, fv.Feature.ID, - &fv.Version, - ).Scan(&created, &fixedInID) - - if err != nil { - return handleError("insertVulnerabilityFixedInFeature", err) - } - - if !created { - // The relationship between the feature and the vulnerability already - // existed, no need to update Vulnerability_Affects_FeatureVersion. - continue - } - - // Insert Vulnerability_Affects_FeatureVersion. - err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Feature.Namespace.VersionFormat, fv.Version) - if err != nil { - return err - } - } - return nil } -func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, versionFormat, fixedInVersion string) error { - // Find every FeatureVersions of the Feature that the vulnerability affects. - // TODO(Quentin-M): LIMIT - rows, err := tx.Query(searchFeatureVersionByFeature, featureID) - if err != nil { - return handleError("searchFeatureVersionByFeature", err) - } - defer rows.Close() - - var affecteds []database.FeatureVersion - for rows.Next() { - var affected database.FeatureVersion - - err := rows.Scan(&affected.ID, &affected.Version) - if err != nil { - return handleError("searchFeatureVersionByFeature.Scan()", err) - } - - cmp, err := versionfmt.Compare(versionFormat, affected.Version, fixedInVersion) - if err != nil { - return err - } - if cmp < 0 { - // The version of the FeatureVersion is lower than the fixed version of this vulnerability, - // thus, this FeatureVersion is affected by it. - affecteds = append(affecteds, affected) - } - } - if err = rows.Err(); err != nil { - return handleError("searchFeatureVersionByFeature.Rows()", err) - } - rows.Close() - - // Insert into Vulnerability_Affects_FeatureVersion. - for _, affected := range affecteds { - // TODO(Quentin-M): Batch me. - _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID) - if err != nil { - return handleError("insertVulnerabilityAffectsFeatureVersion", err) - } - } - - return nil -} - -func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { - defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now()) - - v := database.Vulnerability{ - Name: vulnerabilityName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - FixedIn: fixes, - } - - return pgSQL.insertVulnerability(v, true, true) -} - -func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { - defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now()) - - v := database.Vulnerability{ - Name: vulnerabilityName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - FixedIn: []database.FeatureVersion{ - { - Feature: database.Feature{ - Name: featureName, - Namespace: database.Namespace{ - Name: vulnerabilityNamespace, - }, - }, - Version: versionfmt.MinVersion, - }, - }, - } - - return pgSQL.insertVulnerability(v, true, true) -} - -func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error { - defer observeQueryTime("DeleteVulnerability", "all", time.Now()) - - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return handleError("DeleteVulnerability.Begin()", err) - } - - var vulnerabilityID int - err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID) - if err != nil { - tx.Rollback() - return handleError("removeVulnerability", err) - } - - // Create a notification. - err = createNotification(tx, vulnerabilityID, 0) +// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID +// to affected feature rows and caches them. +func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error { + // Prevent InsertNamespacedFeatures to modify it. + err := tx.lockFeatureVulnerabilityCache() if err != nil { return err } - // Commit transaction. - err = tx.Commit() + vulnIDs := []int64{} + for id := range affected { + vulnIDs = append(vulnIDs, id) + } + + rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) if err != nil { - tx.Rollback() - return handleError("DeleteVulnerability.Commit()", err) + return handleError("searchVulnerabilityPotentialAffected", err) + } + + defer rows.Close() + + relation := []affectRelation{} + for rows.Next() { + var ( + vulnID int64 + nsfID int64 + fVersion string + addedBy int64 + ) + + err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) + if err != nil { + return handleError("searchVulnerabilityPotentialAffected", err) + } + + candidate, ok := affected[vulnID].rows[addedBy] + + if !ok { + return errors.New("vulnerability affected feature not found") + } + + if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat, + fVersion, + candidate.AffectedVersion); err == nil { + if in { + relation = append(relation, + affectRelation{ + vulnerabilityID: vulnID, + namespacedFeatureID: nsfID, + addedBy: addedBy, + }) + } + } else { + return err + } + } + + //TODO(Sida): Change to bulk insert. + for _, r := range relation { + result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) + if err != nil { + return handleError("insertVulnerabilityAffectedNamespacedFeature", err) + } + + if num, err := result.RowsAffected(); err == nil { + if num <= 0 { + return errors.New("Nothing cached in database") + } + } else { + return err + } + } + + log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation)) + return nil +} + +func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { + defer observeQueryTime("DeleteVulnerability", "all", time.Now()) + + vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities) + if err != nil { + return err + } + + if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil { + return err + } + return nil +} + +func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error { + if len(vulnerabilityIDs) == 0 { + return nil + } + + // Prevent InsertNamespacedFeatures to modify it. + err := tx.lockFeatureVulnerabilityCache() + if err != nil { + return err + } + + //TODO(Sida): Make a nicer interface for bulk inserting. + keys := make([]interface{}, len(vulnerabilityIDs)) + for i, id := range vulnerabilityIDs { + keys[i] = id + } + + _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) + if err != nil { + return handleError("removeVulnerabilityAffectedFeature", err) } return nil } + +func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) { + var ( + vulnID sql.NullInt64 + vulnIDs []int64 + ) + + // mark vulnerabilities deleted + stmt, err := tx.Prepare(removeVulnerability) + if err != nil { + return nil, handleError("removeVulnerability", err) + } + + defer stmt.Close() + for _, vuln := range vulnerabilities { + err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) + if err != nil { + return nil, handleError("removeVulnerability", err) + } + if !vulnID.Valid { + return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) + } + vulnIDs = append(vulnIDs, vulnID.Int64) + } + return vulnIDs, nil +} + +// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in +// database and the order of output array is not guaranteed. +func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return tx.findVulnerabilityIDs(vulnIDs, true) +} + +func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return tx.findVulnerabilityIDs(vulnIDs, false) +} + +func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { + if len(vulnIDs) == 0 { + return nil, nil + } + + vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{} + keys := make([]interface{}, len(vulnIDs)*2) + for i, vulnID := range vulnIDs { + keys[i*2] = vulnID.Name + keys[i*2+1] = vulnID.Namespace + vulnIDMap[vulnID] = sql.NullInt64{} + } + + query := "" + if withLatestDeleted { + query = querySearchLastDeletedVulnerabilityID(len(vulnIDs)) + } else { + query = querySearchNotDeletedVulnerabilityID(len(vulnIDs)) + } + + rows, err := tx.Query(query, keys...) + if err != nil { + return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err) + } + + defer rows.Close() + var ( + id sql.NullInt64 + vulnID database.VulnerabilityID + ) + for rows.Next() { + err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) + if err != nil { + return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) + } + vulnIDMap[vulnID] = id + } + + ids := make([]sql.NullInt64, len(vulnIDs)) + for i, v := range vulnIDs { + ids[i] = vulnIDMap[v] + } + + return ids, nil +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go index 61d835bb..9fe2c23b 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability_test.go @@ -15,282 +15,329 @@ package pgsql import ( - "reflect" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" - "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" ) -func TestFindVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("FindVulnerability", true) - if err != nil { - t.Error(err) - return +func TestInsertVulnerabilities(t *testing.T) { + store, tx := openSessionForTest(t, "InsertVulnerabilities", true) + + ns1 := database.Namespace{ + Name: "name", + VersionFormat: "random stuff", } - defer datastore.Close() - // Find a vulnerability that does not exist. - _, err = datastore.FindVulnerability("", "") - assert.Equal(t, commonerr.ErrNotFound, err) + ns2 := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } - // Find a normal vulnerability. + // invalid vulnerability v1 := database.Vulnerability{ - Name: "CVE-OPENSSL-1-DEB7", - Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", - Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", - Severity: database.HighSeverity, - Namespace: database.Namespace{ - Name: "debian:7", - VersionFormat: dpkg.ParserName, - }, - FixedIn: []database.FeatureVersion{ - { - Feature: database.Feature{Name: "openssl"}, - Version: "2.0", - }, - { - Feature: database.Feature{Name: "libssl"}, - Version: "1.9-abc", - }, - }, + Name: "invalid", + Namespace: ns1, } - v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - if assert.Nil(t, err) { - equalsVuln(t, &v1, &v1f) + vwa1 := database.VulnerabilityWithAffected{ + Vulnerability: v1, } - - // Find a vulnerability that has no link, no severity and no FixedIn. + // valid vulnerability v2 := database.Vulnerability{ - Name: "CVE-NOPE", - Description: "A vulnerability affecting nothing", - Namespace: database.Namespace{ - Name: "debian:7", + Name: "valid", + Namespace: ns2, + Severity: database.UnknownSeverity, + } + + vwa2 := database.VulnerabilityWithAffected{ + Vulnerability: v2, + } + + // empty + err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{}) + assert.Nil(t, err) + + // invalid content: vwa1 is invalid + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2}) + assert.NotNil(t, err) + + tx = restartSession(t, store, tx, false) + // invalid content: duplicated input + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2}) + assert.NotNil(t, err) + + tx = restartSession(t, store, tx, false) + // valid content + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + assert.Nil(t, err) + + tx = restartSession(t, store, tx, true) + // ensure the content is in database + vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) + if assert.Nil(t, err) && assert.Len(t, vulns, 1) { + assert.True(t, vulns[0].Valid) + } + + tx = restartSession(t, store, tx, false) + // valid content: vwa2 removed and inserted + err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) + assert.Nil(t, err) + + err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + assert.Nil(t, err) + + closeTest(t, store, tx) +} + +func TestCachingVulnerable(t *testing.T) { + datastore, tx := openSessionForTest(t, "CachingVulnerable", true) + defer closeTest(t, datastore, tx) + + ns := database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, + } + + f := database.NamespacedFeature{ + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", VersionFormat: dpkg.ParserName, }, - Severity: database.UnknownSeverity, + Namespace: ns, } - v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE") - if assert.Nil(t, err) { - equalsVuln(t, &v2, &v2f) - } -} - -func TestDeleteVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("InsertVulnerability", true) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Delete non-existing Vulnerability. - err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7") - assert.Equal(t, commonerr.ErrNotFound, err) - err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1") - assert.Equal(t, commonerr.ErrNotFound, err) - - // Delete Vulnerability. - err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - if assert.Nil(t, err) { - _, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") - assert.Equal(t, commonerr.ErrNotFound, err) - } -} - -func TestInsertVulnerability(t *testing.T) { - datastore, err := openDatabaseForTest("InsertVulnerability", false) - if err != nil { - t.Error(err) - return - } - defer datastore.Close() - - // Create some data. - n1 := database.Namespace{ - Name: "TestInsertVulnerabilityNamespace1", - VersionFormat: dpkg.ParserName, - } - n2 := database.Namespace{ - Name: "TestInsertVulnerabilityNamespace2", - VersionFormat: dpkg.ParserName, - } - - f1 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion1", - Namespace: n1, + vuln := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY", + Namespace: ns, + Severity: database.HighSeverity, }, - Version: "1.0", - } - f2 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion1", - Namespace: n2, + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: "openssl", + AffectedVersion: "2.0", + FixedInVersion: "2.1", + }, }, - Version: "1.0", - } - f3 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion2", - }, - Version: versionfmt.MaxVersion, - } - f4 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion2", - }, - Version: "1.4", - } - f5 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion3", - }, - Version: "1.5", - } - f6 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion4", - }, - Version: "0.1", - } - f7 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion5", - }, - Version: versionfmt.MaxVersion, - } - f8 := database.FeatureVersion{ - Feature: database.Feature{ - Name: "TestInsertVulnerabilityFeatureVersion5", - }, - Version: versionfmt.MinVersion, } - // Insert invalid vulnerabilities. - for _, vulnerability := range []database.Vulnerability{ - { - Name: "", - Namespace: n1, - FixedIn: []database.FeatureVersion{f1}, - Severity: database.UnknownSeverity, + vuln2 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY2", + Namespace: ns, + Severity: database.HighSeverity, }, - { - Name: "TestInsertVulnerability0", - Namespace: database.Namespace{}, - FixedIn: []database.FeatureVersion{f1}, - Severity: database.UnknownSeverity, + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: "openssl", + AffectedVersion: "2.1", + FixedInVersion: "2.2", + }, }, - { - Name: "TestInsertVulnerability0-", - Namespace: database.Namespace{}, - FixedIn: []database.FeatureVersion{f1}, + } + + vulnFixed1 := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY", + Namespace: ns, + Severity: database.HighSeverity, }, - { - Name: "TestInsertVulnerability0", - Namespace: n1, - FixedIn: []database.FeatureVersion{f2}, - Severity: database.UnknownSeverity, + FixedInVersion: "2.1", + } + + vulnFixed2 := database.VulnerabilityWithFixedIn{ + Vulnerability: database.Vulnerability{ + Name: "CVE-YAY2", + Namespace: ns, + Severity: database.HighSeverity, }, - } { - err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) - assert.Error(t, err) + FixedInVersion: "2.2", } - // Insert a simple vulnerability and find it. - v1meta := make(map[string]interface{}) - v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1" - v1meta["TestInsertVulnerabilityMetadata2"] = struct { - Test string - }{ - Test: "TestInsertVulnerabilityMetadataValue1", + if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) { + t.FailNow() } - v1 := database.Vulnerability{ - Name: "TestInsertVulnerability1", - Namespace: n1, - FixedIn: []database.FeatureVersion{f1, f3, f6, f7}, - Severity: database.LowSeverity, - Description: "TestInsertVulnerabilityDescription1", - Link: "TestInsertVulnerabilityLink1", - Metadata: v1meta, - } - err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) - if assert.Nil(t, err) { - v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) - if assert.Nil(t, err) { - equalsVuln(t, &v1, &v1f) - } - } - - // Update vulnerability. - v1.Description = "TestInsertVulnerabilityLink2" - v1.Link = "TestInsertVulnerabilityLink2" - v1.Severity = database.HighSeverity - // Update f3 in f4, add fixed in f5, add fixed in f6 which already exists, - // removes fixed in f7 by adding f8 which is f7 but with MinVersion, and - // add fixed by f5 a second time (duplicated). - v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8, f5} - - err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) - if assert.Nil(t, err) { - v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) - if assert.Nil(t, err) { - // Remove f8 from the struct for comparison as it was just here to cancel f7. - // Remove one of the f5 too as it was twice in the struct but the database - // implementation should have dedup'd it. - v1.FixedIn = v1.FixedIn[:len(v1.FixedIn)-2] - - // We already had f1 before the update. - // Add it to the struct for comparison. - v1.FixedIn = append(v1.FixedIn, f1) - - equalsVuln(t, &v1, &v1f) - } - } -} - -func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) { - assert.Equal(t, expected.Name, actual.Name) - assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name) - assert.Equal(t, expected.Description, actual.Description) - assert.Equal(t, expected.Link, actual.Link) - assert.Equal(t, expected.Severity, actual.Severity) - assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata)) - - if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) { - for _, actualFeatureVersion := range actual.FixedIn { - found := false - for _, expectedFeatureVersion := range expected.FixedIn { - if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name { - found = true - - assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name) - assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version) + r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f}) + assert.Nil(t, err) + assert.Len(t, r, 1) + for _, anf := range r { + if assert.True(t, anf.Valid) && assert.Len(t, anf.AffectedBy, 2) { + for _, a := range anf.AffectedBy { + if a.Name == "CVE-YAY" { + assert.Equal(t, vulnFixed1, a) + } else if a.Name == "CVE-YAY2" { + assert.Equal(t, vulnFixed2, a) + } else { + t.FailNow() } } - if !found { - t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name) + } + } +} + +func TestFindVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindVulnerabilities", true) + defer closeTest(t, datastore, tx) + + vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-NOPE", Namespace: "debian:7"}, + {Name: "CVE-NOT HERE"}, + }) + + ns := database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + } + + expectedExisting := []database.VulnerabilityWithAffected{ + { + Vulnerability: database.Vulnerability{ + Namespace: ns, + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, + }, + Affected: []database.AffectedFeature{ + { + FeatureName: "openssl", + AffectedVersion: "2.0", + FixedInVersion: "2.0", + Namespace: ns, + }, + { + FeatureName: "libssl", + AffectedVersion: "1.9-abc", + FixedInVersion: "1.9-abc", + Namespace: ns, + }, + }, + }, + { + Vulnerability: database.Vulnerability{ + Namespace: ns, + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, + }, + }, + } + + expectedExistingMap := map[database.VulnerabilityID]database.VulnerabilityWithAffected{} + for _, v := range expectedExisting { + expectedExistingMap[database.VulnerabilityID{Name: v.Name, Namespace: v.Namespace.Name}] = v + } + + nonexisting := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{Name: "CVE-NOT HERE"}, + } + + if assert.Nil(t, err) { + for _, v := range vuln { + if v.Valid { + key := database.VulnerabilityID{ + Name: v.Name, + Namespace: v.Namespace.Name, + } + + expected, ok := expectedExistingMap[key] + if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) { + assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) + } + } else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) { + t.FailNow() + } + } + } + + // same vulnerability + r, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + }) + + if assert.Nil(t, err) { + for _, vuln := range r { + if assert.True(t, vuln.Valid) { + expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}] + assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) } } } } -func TestStringComparison(t *testing.T) { - cmp := compareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"}) - assert.Len(t, cmp, 1) - assert.NotContains(t, cmp, "a") - assert.Contains(t, cmp, "b") +func TestDeleteVulnerabilities(t *testing.T) { + datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true) + defer closeTest(t, datastore, tx) - cmp = compareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"}) - assert.Len(t, cmp, 2) - assert.NotContains(t, cmp, "b") - assert.Contains(t, cmp, "a") - assert.Contains(t, cmp, "c") + remove := []database.VulnerabilityID{} + // empty case + assert.Nil(t, tx.DeleteVulnerabilities(remove)) + // invalid case + remove = append(remove, database.VulnerabilityID{}) + assert.NotNil(t, tx.DeleteVulnerabilities(remove)) + + // valid case + validRemove := []database.VulnerabilityID{ + {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, + {Name: "CVE-NOPE", Namespace: "debian:7"}, + } + + assert.Nil(t, tx.DeleteVulnerabilities(validRemove)) + vuln, err := tx.FindVulnerabilities(validRemove) + if assert.Nil(t, err) { + for _, v := range vuln { + assert.False(t, v.Valid) + } + } +} + +func TestFindVulnerabilityIDs(t *testing.T) { + store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true) + defer closeTest(t, store, tx) + + ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) + if assert.Nil(t, err) { + if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, ids[0].Int64)) { + assert.Fail(t, "") + } + } + + ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) + if assert.Nil(t, err) { + if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, ids[0].Int64)) { + assert.Fail(t, "") + } + } +} + +func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) +} + +func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.AffectedFeature]bool{} + for _, i := range expected { + has[i] = false + } + for _, i := range actual { + if visited, ok := has[i]; !ok { + return false + } else if visited { + return false + } + has[i] = true + } + return true + } + return false }