From 0c1b80b2ed54dcbe227f7233468a5bdc66d4a17e Mon Sep 17 00:00:00 2001 From: Sida Chen Date: Mon, 8 Oct 2018 11:11:30 -0400 Subject: [PATCH] pgsql: Implement database queries for detector relationship * Refactor layer and ancestry * Add tests * Fix bugs introduced when the queries were moved --- database/pgsql/ancestry.go | 346 +++++++++------- database/pgsql/ancestry_test.go | 263 +++++------- database/pgsql/complex_test.go | 4 +- database/pgsql/detector.go | 198 +++++++++ database/pgsql/detector_test.go | 119 ++++++ database/pgsql/feature.go | 17 +- database/pgsql/feature_test.go | 23 +- database/pgsql/layer.go | 383 ++++++++++-------- database/pgsql/layer_test.go | 241 +++++++---- .../pgsql/migrations/00001_initial_schema.go | 8 +- database/pgsql/namespace.go | 5 - database/pgsql/namespace_test.go | 39 -- database/pgsql/notification.go | 16 +- database/pgsql/notification_test.go | 225 +++++----- database/pgsql/pgsql.go | 3 +- database/pgsql/pgsql_test.go | 8 +- database/pgsql/queries.go | 22 +- database/pgsql/testdata/data.sql | 161 ++++---- database/pgsql/testutil.go | 263 ++++++++++++ database/pgsql/vulnerability_test.go | 4 +- pkg/testutil/testutil.go | 285 +++++++++++++ 21 files changed, 1766 insertions(+), 867 deletions(-) create mode 100644 database/pgsql/detector.go create mode 100644 database/pgsql/detector_test.go create mode 100644 database/pgsql/testutil.go create mode 100644 pkg/testutil/testutil.go diff --git a/database/pgsql/ancestry.go b/database/pgsql/ancestry.go index 36d8fcde..fa0c0ad5 100644 --- a/database/pgsql/ancestry.go +++ b/database/pgsql/ancestry.go @@ -14,15 +14,17 @@ const ( insertAncestry = ` INSERT INTO ancestry (name) VALUES ($1) RETURNING id` - searchAncestryLayer = ` - SELECT layer.hash, layer.id, ancestry_layer.ancestry_index + findAncestryLayerHashes = ` + SELECT layer.hash, ancestry_layer.ancestry_index FROM layer, ancestry_layer WHERE ancestry_layer.ancestry_id = $1 AND ancestry_layer.layer_id = layer.id ORDER BY ancestry_layer.ancestry_index ASC` - searchAncestryFeatures = ` - SELECT namespace.name, namespace.version_format, feature.name, feature.version, feature.version_format, ancestry_layer.ancestry_index + findAncestryFeatures = ` + SELECT namespace.name, namespace.version_format, feature.name, + feature.version, feature.version_format, ancestry_layer.ancestry_index, + ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id FROM namespace, feature, namespaced_feature, ancestry_layer, ancestry_feature WHERE ancestry_layer.ancestry_id = $1 AND ancestry_feature.ancestry_layer_id = ancestry_layer.id @@ -30,203 +32,220 @@ const ( AND namespaced_feature.feature_id = feature.id AND namespaced_feature.namespace_id = namespace.id` - searchAncestry = `SELECT id FROM ancestry WHERE name = $1` - removeAncestry = `DELETE FROM ancestry WHERE name = $1` - insertAncestryLayer = ` - INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES - ($1, $2, (SELECT layer.id FROM layer WHERE hash = $3 LIMIT 1)) + findAncestryID = `SELECT id FROM ancestry WHERE name = $1` + removeAncestry = `DELETE FROM ancestry WHERE name = $1` + insertAncestryLayers = ` + INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3) RETURNING id` - insertAncestryLayerFeature = ` + insertAncestryFeatures = ` INSERT INTO ancestry_feature (ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES ($1, $2, $3, $4)` ) -type ancestryLayerWithID struct { - database.AncestryLayer +func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) { + var ( + ancestry = database.Ancestry{Name: name} + err error + ) + + id, ok, err := tx.findAncestryID(name) + if !ok || err != nil { + return ancestry, ok, err + } + + if ancestry.By, err = tx.findAncestryDetectors(id); err != nil { + return ancestry, false, err + } + + if ancestry.Layers, err = tx.findAncestryLayers(id); err != nil { + return ancestry, false, err + } - layerID int64 + return ancestry, true, nil } func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry) error { - if ancestry.Name == "" { - log.Error("Empty ancestry name is not allowed") - return commonerr.NewBadRequestError("could not insert an ancestry with empty name") + if !ancestry.Valid() { + return database.ErrInvalidParameters } - if len(ancestry.Layers) == 0 { - log.Error("Empty ancestry is not allowed") - return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers") + if err := tx.removeAncestry(ancestry.Name); err != nil { + return err } - if err := tx.deleteAncestry(ancestry.Name); err != nil { + id, err := tx.insertAncestry(ancestry.Name) + if err != nil { return err } - var ancestryID int64 - if err := tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID); err != nil { - if isErrUniqueViolation(err) { - return handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)")) - } - return handleError("insertAncestry", err) + detectorIDs, err := tx.findDetectorIDs(ancestry.By) + if err != nil { + return err } - if err := tx.insertAncestryLayers(ancestryID, ancestry.Layers); err != nil { + // insert ancestry metadata + if err := tx.insertAncestryDetectors(id, detectorIDs); err != nil { return err } - return tx.persistProcessors(persistAncestryLister, - "persistAncestryLister", - persistAncestryDetector, - "persistAncestryDetector", - ancestryID, ancestry.ProcessedBy) -} - -func (tx *pgSession) findAncestryID(name string) (int64, bool, error) { - var id sql.NullInt64 - if err := tx.QueryRow(searchAncestry, name).Scan(&id); err != nil { - if err == sql.ErrNoRows { - return 0, false, nil - } - - return 0, false, handleError("searchAncestry", err) + layers := make([]string, 0, len(ancestry.Layers)) + for _, layer := range ancestry.Layers { + layers = append(layers, layer.Hash) } - return id.Int64, true, nil -} + layerIDs, ok, err := tx.findLayerIDs(layers) + if err != nil { + return err + } -func (tx *pgSession) findAncestryProcessors(id int64) (database.Processors, error) { - var ( - processors database.Processors - err error - ) + if !ok { + log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.") + return database.ErrMissingEntities + } - if processors.Detectors, err = tx.findProcessors(searchAncestryDetectors, id); err != nil { - return processors, handleError("searchAncestryDetectors", err) + ancestryLayerIDs, err := tx.insertAncestryLayers(id, layerIDs) + if err != nil { + return err } - if processors.Listers, err = tx.findProcessors(searchAncestryListers, id); err != nil { - return processors, handleError("searchAncestryListers", err) + for i, id := range ancestryLayerIDs { + if err := tx.insertAncestryFeatures(id, ancestry.Layers[i]); err != nil { + return err + } } - return processors, err + return nil } -func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) { - var ( - ancestry = database.Ancestry{Name: name} - err error - ) +func (tx *pgSession) insertAncestry(name string) (int64, error) { + var id int64 + err := tx.QueryRow(insertAncestry, name).Scan(&id) + if err != nil { + if isErrUniqueViolation(err) { + return 0, handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)")) + } - id, ok, err := tx.findAncestryID(name) - if !ok || err != nil { - return ancestry, ok, err + return 0, handleError("insertAncestry", err) } - if ancestry.ProcessedBy, err = tx.findAncestryProcessors(id); err != nil { - return ancestry, false, err - } + log.WithFields(log.Fields{"ancestry": name, "id": id}).Debug("database: inserted ancestry") + return id, nil +} - if ancestry.Layers, err = tx.findAncestryLayers(id); err != nil { - return ancestry, false, err +func (tx *pgSession) findAncestryID(name string) (int64, bool, error) { + var id sql.NullInt64 + if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil { + if err == sql.ErrNoRows { + return 0, false, nil + } + + return 0, false, handleError("findAncestryID", err) } - return ancestry, true, nil + return id.Int64, true, nil } -func (tx *pgSession) deleteAncestry(name string) error { +func (tx *pgSession) removeAncestry(name string) error { result, err := tx.Exec(removeAncestry, name) if err != nil { return handleError("removeAncestry", err) } - _, err = result.RowsAffected() + affected, err := result.RowsAffected() if err != nil { return handleError("removeAncestry", err) } + if affected != 0 { + log.WithField("ancestry", name).Debug("removed ancestry") + } + return nil } -func (tx *pgSession) findProcessors(query string, id int64) ([]string, error) { - var ( - processors []string - processor string - ) +func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) { + detectors, err := tx.findAllDetectors() + if err != nil { + return nil, err + } - rows, err := tx.Query(query, id) + layerMap, err := tx.findAncestryLayerHashes(id) if err != nil { - if err == sql.ErrNoRows { - return nil, nil - } + return nil, err + } + log.WithField("map", layerMap).Debug("found layer hashes") + featureMap, err := tx.findAncestryFeatures(id, detectors) + if err != nil { return nil, err } - for rows.Next() { - if err := rows.Scan(&processor); err != nil { - return nil, err + layers := make([]database.AncestryLayer, len(layerMap)) + for index, layer := range layerMap { + // index MUST match the ancestry layer slice index. + if layers[index].Hash == "" && len(layers[index].Features) == 0 { + layers[index] = database.AncestryLayer{ + Hash: layer, + Features: featureMap[index], + } + } else { + log.WithFields(log.Fields{ + "ancestry ID": id, + "duplicated ancestry index": index, + }).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed") + return nil, database.ErrInconsistent } - - processors = append(processors, processor) } - return processors, nil + return layers, nil } -func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) { - var ( - err error - rows *sql.Rows - // layer index -> Ancestry Layer + Layer ID - layers = map[int64]ancestryLayerWithID{} - // layer index -> layer-wise features - features = map[int64][]database.NamespacedFeature{} - ancestryLayers []database.AncestryLayer - ) - - // retrieve ancestry layer metadata - if rows, err = tx.Query(searchAncestryLayer, id); err != nil { - return nil, handleError("searchAncestryLayer", err) +func (tx *pgSession) findAncestryLayerHashes(ancestryID int64) (map[int64]string, error) { + // retrieve layer indexes and hashes + rows, err := tx.Query(findAncestryLayerHashes, ancestryID) + if err != nil { + return nil, handleError("findAncestryLayerHashes", err) } + layerHashes := map[int64]string{} for rows.Next() { var ( - layer database.AncestryLayer - index sql.NullInt64 - id sql.NullInt64 + hash string + index int64 ) - if err = rows.Scan(&layer.Hash, &id, &index); err != nil { - return nil, handleError("searchAncestryLayer", err) + if err = rows.Scan(&hash, &index); err != nil { + return nil, handleError("findAncestryLayerHashes", err) } - if !index.Valid || !id.Valid { - panic("null ancestry ID or ancestry index violates database constraints") - } - - if _, ok := layers[index.Int64]; ok { + if _, ok := layerHashes[index]; ok { // one ancestry index should correspond to only one layer return nil, database.ErrInconsistent } - layers[index.Int64] = ancestryLayerWithID{layer, id.Int64} + layerHashes[index] = hash } - for _, layer := range layers { - if layer.ProcessedBy, err = tx.findLayerProcessors(layer.layerID); err != nil { - return nil, err - } - } + return layerHashes, nil +} +func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMap) (map[int64][]database.AncestryFeature, error) { + // ancestry_index -> ancestry features + featureMap := make(map[int64][]database.AncestryFeature) // retrieve ancestry layer's namespaced features - if rows, err = tx.Query(searchAncestryFeatures, id); err != nil { - return nil, handleError("searchAncestryFeatures", err) + rows, err := tx.Query(findAncestryFeatures, ancestryID) + if err != nil { + return nil, handleError("findAncestryFeatures", err) } + defer rows.Close() + for rows.Next() { var ( - feature database.NamespacedFeature + featureDetectorID int64 + namespaceDetectorID int64 + feature database.NamespacedFeature // index is used to determine which layer the feature belongs to. index sql.NullInt64 ) @@ -238,8 +257,10 @@ func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, err &feature.Feature.Version, &feature.Feature.VersionFormat, &index, + &featureDetectorID, + &namespaceDetectorID, ); err != nil { - return nil, handleError("searchAncestryFeatures", err) + return nil, handleError("findAncestryFeatures", err) } if feature.Feature.VersionFormat != feature.Namespace.VersionFormat { @@ -248,59 +269,88 @@ func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, err return nil, database.ErrInconsistent } - features[index.Int64] = append(features[index.Int64], feature) - } + fDetector, ok := detectors.byID[featureDetectorID] + if !ok { + return nil, database.ErrInconsistent + } + + nsDetector, ok := detectors.byID[namespaceDetectorID] + if !ok { + return nil, database.ErrInconsistent + } - for index, layer := range layers { - layer.DetectedFeatures = features[index] - ancestryLayers = append(ancestryLayers, layer.AncestryLayer) + featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{ + NamespacedFeature: feature, + FeatureBy: fDetector, + NamespaceBy: nsDetector, + }) } - return ancestryLayers, nil + return featureMap, nil } // insertAncestryLayers inserts the ancestry layers along with its content into // the database. The layers are 0 based indexed in the original order. -func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []database.AncestryLayer) error { - //TODO(Sida): use bulk insert. - stmt, err := tx.Prepare(insertAncestryLayer) +func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []int64) ([]int64, error) { + stmt, err := tx.Prepare(insertAncestryLayers) if err != nil { - return handleError("insertAncestryLayer", err) + return nil, handleError("insertAncestryLayers", err) } - ancestryLayerIDs := []sql.NullInt64{} - for index, layer := range layers { + ancestryLayerIDs := []int64{} + for index, layerID := range layers { var ancestryLayerID sql.NullInt64 - if err := stmt.QueryRow(ancestryID, index, layer.Hash).Scan(&ancestryLayerID); err != nil { - return handleError("insertAncestryLayer", commonerr.CombineErrors(err, stmt.Close())) + if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil { + return nil, handleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close())) + } + + if !ancestryLayerID.Valid { + return nil, database.ErrInconsistent } - ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID) + ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64) } if err := stmt.Close(); err != nil { - return handleError("Failed to close insertAncestryLayer statement", err) + return nil, handleError("insertAncestryLayers", err) } - stmt, err = tx.Prepare(insertAncestryLayerFeature) - defer stmt.Close() + return ancestryLayerIDs, nil +} - for i, layer := range layers { - var ( - nsFeatureIDs []sql.NullInt64 - layerID = ancestryLayerIDs[i] - ) +func (tx *pgSession) insertAncestryFeatures(ancestryLayerID int64, layer database.AncestryLayer) error { + detectors, err := tx.findAllDetectors() + if err != nil { + return err + } - if nsFeatureIDs, err = tx.findNamespacedFeatureIDs(layer.DetectedFeatures); err != nil { - return err + nsFeatureIDs, err := tx.findNamespacedFeatureIDs(layer.GetFeatures()) + if err != nil { + return err + } + + // find the detectors for each feature + stmt, err := tx.Prepare(insertAncestryFeatures) + if err != nil { + return handleError("insertAncestryFeatures", err) + } + + defer stmt.Close() + + for index, id := range nsFeatureIDs { + namespaceDetectorID, ok := detectors.byValue[layer.Features[index].NamespaceBy] + if !ok { + return database.ErrMissingEntities } - for _, id := range nsFeatureIDs { - if _, err := stmt.Exec(layerID, id); err != nil { - return handleError("insertAncestryLayerFeature", commonerr.CombineErrors(err, stmt.Close())) - } + featureDetectorID, ok := detectors.byValue[layer.Features[index].FeatureBy] + if !ok { + return database.ErrMissingEntities } + if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil { + return handleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close())) + } } return nil diff --git a/database/pgsql/ancestry_test.go b/database/pgsql/ancestry_test.go index 9d1f1c5c..6cceb718 100644 --- a/database/pgsql/ancestry_test.go +++ b/database/pgsql/ancestry_test.go @@ -15,198 +15,125 @@ package pgsql import ( - "sort" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/testutil" ) -func TestUpsertAncestry(t *testing.T) { - store, tx := openSessionForTest(t, "UpsertAncestry", true) - defer closeTest(t, store, tx) - a1 := database.Ancestry{ - Name: "a1", - Layers: []database.AncestryLayer{ - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-N", +var upsertAncestryTests = []struct { + in *database.Ancestry + err string + title string +}{ + { + title: "ancestry with invalid layer", + in: &database.Ancestry{ + Name: "a1", + Layers: []database.AncestryLayer{ + { + Hash: "layer-non-existing", }, }, }, - } - - a2 := database.Ancestry{} - - a3 := database.Ancestry{ - Name: "a", - Layers: []database.AncestryLayer{ - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-0", - }, - }, + err: database.ErrMissingEntities.Error(), + }, + { + title: "ancestry with invalid name", + in: &database.Ancestry{}, + err: database.ErrInvalidParameters.Error(), + }, + { + title: "new valid ancestry", + in: &database.Ancestry{ + Name: "a", + Layers: []database.AncestryLayer{{Hash: "layer-0"}}, }, - } - - a4 := database.Ancestry{ - Name: "a", - Layers: []database.AncestryLayer{ - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-1", - }, + }, + { + title: "ancestry with invalid feature", + in: &database.Ancestry{ + Name: "a", + By: []database.Detector{realDetectors[1], realDetectors[2]}, + Layers: []database.AncestryLayer{{Hash: "layer-1", Features: []database.AncestryFeature{ + {fakeNamespacedFeatures[1], fakeDetector[1], fakeDetector[2]}, + }}}, + }, + err: database.ErrMissingEntities.Error(), + }, + { + title: "replace old ancestry", + in: &database.Ancestry{ + Name: "a", + By: []database.Detector{realDetectors[1], realDetectors[2]}, + Layers: []database.AncestryLayer{ + {"layer-1", []database.AncestryFeature{{realNamespacedFeatures[1], realDetectors[2], realDetectors[1]}}}, }, }, - } - - f1 := database.Feature{ - Name: "wechat", - Version: "0.5", - VersionFormat: "dpkg", - } - - // not in database - f2 := database.Feature{ - Name: "wechat", - Version: "0.6", - VersionFormat: "dpkg", - } - - n1 := database.Namespace{ - Name: "debian:7", - VersionFormat: "dpkg", - } - - p := database.Processors{ - Listers: []string{"dpkg", "non-existing"}, - Detectors: []string{"os-release", "non-existing"}, - } - - nsf1 := database.NamespacedFeature{ - Namespace: n1, - Feature: f1, - } - - // not in database - nsf2 := database.NamespacedFeature{ - Namespace: n1, - Feature: f2, - } - - a4.ProcessedBy = p - // invalid case - assert.NotNil(t, tx.UpsertAncestry(a1)) - assert.NotNil(t, tx.UpsertAncestry(a2)) - // valid case - assert.Nil(t, tx.UpsertAncestry(a3)) - a4.Layers[0].DetectedFeatures = []database.NamespacedFeature{nsf1, nsf2} - // replace invalid case - assert.NotNil(t, tx.UpsertAncestry(a4)) - a4.Layers[0].DetectedFeatures = []database.NamespacedFeature{nsf1} - // replace valid case - assert.Nil(t, tx.UpsertAncestry(a4)) - // validate - ancestry, ok, err := tx.FindAncestry("a") - assert.Nil(t, err) - assert.True(t, ok) - assertAncestryEqual(t, a4, ancestry) -} - -func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool { - sort.Strings(expected.Detectors) - sort.Strings(actual.Detectors) - sort.Strings(expected.Listers) - sort.Strings(actual.Listers) - return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers) + }, } -func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool { - assert.Equal(t, expected.Name, actual.Name) - assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) - if assert.Equal(t, len(expected.Layers), len(actual.Layers)) { - for index, layer := range expected.Layers { - if !assertAncestryLayerEqual(t, layer, actual.Layers[index]) { - return false +func TestUpsertAncestry(t *testing.T) { + store, tx := openSessionForTest(t, "UpsertAncestry", true) + defer closeTest(t, store, tx) + for _, test := range upsertAncestryTests { + t.Run(test.title, func(t *testing.T) { + err := tx.UpsertAncestry(*test.in) + if test.err != "" { + assert.EqualError(t, err, test.err, "unexpected error") + return } - } - return true + assert.Nil(t, err) + actual, ok, err := tx.FindAncestry(test.in.Name) + assert.Nil(t, err) + assert.True(t, ok) + testutil.AssertAncestryEqual(t, test.in, &actual) + }) } - return false } -func assertAncestryLayerEqual(t *testing.T, expected database.AncestryLayer, actual database.AncestryLayer) bool { - return assertLayerEqual(t, expected.LayerMetadata, actual.LayerMetadata) && - assertNamespacedFeatureEqual(t, expected.DetectedFeatures, actual.DetectedFeatures) +var findAncestryTests = []struct { + title string + in string + + ancestry *database.Ancestry + err string + ok bool +}{ + { + title: "missing ancestry", + in: "ancestry-non", + err: "", + ancestry: nil, + ok: false, + }, + { + title: "valid ancestry", + in: "ancestry-2", + err: "", + ok: true, + ancestry: takeAncestryPointerFromMap(realAncestries, 2), + }, } func TestFindAncestry(t *testing.T) { store, tx := openSessionForTest(t, "FindAncestry", true) defer closeTest(t, store, tx) + for _, test := range findAncestryTests { + t.Run(test.title, func(t *testing.T) { + ancestry, ok, err := tx.FindAncestry(test.in) + if test.err != "" { + assert.EqualError(t, err, test.err, "unexpected error") + return + } - // invalid - _, ok, err := tx.FindAncestry("ancestry-non") - if assert.Nil(t, err) { - assert.False(t, ok) - } - - expected := database.Ancestry{ - Name: "ancestry-2", - ProcessedBy: database.Processors{ - Detectors: []string{"os-release"}, - Listers: []string{"dpkg"}, - }, - Layers: []database.AncestryLayer{ - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-0", - }, - DetectedFeatures: []database.NamespacedFeature{ - { - Namespace: database.Namespace{ - Name: "debian:7", - VersionFormat: "dpkg", - }, - Feature: database.Feature{ - Name: "wechat", - Version: "0.5", - VersionFormat: "dpkg", - }, - }, - { - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: "dpkg", - }, - Feature: database.Feature{ - Name: "openssl", - Version: "1.0", - VersionFormat: "dpkg", - }, - }, - }, - }, - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-1", - }, - }, - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-2", - }, - }, - { - LayerMetadata: database.LayerMetadata{ - Hash: "layer-3b", - }, - }, - }, - } - // valid - ancestry, ok, err := tx.FindAncestry("ancestry-2") - if assert.Nil(t, err) && assert.True(t, ok) { - assertAncestryEqual(t, expected, ancestry) + assert.Nil(t, err) + assert.Equal(t, test.ok, ok) + if test.ok { + testutil.AssertAncestryEqual(t, test.ancestry, &ancestry) + } + }) } } diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index 07d6f55f..de8b0f20 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -220,7 +220,7 @@ func TestCaching(t *testing.T) { actualAffectedNames = append(actualAffectedNames, s.Name) } - assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) - assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) + assert.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0) + assert.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) } } diff --git a/database/pgsql/detector.go b/database/pgsql/detector.go new file mode 100644 index 00000000..99b5c283 --- /dev/null +++ b/database/pgsql/detector.go @@ -0,0 +1,198 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "database/sql" + + "github.com/deckarep/golang-set" + log "github.com/sirupsen/logrus" + + "github.com/coreos/clair/database" +) + +const ( + soiDetector = ` + INSERT INTO detector (name, version, dtype) + SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type ) + WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);` + + selectAncestryDetectors = ` + SELECT d.name, d.version, d.dtype + FROM ancestry_detector, detector AS d + WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;` + + selectLayerDetectors = ` + SELECT d.name, d.version, d.dtype + FROM layer_detector, detector AS d + WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;` + + insertAncestryDetectors = ` + INSERT INTO ancestry_detector (ancestry_id, detector_id) + SELECT $1, $2 + WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)` + + persistLayerDetector = ` + INSERT INTO layer_detector (layer_id, detector_id) + SELECT $1, $2 + WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)` + + findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3` + findAllDetectors = `SELECT id, name, version, dtype FROM detector` +) + +type detectorMap struct { + byID map[int64]database.Detector + byValue map[database.Detector]int64 +} + +func (tx *pgSession) PersistDetectors(detectors []database.Detector) error { + for _, d := range detectors { + if !d.Valid() { + log.WithField("detector", d).Debug("Invalid Detector") + return database.ErrInvalidParameters + } + + r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType) + if err != nil { + return handleError("soiDetector", err) + } + + count, err := r.RowsAffected() + if err != nil { + return handleError("soiDetector", err) + } + + if count == 0 { + log.Debug("detector already exists: ", d) + } + } + + return nil +} + +func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error { + if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil { + return handleError("persistLayerDetector", err) + } + + return nil +} + +func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error { + alreadySaved := mapset.NewSet() + for _, id := range detectorIDs { + if alreadySaved.Contains(id) { + continue + } + + alreadySaved.Add(id) + if err := tx.persistLayerDetector(layerID, id); err != nil { + return err + } + } + + return nil +} + +func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error { + for _, detectorID := range detectorIDs { + if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil { + return handleError("insertAncestryDetectors", err) + } + } + + return nil +} + +func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) { + detectors, err := tx.getDetectors(selectAncestryDetectors, id) + log.WithField("detectors", detectors).Debug("found ancestry detectors") + return detectors, err +} + +func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) { + detectors, err := tx.getDetectors(selectLayerDetectors, id) + log.WithField("detectors", detectors).Debug("found layer detectors") + return detectors, err +} + +// findDetectorIDs retrieve ids of the detectors from the database, if any is not +// found, return the error. +func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) { + ids := []int64{} + for _, d := range detectors { + id := sql.NullInt64{} + err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id) + if err != nil { + return nil, handleError("findDetectorID", err) + } + + if !id.Valid { + return nil, database.ErrInconsistent + } + + ids = append(ids, id.Int64) + } + + return ids, nil +} + +func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) { + rows, err := tx.Query(query, id) + if err != nil { + return nil, handleError("getDetectors", err) + } + + detectors := []database.Detector{} + for rows.Next() { + d := database.Detector{} + err := rows.Scan(&d.Name, &d.Version, &d.DType) + if err != nil { + return nil, handleError("getDetectors", err) + } + + if !d.Valid() { + return nil, database.ErrInvalidDetector + } + + detectors = append(detectors, d) + } + + return detectors, nil +} + +func (tx *pgSession) findAllDetectors() (detectorMap, error) { + rows, err := tx.Query(findAllDetectors) + if err != nil { + return detectorMap{}, handleError("searchAllDetectors", err) + } + + detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)} + for rows.Next() { + var ( + id int64 + d database.Detector + ) + if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil { + return detectorMap{}, handleError("searchAllDetectors", err) + } + + detectors.byID[id] = d + detectors.byValue[d] = id + } + + return detectors, nil +} diff --git a/database/pgsql/detector_test.go b/database/pgsql/detector_test.go new file mode 100644 index 00000000..582da60b --- /dev/null +++ b/database/pgsql/detector_test.go @@ -0,0 +1,119 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "testing" + + "github.com/deckarep/golang-set" + "github.com/stretchr/testify/require" + + "github.com/coreos/clair/database" +) + +func testGetAllDetectors(tx *pgSession) []database.Detector { + query := `SELECT name, version, dtype FROM detector` + rows, err := tx.Query(query) + if err != nil { + panic(err) + } + + detectors := []database.Detector{} + for rows.Next() { + d := database.Detector{} + if err := rows.Scan(&d.Name, &d.Version, &d.DType); err != nil { + panic(err) + } + + detectors = append(detectors, d) + } + + return detectors +} + +var persistDetectorTests = []struct { + title string + in []database.Detector + err string +}{ + { + title: "invalid detector", + in: []database.Detector{ + {}, + database.NewFeatureDetector("name", "2.0"), + }, + err: database.ErrInvalidParameters.Error(), + }, + { + title: "invalid detector 2", + in: []database.Detector{ + database.NewFeatureDetector("name", "2.0"), + {"name", "1.0", "random not valid dtype"}, + }, + err: database.ErrInvalidParameters.Error(), + }, + { + title: "detectors with some different fields", + in: []database.Detector{ + database.NewFeatureDetector("name", "2.0"), + database.NewFeatureDetector("name", "1.0"), + database.NewNamespaceDetector("name", "1.0"), + }, + }, + { + title: "duplicated detectors (parameter level)", + in: []database.Detector{ + database.NewFeatureDetector("name", "1.0"), + database.NewFeatureDetector("name", "1.0"), + }, + }, + { + title: "duplicated detectors (db level)", + in: []database.Detector{ + database.NewNamespaceDetector("os-release", "1.0"), + database.NewNamespaceDetector("os-release", "1.0"), + database.NewFeatureDetector("dpkg", "1.0"), + }, + }, +} + +func TestPersistDetector(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistDetector", true) + defer closeTest(t, datastore, tx) + + for _, test := range persistDetectorTests { + t.Run(test.title, func(t *testing.T) { + err := tx.PersistDetectors(test.in) + if test.err != "" { + require.EqualError(t, err, test.err) + return + } + + detectors := testGetAllDetectors(tx) + + // ensure no duplicated detectors + detectorSet := mapset.NewSet() + for _, d := range detectors { + require.False(t, detectorSet.Contains(d), "duplicated: %v", d) + detectorSet.Add(d) + } + + // ensure all persisted detectors are actually saved + for _, d := range test.in { + require.True(t, detectorSet.Contains(d), "detector: %v, detectors: %v", d, detectorSet) + } + }) + } +} diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index e1c0781c..345b73a3 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -16,7 +16,6 @@ package pgsql import ( "database/sql" - "errors" "sort" "github.com/lib/pq" @@ -28,7 +27,6 @@ import ( ) const ( - // feature.go soiNamespacedFeature = ` WITH new_feature_ns AS ( INSERT INTO namespaced_feature(feature_id, namespace_id) @@ -65,15 +63,6 @@ const ( AND v.deleted_at IS NULL` ) -var ( - errFeatureNotFound = errors.New("Feature not found") -) - -type vulnerabilityAffecting struct { - vulnerabilityID int64 - addedByID int64 -} - func (tx *pgSession) PersistFeatures(features []database.Feature) error { if len(features) == 0 { return nil @@ -126,7 +115,7 @@ func (tx *pgSession) searchAffectingVulnerabilities(features []database.Namespac fMap := map[int64]database.NamespacedFeature{} for i, f := range features { if !ids[i].Valid { - return nil, errFeatureNotFound + return nil, database.ErrMissingEntities } fMap[ids[i].Int64] = f } @@ -218,7 +207,7 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea if ids, err := tx.findFeatureIDs(fToFind); err == nil { for i, id := range ids { if !id.Valid { - return errFeatureNotFound + return database.ErrMissingEntities } fIDs[fToFind[i]] = id } @@ -234,7 +223,7 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea if ids, err := tx.findNamespaceIDs(nsToFind); err == nil { for i, id := range ids { if !id.Valid { - return errNamespaceNotFound + return database.ErrMissingEntities } nsIDs[nsToFind[i]] = id } diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature_test.go index 934b8cc1..2823e1e8 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature_test.go @@ -52,7 +52,7 @@ func TestPersistNamespacedFeatures(t *testing.T) { // existing features f1 := database.Feature{ - Name: "wechat", + Name: "ourchat", Version: "0.5", VersionFormat: "dpkg", } @@ -213,27 +213,6 @@ func listFeatures(t *testing.T, tx *pgSession) []database.Feature { return fs } -func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool { - if assert.Len(t, actual, len(expected)) { - has := map[database.Feature]bool{} - for _, nf := range expected { - has[nf] = false - } - - for _, nf := range actual { - has[nf] = true - } - - for nf, visited := range has { - if !assert.True(t, visited, nf.Name+" is expected") { - return false - } - return true - } - } - return false -} - func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { if assert.Len(t, actual, len(expected)) { has := map[database.NamespacedFeature]bool{} diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index e474164e..ebea1849 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -18,6 +18,8 @@ import ( "database/sql" "sort" + "github.com/deckarep/golang-set" + "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" ) @@ -34,300 +36,331 @@ const ( UNION SELECT id FROM layer WHERE hash = $1` - searchLayerFeatures = ` - SELECT feature_id, detector_id - FROM layer_feature - WHERE layer_id = $1` + findLayerFeatures = ` + SELECT f.name, f.version, f.version_format, lf.detector_id + FROM layer_feature AS lf, feature AS f + WHERE lf.feature_id = f.id + AND lf.layer_id = $1` - searchLayerNamespaces = ` - SELECT namespace.Name, namespace.version_format - FROM namespace, layer_namespace - WHERE layer_namespace.layer_id = $1 - AND layer_namespace.namespace_id = namespace.id` + findLayerNamespaces = ` + SELECT ns.name, ns.version_format, ln.detector_id + FROM layer_namespace AS ln, namespace AS ns + WHERE ln.namespace_id = ns.id + AND ln.layer_id = $1` - searchLayer = `SELECT id FROM layer WHERE hash = $1` + findLayerID = `SELECT id FROM layer WHERE hash = $1` ) + +// dbLayerNamespace represents the layer_namespace table. +type dbLayerNamespace struct { + layerID int64 + namespaceID int64 + detectorID int64 +} + +// dbLayerFeature represents the layer_feature table +type dbLayerFeature struct { + layerID int64 + featureID int64 + detectorID int64 +} + +func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) { + layer := database.Layer{Hash: hash} + if hash == "" { + return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.") + } + + layerID, ok, err := tx.findLayerID(hash) + if err != nil || !ok { + return layer, ok, err + } + + detectorMap, err := tx.findAllDetectors() if err != nil { return layer, false, err } - if !ok { - return layer, false, nil + if layer.By, err = tx.findLayerDetectors(layerID); err != nil { + return layer, false, err + } + + if layer.Features, err = tx.findLayerFeatures(layerID, detectorMap); err != nil { + return layer, false, err + } + + if layer.Namespaces, err = tx.findLayerNamespaces(layerID, detectorMap); err != nil { + return layer, false, err } - layer.Features, err = tx.findLayerFeatures(layerID) - layer.Namespaces, err = tx.findLayerNamespaces(layerID) return layer, true, nil } -func (tx *pgSession) persistLayer(hash string) (int64, error) { +func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { if hash == "" { - return -1, commonerr.NewBadRequestError("Empty Layer Hash is not allowed") + return commonerr.NewBadRequestError("expected non-empty layer hash") } - id := sql.NullInt64{} - if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil { - return -1, handleError("queryPersistLayer", err) + detectedBySet := mapset.NewSet() + for _, d := range detectedBy { + detectedBySet.Add(d) } - if !id.Valid { - panic("null layer.id violates database constraint") + for _, f := range features { + if !detectedBySet.Contains(f.By) { + return database.ErrInvalidParameters + } } - return id.Int64, nil -} - -// PersistLayer relates layer identified by hash with namespaces, -// features and processors provided. If the layer, namespaces, features are not -// in database, the function returns an error. -func (tx *pgSession) PersistLayer(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error { - if hash == "" { - return commonerr.NewBadRequestError("Empty layer hash is not allowed") + for _, n := range namespaces { + if !detectedBySet.Contains(n.By) { + return database.ErrInvalidParameters + } } + return nil +} + +// PersistLayer saves the content of a layer to the database. +func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error { var ( - err error - id int64 + err error + id int64 + detectorIDs []int64 ) - if id, err = tx.persistLayer(hash); err != nil { + if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil { return err } - if err = tx.persistLayerNamespace(id, namespaces); err != nil { + if id, err = tx.soiLayer(hash); err != nil { return err } - if err = tx.persistLayerFeatures(id, features); err != nil { + if detectorIDs, err = tx.findDetectorIDs(detectedBy); err != nil { + if err == commonerr.ErrNotFound { + return database.ErrMissingEntities + } + return err } - if err = tx.persistLayerDetectors(id, processedBy.Detectors); err != nil { + if err = tx.persistLayerDetectors(id, detectorIDs); err != nil { return err } - if err = tx.persistLayerListers(id, processedBy.Listers); err != nil { + if err = tx.persistAllLayerFeatures(id, features); err != nil { + return err + } + + if err = tx.persistAllLayerNamespaces(id, namespaces); err != nil { return err } return nil } -func (tx *pgSession) persistLayerDetectors(id int64, detectors []string) error { - if len(detectors) == 0 { - return nil +func (tx *pgSession) persistAllLayerNamespaces(layerID int64, namespaces []database.LayerNamespace) error { + detectorMap, err := tx.findAllDetectors() + if err != nil { + return err } - // Sorting is needed before inserting into database to prevent deadlock. - sort.Strings(detectors) - keys := make([]interface{}, len(detectors)*2) - for i, d := range detectors { - keys[i*2] = id - keys[i*2+1] = d + // TODO(sidac): This kind of type conversion is very useless and wasteful, + // we need interfaces around the database models to reduce these kind of + // operations. + rawNamespaces := make([]database.Namespace, 0, len(namespaces)) + for _, ns := range namespaces { + rawNamespaces = append(rawNamespaces, ns.Namespace) } - _, err := tx.Exec(queryPersistLayerDetectors(len(detectors)), keys...) + + rawNamespaceIDs, err := tx.findNamespaceIDs(rawNamespaces) if err != nil { - return handleError("queryPersistLayerDetectors", err) + return err } - return nil -} -func (tx *pgSession) persistLayerListers(id int64, listers []string) error { - if len(listers) == 0 { - return nil - } + dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces)) + for i, ns := range namespaces { + detectorID := detectorMap.byValue[ns.By] + namespaceID := rawNamespaceIDs[i].Int64 + if !rawNamespaceIDs[i].Valid { + return database.ErrMissingEntities + } - sort.Strings(listers) - keys := make([]interface{}, len(listers)*2) - for i, d := range listers { - keys[i*2] = id - keys[i*2+1] = d + dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID}) } - _, err := tx.Exec(queryPersistLayerListers(len(listers)), keys...) + return tx.persistLayerNamespaces(dbLayerNamespaces) +} + +func (tx *pgSession) persistAllLayerFeatures(layerID int64, features []database.LayerFeature) error { + detectorMap, err := tx.findAllDetectors() if err != nil { - return handleError("queryPersistLayerDetectors", err) + return err } - return nil -} -func (tx *pgSession) persistLayerFeatures(id int64, features []database.Feature) error { - if len(features) == 0 { - return nil + rawFeatures := make([]database.Feature, 0, len(features)) + for _, f := range features { + rawFeatures = append(rawFeatures, f.Feature) } - fIDs, err := tx.findFeatureIDs(features) + featureIDs, err := tx.findFeatureIDs(rawFeatures) if err != nil { return err } - ids := make([]int, len(fIDs)) - for i, fID := range fIDs { - if !fID.Valid { - return errNamespaceNotFound + dbFeatures := make([]dbLayerFeature, 0, len(features)) + for i, f := range features { + detectorID := detectorMap.byValue[f.By] + featureID := featureIDs[i].Int64 + if !featureIDs[i].Valid { + return database.ErrMissingEntities } - ids[i] = int(fID.Int64) - } - sort.IntSlice(ids).Sort() - keys := make([]interface{}, len(features)*2) - for i, fID := range ids { - keys[i*2] = id - keys[i*2+1] = fID + dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID}) } - _, err = tx.Exec(queryPersistLayerFeature(len(features)), keys...) - if err != nil { - return handleError("queryPersistLayerFeature", err) + if err := tx.persistLayerFeatures(dbFeatures); err != nil { + return err } + return nil } -func (tx *pgSession) persistLayerNamespace(id int64, namespaces []database.Namespace) error { - if len(namespaces) == 0 { +func (tx *pgSession) persistLayerFeatures(features []dbLayerFeature) error { + if len(features) == 0 { return nil } - nsIDs, err := tx.findNamespaceIDs(namespaces) - if err != nil { - return err - } - - // for every bulk persist operation, the input data should be sorted. - ids := make([]int, len(nsIDs)) - for i, nsID := range nsIDs { - if !nsID.Valid { - panic(errNamespaceNotFound) - } - ids[i] = int(nsID.Int64) - } - - sort.IntSlice(ids).Sort() + sort.Slice(features, func(i, j int) bool { + return features[i].featureID < features[j].featureID + }) - keys := make([]interface{}, len(namespaces)*2) - for i, nsID := range ids { - keys[i*2] = id - keys[i*2+1] = nsID + keys := make([]interface{}, len(features)*3) + for i, feature := range features { + keys[i*3] = feature.layerID + keys[i*3+1] = feature.featureID + keys[i*3+2] = feature.detectorID } - _, err = tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...) + _, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...) if err != nil { - return handleError("queryPersistLayerNamespace", err) + return handleError("queryPersistLayerFeature", err) } return nil } -func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int64, processors database.Processors) error { - stmt, err := tx.Prepare(listerQuery) - if err != nil { - return handleError(listerQueryName, err) +func (tx *pgSession) persistLayerNamespaces(namespaces []dbLayerNamespace) error { + if len(namespaces) == 0 { + return nil } - for _, l := range processors.Listers { - _, err := stmt.Exec(id, l) - if err != nil { - stmt.Close() - return handleError(listerQueryName, err) - } - } + // for every bulk persist operation, the input data should be sorted. + sort.Slice(namespaces, func(i, j int) bool { + return namespaces[i].namespaceID < namespaces[j].namespaceID + }) - if err := stmt.Close(); err != nil { - return handleError(listerQueryName, err) + elementSize := 3 + keys := make([]interface{}, len(namespaces)*elementSize) + for i, row := range namespaces { + keys[i*3] = row.layerID + keys[i*3+1] = row.namespaceID + keys[i*3+2] = row.detectorID } - stmt, err = tx.Prepare(detectorQuery) + _, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...) if err != nil { - return handleError(detectorQueryName, err) - } - - for _, d := range processors.Detectors { - _, err := stmt.Exec(id, d) - if err != nil { - stmt.Close() - return handleError(detectorQueryName, err) - } - } - - if err := stmt.Close(); err != nil { - return handleError(detectorQueryName, err) + return handleError("queryPersistLayerNamespace", err) } return nil } -func (tx *pgSession) findLayerNamespaces(layerID int64) ([]database.Namespace, error) { - var namespaces []database.Namespace - - rows, err := tx.Query(searchLayerNamespaces, layerID) +func (tx *pgSession) findLayerNamespaces(layerID int64, detectors detectorMap) ([]database.LayerNamespace, error) { + rows, err := tx.Query(findLayerNamespaces, layerID) if err != nil { - return nil, handleError("searchLayerFeatures", err) + return nil, handleError("findLayerNamespaces", err) } + namespaces := []database.LayerNamespace{} for rows.Next() { - ns := database.Namespace{} - err := rows.Scan(&ns.Name, &ns.VersionFormat) - if err != nil { + var ( + namespace database.LayerNamespace + detectorID int64 + ) + + if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil { return nil, err } - namespaces = append(namespaces, ns) + + namespace.By = detectors.byID[detectorID] + namespaces = append(namespaces, namespace) } + return namespaces, nil } -func (tx *pgSession) findLayerFeatures(layerID int64) ([]database.Feature, error) { - var features []database.Feature - - rows, err := tx.Query(searchLayerFeatures, layerID) +func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]database.LayerFeature, error) { + rows, err := tx.Query(findLayerFeatures, layerID) if err != nil { - return nil, handleError("searchLayerFeatures", err) + return nil, handleError("findLayerFeatures", err) } + defer rows.Close() + features := []database.LayerFeature{} for rows.Next() { - f := database.Feature{} - err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) - if err != nil { - return nil, err + var ( + detectorID int64 + feature database.LayerFeature + ) + if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &detectorID); err != nil { + return nil, handleError("findLayerFeatures", err) } - features = append(features, f) + + feature.By = detectors.byID[detectorID] + features = append(features, feature) } + return features, nil } -func (tx *pgSession) findLayer(hash string) (database.LayerMetadata, int64, bool, error) { - var ( - layerID int64 - layer = database.LayerMetadata{Hash: hash, ProcessedBy: database.Processors{}} - ) - - if hash == "" { - return layer, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed") - } - - err := tx.QueryRow(searchLayer, hash).Scan(&layerID) +func (tx *pgSession) findLayerID(hash string) (int64, bool, error) { + var layerID int64 + err := tx.QueryRow(findLayerID, hash).Scan(&layerID) if err != nil { if err == sql.ErrNoRows { - return layer, layerID, false, nil + return layerID, false, nil } - return layer, layerID, false, err + + return layerID, false, handleError("findLayerID", err) } - layer.ProcessedBy, err = tx.findLayerProcessors(layerID) - return layer, layerID, true, err + return layerID, true, nil } -func (tx *pgSession) findLayerProcessors(id int64) (database.Processors, error) { - var ( - err error - processors database.Processors - ) +func (tx *pgSession) findLayerIDs(hashes []string) ([]int64, bool, error) { + layerIDs := make([]int64, 0, len(hashes)) + for _, hash := range hashes { + id, ok, err := tx.findLayerID(hash) + if !ok { + return nil, false, nil + } - if processors.Detectors, err = tx.findProcessors(searchLayerDetectors, id); err != nil { - return processors, handleError("searchLayerDetectors", err) + if err != nil { + return nil, false, err + } + + layerIDs = append(layerIDs, id) } - if processors.Listers, err = tx.findProcessors(searchLayerListers, id); err != nil { - return processors, handleError("searchLayerListers", err) + return layerIDs, true, nil +} + +func (tx *pgSession) soiLayer(hash string) (int64, error) { + var id int64 + if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil { + return 0, handleError("soiLayer", err) } - return processors, nil + return id, nil } diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer_test.go index 6fe8bed3..fc22a83c 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer_test.go @@ -20,107 +20,172 @@ import ( "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/testutil" ) -func TestPersistLayer(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistLayer", false) - defer closeTest(t, datastore, tx) - - // invalid - assert.NotNil(t, tx.PersistLayer("", nil, nil, database.Processors{})) - // insert namespaces + features to - namespaces := []database.Namespace{ - { - Name: "sushi shop", - VersionFormat: "apk", +var persistLayerTests = []struct { + title string + name string + by []database.Detector + features []database.LayerFeature + namespaces []database.LayerNamespace + layer *database.Layer + err string +}{ + { + title: "invalid layer name", + name: "", + err: "expected non-empty layer hash", + }, + { + title: "layer with inconsistent feature and detectors", + name: "random-forest", + by: []database.Detector{realDetectors[2]}, + features: []database.LayerFeature{ + {realFeatures[1], realDetectors[1]}, }, - } - - features := []database.Feature{ - { - Name: "blue fin sashimi", - Version: "v1.0", - VersionFormat: "apk", + err: "database: parameters are not valid", + }, + { + title: "layer with non-existing feature", + name: "random-forest", + err: "database: associated immutable entities are missing in the database", + by: []database.Detector{realDetectors[2]}, + features: []database.LayerFeature{ + {fakeFeatures[1], realDetectors[2]}, }, - } - - processors := database.Processors{ - Listers: []string{"release"}, - Detectors: []string{"apk"}, - } - - assert.Nil(t, tx.PersistNamespaces(namespaces)) - assert.Nil(t, tx.PersistFeatures(features)) - - // Valid - assert.Nil(t, tx.PersistLayer("RANDOM_FOREST", namespaces, features, processors)) - - nonExistingFeature := []database.Feature{{Name: "lobster sushi", Version: "v0.1", VersionFormat: "apk"}} - // Invalid: - assert.NotNil(t, tx.PersistLayer("RANDOM_FOREST", namespaces, nonExistingFeature, processors)) - - assert.Nil(t, tx.PersistFeatures(nonExistingFeature)) - // Update the layer - assert.Nil(t, tx.PersistLayer("RANDOM_FOREST", namespaces, nonExistingFeature, processors)) - - // confirm update - layer, ok, err := tx.FindLayer("RANDOM_FOREST") - assert.Nil(t, err) - assert.True(t, ok) - - expectedLayer := database.Layer{ - LayerMetadata: database.LayerMetadata{ - Hash: "RANDOM_FOREST", - ProcessedBy: processors, + }, + { + title: "layer with non-existing namespace", + name: "random-forest2", + err: "database: associated immutable entities are missing in the database", + by: []database.Detector{realDetectors[1]}, + namespaces: []database.LayerNamespace{ + {fakeNamespaces[1], realDetectors[1]}, }, - Features: append(features, nonExistingFeature...), - Namespaces: namespaces, - } - - assertLayerWithContentEqual(t, expectedLayer, layer) -} - -func TestFindLayer(t *testing.T) { - datastore, tx := openSessionForTest(t, "FindLayer", true) - defer closeTest(t, datastore, tx) - - _, _, err := tx.FindLayer("") - assert.NotNil(t, err) - _, ok, err := tx.FindLayer("layer-non") - assert.Nil(t, err) - assert.False(t, ok) - - expectedL := database.Layer{ - LayerMetadata: database.LayerMetadata{ - Hash: "layer-4", - ProcessedBy: database.Processors{ - Detectors: []string{"os-release", "apt-sources"}, - Listers: []string{"dpkg", "rpm"}, + }, + { + title: "layer with non-existing detector", + name: "random-forest3", + err: "database: associated immutable entities are missing in the database", + by: []database.Detector{fakeDetector[1]}, + }, + { + title: "valid layer", + name: "hamsterhouse", + by: []database.Detector{realDetectors[1], realDetectors[2]}, + features: []database.LayerFeature{ + {realFeatures[1], realDetectors[2]}, + {realFeatures[2], realDetectors[2]}, + }, + namespaces: []database.LayerNamespace{ + {realNamespaces[1], realDetectors[1]}, + }, + layer: &database.Layer{ + Hash: "hamsterhouse", + By: []database.Detector{realDetectors[1], realDetectors[2]}, + Features: []database.LayerFeature{ + {realFeatures[1], realDetectors[2]}, + {realFeatures[2], realDetectors[2]}, + }, + Namespaces: []database.LayerNamespace{ + {realNamespaces[1], realDetectors[1]}, }, }, - Features: []database.Feature{ - {Name: "fake", Version: "2.0", VersionFormat: "rpm"}, - {Name: "openssl", Version: "2.0", VersionFormat: "dpkg"}, + }, + { + title: "update existing layer", + name: "layer-1", + by: []database.Detector{realDetectors[3], realDetectors[4]}, + features: []database.LayerFeature{ + {realFeatures[4], realDetectors[3]}, }, - Namespaces: []database.Namespace{ - {Name: "debian:7", VersionFormat: "dpkg"}, - {Name: "fake:1.0", VersionFormat: "rpm"}, + namespaces: []database.LayerNamespace{ + {realNamespaces[3], realDetectors[4]}, }, - } + layer: &database.Layer{ + Hash: "layer-1", + By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]}, + Features: []database.LayerFeature{ + {realFeatures[1], realDetectors[2]}, + {realFeatures[2], realDetectors[2]}, + {realFeatures[4], realDetectors[3]}, + }, + Namespaces: []database.LayerNamespace{ + {realNamespaces[1], realDetectors[1]}, + {realNamespaces[3], realDetectors[4]}, + }, + }, + }, +} - layer, ok2, err := tx.FindLayer("layer-4") - if assert.Nil(t, err) && assert.True(t, ok2) { - assertLayerWithContentEqual(t, expectedL, layer) +func TestPersistLayer(t *testing.T) { + datastore, tx := openSessionForTest(t, "PersistLayer", true) + defer closeTest(t, datastore, tx) + + for _, test := range persistLayerTests { + t.Run(test.title, func(t *testing.T) { + err := tx.PersistLayer(test.name, test.features, test.namespaces, test.by) + if test.err != "" { + assert.EqualError(t, err, test.err, "unexpected error") + return + } + + assert.Nil(t, err) + if test.layer != nil { + layer, ok, err := tx.FindLayer(test.name) + assert.Nil(t, err) + assert.True(t, ok) + testutil.AssertLayerEqual(t, test.layer, &layer) + } + }) } } -func assertLayerWithContentEqual(t *testing.T, expected database.Layer, actual database.Layer) bool { - return assertLayerEqual(t, expected.LayerMetadata, actual.LayerMetadata) && - assertFeaturesEqual(t, expected.Features, actual.Features) && - assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces) +var findLayerTests = []struct { + title string + in string + + out *database.Layer + err string + ok bool +}{ + { + title: "invalid layer name", + in: "", + err: "non empty layer hash is expected.", + }, + { + title: "non-existing layer", + in: "layer-non-existing", + ok: false, + out: nil, + }, + { + title: "existing layer", + in: "layer-4", + ok: true, + out: takeLayerPointerFromMap(realLayers, 6), + }, } -func assertLayerEqual(t *testing.T, expected database.LayerMetadata, actual database.LayerMetadata) bool { - return assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) && - assert.Equal(t, expected.Hash, actual.Hash) +func TestFindLayer(t *testing.T) { + datastore, tx := openSessionForTest(t, "FindLayer", true) + defer closeTest(t, datastore, tx) + + for _, test := range findLayerTests { + t.Run(test.title, func(t *testing.T) { + layer, ok, err := tx.FindLayer(test.in) + if test.err != "" { + assert.EqualError(t, err, test.err, "unexpected error") + return + } + + assert.Nil(t, err) + assert.Equal(t, test.ok, ok) + if test.ok { + testutil.AssertLayerEqual(t, test.out, &layer) + } + }) + } } diff --git a/database/pgsql/migrations/00001_initial_schema.go b/database/pgsql/migrations/00001_initial_schema.go index 69bd4081..c073e286 100644 --- a/database/pgsql/migrations/00001_initial_schema.go +++ b/database/pgsql/migrations/00001_initial_schema.go @@ -38,8 +38,8 @@ var ( `CREATE TABLE IF NOT EXISTS namespaced_feature ( id SERIAL PRIMARY KEY, - namespace_id INT REFERENCES namespace, - feature_id INT REFERENCES feature, + namespace_id INT REFERENCES namespace ON DELETE CASCADE, + feature_id INT REFERENCES feature ON DELETE CASCADE, UNIQUE (namespace_id, feature_id));`, }, Down: []string{ @@ -116,7 +116,7 @@ var ( id SERIAL PRIMARY KEY, ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, ancestry_index INT NOT NULL, - layer_id INT REFERENCES layer ON DELETE RESTRICT, + layer_id INT NOT NULL REFERENCES layer ON DELETE RESTRICT, UNIQUE (ancestry_id, ancestry_index));`, `CREATE INDEX ON ancestry_layer(ancestry_id);`, @@ -130,7 +130,7 @@ var ( `CREATE TABLE IF NOT EXISTS ancestry_detector( id SERIAL PRIMARY KEY, - ancestry_id INT REFERENCES layer ON DELETE CASCADE, + ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, detector_id INT REFERENCES detector ON DELETE CASCADE, UNIQUE(ancestry_id, detector_id));`, `CREATE INDEX ON ancestry_detector(ancestry_id);`, diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index bd0dae34..87d25e33 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -16,7 +16,6 @@ package pgsql import ( "database/sql" - "errors" "sort" "github.com/coreos/clair/database" @@ -27,10 +26,6 @@ const ( searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` ) -var ( - errNamespaceNotFound = errors.New("Requested Namespace is not in database") -) - // PersistNamespaces soi namespaces into database. func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { if len(namespaces) == 0 { diff --git a/database/pgsql/namespace_test.go b/database/pgsql/namespace_test.go index 27ceefef..8f2af288 100644 --- a/database/pgsql/namespace_test.go +++ b/database/pgsql/namespace_test.go @@ -42,42 +42,3 @@ func TestPersistNamespaces(t *testing.T) { assert.Len(t, nsList, 1) assert.Equal(t, ns2, nsList[0]) } - -func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool { - if assert.Len(t, actual, len(expected)) { - has := map[database.Namespace]bool{} - for _, i := range expected { - has[i] = false - } - for _, i := range actual { - has[i] = true - } - for key, v := range has { - if !assert.True(t, v, key.Name+"is expected") { - return false - } - } - return true - } - return false -} - -func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { - rows, err := tx.Query("SELECT name, version_format FROM namespace") - if err != nil { - t.FailNow() - } - defer rows.Close() - - namespaces := []database.Namespace{} - for rows.Next() { - var ns database.Namespace - err := rows.Scan(&ns.Name, &ns.VersionFormat) - if err != nil { - t.FailNow() - } - namespaces = append(namespaces, ns) - } - - return namespaces -} diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index 44eff64b..7d2b750d 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -27,7 +27,6 @@ import ( ) const ( - // notification.go insertNotification = ` INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) VALUES ($1, $2, $3, $4)` @@ -60,9 +59,10 @@ const ( SELECT DISTINCT ON (a.id) a.id, a.name FROM vulnerability_affected_namespaced_feature AS vanf, - ancestry_layer AS al, ancestry_feature AS af + ancestry_layer AS al, ancestry_feature AS af, ancestry AS a WHERE vanf.vulnerability_id = $1 - AND al.ancestry_id >= $2 + AND a.id >= $2 + AND al.ancestry_id = a.id AND al.id = af.ancestry_layer_id AND af.namespaced_feature_id = vanf.namespaced_feature_id ORDER BY a.id ASC @@ -211,14 +211,12 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr vulnPage := database.PagedVulnerableAncestries{Limit: limit} currentPage := Page{0} if currentToken != pagination.FirstPageToken { - var err error - err = tx.key.UnmarshalToken(currentToken, ¤tPage) - if err != nil { + if err := tx.key.UnmarshalToken(currentToken, ¤tPage); err != nil { return vulnPage, err } } - err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( + if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( &vulnPage.Name, &vulnPage.Description, &vulnPage.Link, @@ -226,8 +224,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr &vulnPage.Metadata, &vulnPage.Namespace.Name, &vulnPage.Namespace.VersionFormat, - ) - if err != nil { + ); err != nil { return vulnPage, handleError("searchVulnerabilityByID", err) } @@ -290,7 +287,6 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } noti.Name = name - err := tx.QueryRow(searchNotification, name).Scan(&created, ¬ified, &deleted, &oldVulnID, &newVulnID) diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 9d36f4cb..0a23abca 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -19,121 +19,144 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" ) -func TestPagination(t *testing.T) { - datastore, tx := openSessionForTest(t, "Pagination", true) - defer closeTest(t, datastore, tx) - - ns := database.Namespace{ - Name: "debian:7", - VersionFormat: "dpkg", - } - - vNew := database.Vulnerability{ - Namespace: ns, - Name: "CVE-OPENSSL-1-DEB7", - Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", - Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", - Severity: database.HighSeverity, - } - - vOld := database.Vulnerability{ - Namespace: ns, - Name: "CVE-NOPE", - Description: "A vulnerability affecting nothing", - Severity: database.UnknownSeverity, - } - - noti, ok, err := tx.FindVulnerabilityNotification("test", 1, "", "") - oldPage := database.PagedVulnerableAncestries{ - Vulnerability: vOld, - Limit: 1, - Affected: make(map[int]string), - End: true, - } - - newPage1 := database.PagedVulnerableAncestries{ - Vulnerability: vNew, - Limit: 1, - Affected: map[int]string{3: "ancestry-3"}, - End: false, - } +type findVulnerabilityNotificationIn struct { + notificationName string + pageSize int + oldAffectedAncestryPage pagination.Token + newAffectedAncestryPage pagination.Token +} - newPage2 := database.PagedVulnerableAncestries{ - Vulnerability: vNew, - Limit: 1, - Affected: map[int]string{4: "ancestry-4"}, - Next: "", - End: true, - } +type findVulnerabilityNotificationOut struct { + notification *database.VulnerabilityNotificationWithVulnerable + ok bool + err string +} - if assert.Nil(t, err) && assert.True(t, ok) { - assert.Equal(t, "test", noti.Name) - if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - var oldPage Page - err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage) - if !assert.Nil(t, err) { - assert.FailNow(t, "") - } +var findVulnerabilityNotificationTests = []struct { + title string + in findVulnerabilityNotificationIn + out findVulnerabilityNotificationOut +}{ + { + title: "find notification with invalid page", + in: findVulnerabilityNotificationIn{ + notificationName: "test", + pageSize: 1, + oldAffectedAncestryPage: pagination.FirstPageToken, + newAffectedAncestryPage: pagination.Token("random non sense"), + }, + out: findVulnerabilityNotificationOut{ + err: pagination.ErrInvalidToken.Error(), + }, + }, + { + title: "find non-existing notification", + in: findVulnerabilityNotificationIn{ + notificationName: "non-existing", + pageSize: 1, + oldAffectedAncestryPage: pagination.FirstPageToken, + newAffectedAncestryPage: pagination.FirstPageToken, + }, + out: findVulnerabilityNotificationOut{ + ok: false, + }, + }, + { + title: "find existing notification first page", + in: findVulnerabilityNotificationIn{ + notificationName: "test", + pageSize: 1, + oldAffectedAncestryPage: pagination.FirstPageToken, + newAffectedAncestryPage: pagination.FirstPageToken, + }, + out: findVulnerabilityNotificationOut{ + &database.VulnerabilityNotificationWithVulnerable{ + NotificationHook: realNotification[1].NotificationHook, + Old: &database.PagedVulnerableAncestries{ + Vulnerability: realVulnerability[2], + Limit: 1, + Affected: make(map[int]string), + Current: mustMarshalToken(testPaginationKey, Page{0}), + Next: mustMarshalToken(testPaginationKey, Page{0}), + End: true, + }, + New: &database.PagedVulnerableAncestries{ + Vulnerability: realVulnerability[1], + Limit: 1, + Affected: map[int]string{3: "ancestry-3"}, + Current: mustMarshalToken(testPaginationKey, Page{0}), + Next: mustMarshalToken(testPaginationKey, Page{4}), + End: false, + }, + }, - assert.Equal(t, int64(0), oldPage.StartID) - var newPage Page - err = tx.key.UnmarshalToken(noti.New.Current, &newPage) - if !assert.Nil(t, err) { - assert.FailNow(t, "") - } - var newPageNext Page - err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext) - if !assert.Nil(t, err) { - assert.FailNow(t, "") - } - assert.Equal(t, int64(0), newPage.StartID) - assert.Equal(t, int64(4), newPageNext.StartID) - - noti.Old.Current = "" - noti.New.Current = "" - noti.New.Next = "" - assert.Equal(t, oldPage, *noti.Old) - assert.Equal(t, newPage1, *noti.New) - } - } + true, + "", + }, + }, + + { + title: "find existing notification of second page of new affected ancestry", + in: findVulnerabilityNotificationIn{ + notificationName: "test", + pageSize: 1, + oldAffectedAncestryPage: pagination.FirstPageToken, + newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}), + }, + out: findVulnerabilityNotificationOut{ + &database.VulnerabilityNotificationWithVulnerable{ + NotificationHook: realNotification[1].NotificationHook, + Old: &database.PagedVulnerableAncestries{ + Vulnerability: realVulnerability[2], + Limit: 1, + Affected: make(map[int]string), + Current: mustMarshalToken(testPaginationKey, Page{0}), + Next: mustMarshalToken(testPaginationKey, Page{0}), + End: true, + }, + New: &database.PagedVulnerableAncestries{ + Vulnerability: realVulnerability[1], + Limit: 1, + Affected: map[int]string{4: "ancestry-4"}, + Current: mustMarshalToken(testPaginationKey, Page{4}), + Next: mustMarshalToken(testPaginationKey, Page{0}), + End: true, + }, + }, - pageNum1, err := tx.key.MarshalToken(Page{0}) - if !assert.Nil(t, err) { - assert.FailNow(t, "") - } + true, + "", + }, + }, +} - pageNum2, err := tx.key.MarshalToken(Page{4}) - if !assert.Nil(t, err) { - assert.FailNow(t, "") - } +func TestFindVulnerabilityNotification(t *testing.T) { + datastore, tx := openSessionForTest(t, "pagination", true) + defer closeTest(t, datastore, tx) - noti, ok, err = tx.FindVulnerabilityNotification("test", 1, pageNum1, pageNum2) - if assert.Nil(t, err) && assert.True(t, ok) { - assert.Equal(t, "test", noti.Name) - if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - var oldCurrentPage Page - err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage) - if !assert.Nil(t, err) { - assert.FailNow(t, "") + for _, test := range findVulnerabilityNotificationTests { + t.Run(test.title, func(t *testing.T) { + notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage) + if test.out.err != "" { + require.EqualError(t, err, test.out.err) + return } - var newCurrentPage Page - err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage) - if !assert.Nil(t, err) { - assert.FailNow(t, "") + require.Nil(t, err) + if !test.out.ok { + require.Equal(t, test.out.ok, ok) + return } - assert.Equal(t, int64(0), oldCurrentPage.StartID) - assert.Equal(t, int64(4), newCurrentPage.StartID) - noti.Old.Current = "" - noti.New.Current = "" - assert.Equal(t, oldPage, *noti.Old) - assert.Equal(t, newPage2, *noti.New) - } + require.True(t, ok) + assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, ¬ification) + }) } } diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 9af010b6..4b23e014 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -270,6 +270,7 @@ func migrateDatabase(db *sql.DB) error { // createDatabase creates a new database. // The source parameter should not contain a dbname. func createDatabase(source, dbName string) error { + log.WithFields(log.Fields{"source": source, "dbName": dbName}).Debug("creating database...") // Open database. db, err := sql.Open("postgres", source) if err != nil { @@ -325,7 +326,7 @@ func handleError(desc string, err error) error { return commonerr.ErrNotFound } - log.WithError(err).WithField("Description", desc).Error("Handled Database Error") + log.WithError(err).WithField("Description", desc).Error("database: handled database error") promErrorsTotal.WithLabelValues(desc).Inc() if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index e4a8c8b4..863445a5 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -37,6 +37,8 @@ var ( withFixtureName, withoutFixtureName string ) +var testPaginationKey = pagination.Must(pagination.NewKey()) + func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) { config := generateTestConfig(name, loadFixture, false) source := config.Options["source"].(string) @@ -215,13 +217,15 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data source = fmt.Sprintf(sourceEnv, dbName) } + log.Infof("pagination key for current test: %s", testPaginationKey.String()) + return database.RegistrableComponentConfig{ Options: map[string]interface{}{ "source": source, "cachesize": 0, "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, - "paginationkey": pagination.Must(pagination.NewKey()).String(), + "paginationkey": testPaginationKey.String(), }, } } @@ -247,6 +251,8 @@ func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *p t.Error(err) t.FailNow() } + + log.Infof("transaction pagination key: '%s'", tx.(*pgSession).key.String()) return store, tx.(*pgSession) } diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index ad7cfc44..2d4b7e99 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -121,7 +121,8 @@ func queryPersistLayerFeature(count int) string { "layer_feature", "layer_feature_layer_id_feature_id_key", "layer_id", - "feature_id") + "feature_id", + "detector_id") } func queryPersistNamespace(count int) string { @@ -132,28 +133,13 @@ func queryPersistNamespace(count int) string { "version_format") } -func queryPersistLayerListers(count int) string { - return queryPersist(count, - "layer_lister", - "layer_lister_layer_id_lister_key", - "layer_id", - "lister") -} - -func queryPersistLayerDetectors(count int) string { - return queryPersist(count, - "layer_detector", - "layer_detector_layer_id_detector_key", - "layer_id", - "detector") -} - func queryPersistLayerNamespace(count int) string { return queryPersist(count, "layer_namespace", "layer_namespace_layer_id_namespace_id_key", "layer_id", - "namespace_id") + "namespace_id", + "detector_id") } // size of key and array should be both greater than 0 diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testdata/data.sql index 9d8e0323..e7484209 100644 --- a/database/pgsql/testdata/data.sql +++ b/database/pgsql/testdata/data.sql @@ -1,57 +1,69 @@ +-- initialize entities INSERT INTO namespace (id, name, version_format) VALUES -(1, 'debian:7', 'dpkg'), -(2, 'debian:8', 'dpkg'), -(3, 'fake:1.0', 'rpm'); + (1, 'debian:7', 'dpkg'), + (2, 'debian:8', 'dpkg'), + (3, 'fake:1.0', 'rpm'); INSERT INTO feature (id, name, version, version_format) VALUES -(1, 'wechat', '0.5', 'dpkg'), -(2, 'openssl', '1.0', 'dpkg'), -(3, 'openssl', '2.0', 'dpkg'), -(4, 'fake', '2.0', 'rpm'); + (1, 'ourchat', '0.5', 'dpkg'), + (2, 'openssl', '1.0', 'dpkg'), + (3, 'openssl', '2.0', 'dpkg'), + (4, 'fake', '2.0', 'rpm'); +INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES + (1, 1, 1), -- ourchat 0.5, debian:7 + (2, 2, 1), -- openssl 1.0, debian:7 + (3, 2, 2), -- openssl 1.0, debian:8 + (4, 3, 1); -- openssl 2.0, debian:7 + +INSERT INTO detector(id, name, version, dtype) VALUES + (1, 'os-release', '1.0', 'namespace'), + (2, 'dpkg', '1.0', 'feature'), + (3, 'rpm', '1.0', 'feature'), + (4, 'apt-sources', '1.0', 'namespace'); + +-- initialize layers INSERT INTO layer (id, hash) VALUES (1, 'layer-0'), -- blank - (2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0 - (3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0 + (2, 'layer-1'), -- debian:7; ourchat 0.5, openssl 1.0 + (3, 'layer-2'), -- debian:7; ourchat 0.5, openssl 2.0 (4, 'layer-3a'),-- debian:7; - (5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0 + (5, 'layer-3b'),-- debian:8; ourchat 0.5, openssl 1.0 (6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake) -INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES - (1, 2, 1), - (2, 3, 1), - (3, 4, 1), - (4, 5, 2), - (5, 6, 1), - (6, 6, 3); - -INSERT INTO layer_feature(id, layer_id, feature_id) VALUES - (1, 2, 1), - (2, 2, 2), - (3, 3, 1), - (4, 3, 3), - (5, 5, 1), - (6, 5, 2), - (7, 6, 4), - (8, 6, 3); - -INSERT INTO layer_lister(id, layer_id, lister) VALUES - (1, 1, 'dpkg'), - (2, 2, 'dpkg'), - (3, 3, 'dpkg'), - (4, 4, 'dpkg'), - (5, 5, 'dpkg'), - (6, 6, 'dpkg'), - (7, 6, 'rpm'); - -INSERT INTO layer_detector(id, layer_id, detector) VALUES - (1, 1, 'os-release'), - (2, 2, 'os-release'), - (3, 3, 'os-release'), - (4, 4, 'os-release'), - (5, 5, 'os-release'), - (6, 6, 'os-release'), - (7, 6, 'apt-sources'); +INSERT INTO layer_namespace(id, layer_id, namespace_id, detector_id) VALUES + (1, 2, 1, 1), -- layer-1: debian:7 + (2, 3, 1, 1), -- layer-2: debian:7 + (3, 4, 1, 1), -- layer-3a: debian:7 + (4, 5, 2, 1), -- layer-3b: debian:8 + (5, 6, 1, 1), -- layer-4: debian:7 + (6, 6, 3, 4); -- layer-4: fake:1.0 + +INSERT INTO layer_feature(id, layer_id, feature_id, detector_id) VALUES + (1, 2, 1, 2), -- layer-1: ourchat 0.5 + (2, 2, 2, 2), -- layer-1: openssl 1.0 + (3, 3, 1, 2), -- layer-2: ourchat 0.5 + (4, 3, 3, 2), -- layer-2: openssl 2.0 + (5, 5, 1, 2), -- layer-3b: ourchat 0.5 + (6, 5, 2, 2), -- layer-3b: openssl 1.0 + (7, 6, 4, 3), -- layer-4: fake 2.0 + (8, 6, 3, 2); -- layer-4: openssl 2.0 + +INSERT INTO layer_detector(layer_id, detector_id) VALUES + (1, 1), + (2, 1), + (3, 1), + (4, 1), + (5, 1), + (6, 1), + (6, 4), + (1, 2), + (2, 2), + (3, 2), + (4, 2), + (5, 2), + (6, 2), + (6, 3); INSERT INTO ancestry (id, name) VALUES (1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a @@ -59,32 +71,39 @@ INSERT INTO ancestry (id, name) VALUES (3, 'ancestry-3'), -- layer-0 (4, 'ancestry-4'); -- layer-0 -INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES - (1, 1, 'dpkg'), - (2, 2, 'dpkg'); - -INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES - (1, 1, 'os-release'), - (2, 2, 'os-release'); +INSERT INTO ancestry_detector (ancestry_id, detector_id) VALUES + (1, 2), + (2, 2), + (1, 1), + (2, 1); INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES + -- ancestry-1: layer-0, layer-1, layer-2, layer-3a (1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3), + -- ancestry-2: layer-0, layer-1, layer-2, layer-3b (5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3), - (9, 3, 1, 0), - (10, 4, 1, 0); - -INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES - (1, 1, 1), -- wechat 0.5, debian:7 - (2, 2, 1), -- openssl 1.0, debian:7 - (3, 2, 2), -- openssl 1.0, debian:8 - (4, 3, 1); -- openssl 2.0, debian:7 + -- ancestry-3: layer-1 + (9, 3, 2, 0), + -- ancestry-4: layer-1 + (10, 4, 2, 0); -- assume that ancestry-3 and ancestry-4 are vulnerable. -INSERT INTO ancestry_feature (id, ancestry_layer_id, namespaced_feature_id) VALUES - (1, 1, 1), (2, 1, 4), -- ancestry-1, layer 0 introduces 1, 4 - (3, 5, 1), (4, 5, 3), -- ancestry-2, layer 0 introduces 1, 3 - (5, 9, 2), -- ancestry-3, layer 0 introduces 2 - (6, 10, 2); -- ancestry-4, layer 0 introduces 2 +INSERT INTO ancestry_feature (id, ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES + -- ancestry-1: + -- layer-2: ourchat 0.5 <- detected by dpkg 1.0 (2); debian: 7 <- detected by os-release 1.0 (1) + -- layer-2: openssl 2.0, debian:7 + (1, 3, 1, 2, 1), (2, 3, 4, 2, 1), + -- ancestry 2: + -- 1(ourchat 0.5; debian:7 layer-2) + -- 3(openssl 1.0; debian:8 layer-3b) + (3, 7, 1, 2, 1), (4, 8, 3, 2, 1), + -- ancestry-3: + -- 2(openssl 1.0, debian:7 layer-1) + -- 1(ourchat 0.5, debian:7 layer-1) + (5, 9, 2, 2, 1), (6, 9, 1, 2, 1), -- vulnerable + -- ancestry-4: + -- same as ancestry-3 + (7, 10, 2, 2, 1), (8, 10, 1, 2, 1); -- vulnerable INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), @@ -103,19 +122,23 @@ INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, name INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES (1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7' +SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('detector', 'id'), (SELECT MAX(id) FROM detector)+1); + SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1); + SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('layer_feature', 'id'), (SELECT MAX(id) FROM layer_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1); -SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1); + SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1); +SELECT pg_catalog.setval(pg_get_serial_sequence('detector', 'id'), (SELECT MAX(id) FROM detector)+1); diff --git a/database/pgsql/testutil.go b/database/pgsql/testutil.go new file mode 100644 index 00000000..05ed4fbf --- /dev/null +++ b/database/pgsql/testutil.go @@ -0,0 +1,263 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" + "github.com/coreos/clair/pkg/testutil" +) + +// int keys must be the consistent with the database ID. +var ( + realFeatures = map[int]database.Feature{ + 1: {"ourchat", "0.5", "dpkg"}, + 2: {"openssl", "1.0", "dpkg"}, + 3: {"openssl", "2.0", "dpkg"}, + 4: {"fake", "2.0", "rpm"}, + } + + realNamespaces = map[int]database.Namespace{ + 1: {"debian:7", "dpkg"}, + 2: {"debian:8", "dpkg"}, + 3: {"fake:1.0", "rpm"}, + } + + realNamespacedFeatures = map[int]database.NamespacedFeature{ + 1: {realFeatures[1], realNamespaces[1]}, + 2: {realFeatures[2], realNamespaces[1]}, + 3: {realFeatures[2], realNamespaces[2]}, + 4: {realFeatures[3], realNamespaces[1]}, + } + + realDetectors = map[int]database.Detector{ + 1: database.NewNamespaceDetector("os-release", "1.0"), + 2: database.NewFeatureDetector("dpkg", "1.0"), + 3: database.NewFeatureDetector("rpm", "1.0"), + 4: database.NewNamespaceDetector("apt-sources", "1.0"), + } + + realLayers = map[int]database.Layer{ + 2: { + Hash: "layer-1", + By: []database.Detector{realDetectors[1], realDetectors[2]}, + Features: []database.LayerFeature{ + {realFeatures[1], realDetectors[2]}, + {realFeatures[2], realDetectors[2]}, + }, + Namespaces: []database.LayerNamespace{ + {realNamespaces[1], realDetectors[1]}, + }, + }, + 6: { + Hash: "layer-4", + By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]}, + Features: []database.LayerFeature{ + {realFeatures[4], realDetectors[3]}, + {realFeatures[3], realDetectors[2]}, + }, + Namespaces: []database.LayerNamespace{ + {realNamespaces[1], realDetectors[1]}, + {realNamespaces[3], realDetectors[4]}, + }, + }, + } + + realAncestries = map[int]database.Ancestry{ + 2: { + Name: "ancestry-2", + By: []database.Detector{realDetectors[2], realDetectors[1]}, + Layers: []database.AncestryLayer{ + { + "layer-0", + []database.AncestryFeature{}, + }, + { + "layer-1", + []database.AncestryFeature{}, + }, + { + "layer-2", + []database.AncestryFeature{ + { + realNamespacedFeatures[1], + realDetectors[2], + realDetectors[1], + }, + }, + }, + { + "layer-3b", + []database.AncestryFeature{ + { + realNamespacedFeatures[3], + realDetectors[2], + realDetectors[1], + }, + }, + }, + }, + }, + } + + realVulnerability = map[int]database.Vulnerability{ + 1: { + Name: "CVE-OPENSSL-1-DEB7", + Namespace: realNamespaces[1], + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: database.HighSeverity, + }, + 2: { + Name: "CVE-NOPE", + Namespace: realNamespaces[1], + Description: "A vulnerability affecting nothing", + Severity: database.UnknownSeverity, + }, + } + + realNotification = map[int]database.VulnerabilityNotification{ + 1: { + NotificationHook: database.NotificationHook{ + Name: "test", + }, + Old: takeVulnerabilityPointerFromMap(realVulnerability, 2), + New: takeVulnerabilityPointerFromMap(realVulnerability, 1), + }, + } + + fakeFeatures = map[int]database.Feature{ + 1: { + Name: "ourchat", + Version: "0.6", + VersionFormat: "dpkg", + }, + } + + fakeNamespaces = map[int]database.Namespace{ + 1: {"green hat", "rpm"}, + } + fakeNamespacedFeatures = map[int]database.NamespacedFeature{ + 1: { + Feature: fakeFeatures[0], + Namespace: realNamespaces[0], + }, + } + + fakeDetector = map[int]database.Detector{ + 1: { + Name: "fake", + Version: "1.0", + DType: database.FeatureDetectorType, + }, + 2: { + Name: "fake2", + Version: "2.0", + DType: database.NamespaceDetectorType, + }, + } +) + +func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability { + x := m[id] + return &x +} + +func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry { + x := m[id] + return &x +} + +func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer { + x := m[id] + return &x +} + +func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace { + rows, err := tx.Query("SELECT name, version_format FROM namespace") + if err != nil { + t.FailNow() + } + defer rows.Close() + + namespaces := []database.Namespace{} + for rows.Next() { + var ns database.Namespace + err := rows.Scan(&ns.Name, &ns.VersionFormat) + if err != nil { + t.FailNow() + } + namespaces = append(namespaces, ns) + } + + return namespaces +} + +func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) && + AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) && + AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New) +} + +func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return testutil.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) && + assert.Equal(t, expected.Limit, actual.Limit) && + assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) && + assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) && + assert.Equal(t, expected.End, actual.End) && + testutil.AssertIntStringMapEqual(t, expected.Affected, actual.Affected) +} + +func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page { + if token == pagination.FirstPageToken { + return Page{} + } + + p := Page{} + if err := key.UnmarshalToken(token, &p); err != nil { + panic(err) + } + + return p +} + +func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token { + token, err := key.MarshalToken(v) + if err != nil { + panic(err) + } + + return token +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go index 9fe2c23b..bfa465b2 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability_test.go @@ -306,14 +306,14 @@ func TestFindVulnerabilityIDs(t *testing.T) { ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) if assert.Nil(t, err) { - if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, ids[0].Int64)) { + if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) { assert.Fail(t, "") } } ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) if assert.Nil(t, err) { - if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, ids[0].Int64)) { + if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) { assert.Fail(t, "") } } diff --git a/pkg/testutil/testutil.go b/pkg/testutil/testutil.go new file mode 100644 index 00000000..91c2abd4 --- /dev/null +++ b/pkg/testutil/testutil.go @@ -0,0 +1,285 @@ +package testutil + +import ( + "encoding/json" + "sort" + "testing" + + "github.com/deckarep/golang-set" + "github.com/stretchr/testify/assert" + + "github.com/coreos/clair/database" +) + +// AssertDetectorsEqual asserts actual detectors are content wise equal to +// expected detectors regardless of the ordering. +func AssertDetectorsEqual(t *testing.T, expected, actual []database.Detector) bool { + if len(expected) != len(actual) { + return assert.Fail(t, "detectors are not equal", "expected: '%v', actual: '%v'", expected, actual) + } + + sort.Slice(expected, func(i, j int) bool { + return expected[i].String() < expected[j].String() + }) + + sort.Slice(actual, func(i, j int) bool { + return actual[i].String() < actual[j].String() + }) + + for i := range expected { + if expected[i] != actual[i] { + return assert.Fail(t, "detectors are not equal", "expected: '%v', actual: '%v'", expected, actual) + } + } + + return true +} + +// AssertAncestryEqual asserts actual ancestry equals to expected ancestry +// content wise. +func AssertAncestryEqual(t *testing.T, expected, actual *database.Ancestry) bool { + if expected == actual { + return true + } + + if actual == nil || expected == nil { + return assert.Equal(t, expected, actual) + } + + if !assert.Equal(t, expected.Name, actual.Name) || !AssertDetectorsEqual(t, expected.By, actual.By) { + return false + } + + if assert.Equal(t, len(expected.Layers), len(actual.Layers)) { + for index := range expected.Layers { + if !AssertAncestryLayerEqual(t, &expected.Layers[index], &actual.Layers[index]) { + return false + } + } + return true + } + return false +} + +// AssertAncestryLayerEqual asserts actual ancestry layer equals to expected +// ancestry layer content wise. +func AssertAncestryLayerEqual(t *testing.T, expected, actual *database.AncestryLayer) bool { + if !assert.Equal(t, expected.Hash, actual.Hash) { + return false + } + + if !assert.Equal(t, len(expected.Features), len(actual.Features), + "layer: %s\nExpected: %v\n Actual: %v", + expected.Hash, expected.Features, actual.Features, + ) { + return false + } + + // feature -> is in actual layer + hitCounter := map[database.AncestryFeature]bool{} + for _, f := range expected.Features { + hitCounter[f] = false + } + + // if there's no extra features and no duplicated features, since expected + // and actual have the same length, their result must equal. + for _, f := range actual.Features { + v, ok := hitCounter[f] + assert.True(t, ok, "unexpected feature %s", f) + assert.False(t, v, "duplicated feature %s", f) + hitCounter[f] = true + } + + for f, visited := range hitCounter { + assert.True(t, visited, "missing feature %s", f) + } + + return true +} + +// AssertElementsEqual asserts that content in actual equals to content in +// expected array regardless of ordering. +// +// Note: This function uses interface wise comparison. +func AssertElementsEqual(t *testing.T, expected, actual []interface{}) bool { + counter := map[interface{}]bool{} + for _, f := range expected { + counter[f] = false + } + + for _, f := range actual { + v, ok := counter[f] + if !assert.True(t, ok, "unexpected element %v\nExpected: %v\n Actual: %v\n", f, expected, actual) { + return false + } + + if !assert.False(t, v, "duplicated element %v\nExpected: %v\n Actual: %v\n", f, expected, actual) { + return false + } + + counter[f] = true + } + + for f, visited := range counter { + if !assert.True(t, visited, "missing feature %v\nExpected: %v\n Actual: %v\n", f, expected, actual) { + return false + } + } + + return true +} + +// AssertFeaturesEqual asserts content in actual equals content in expected +// regardless of ordering. +func AssertFeaturesEqual(t *testing.T, expected, actual []database.Feature) bool { + if assert.Len(t, actual, len(expected)) { + has := map[database.Feature]bool{} + for _, nf := range expected { + has[nf] = false + } + + for _, nf := range actual { + has[nf] = true + } + + for nf, visited := range has { + if !assert.True(t, visited, nf.Name+" is expected") { + return false + } + return true + } + } + return false +} + +// AssertLayerFeaturesEqual asserts content in actual equals to content in +// expected regardless of ordering. +func AssertLayerFeaturesEqual(t *testing.T, expected, actual []database.LayerFeature) bool { + if !assert.Len(t, actual, len(expected)) { + return false + } + + expectedInterfaces := []interface{}{} + for _, e := range expected { + expectedInterfaces = append(expectedInterfaces, e) + } + + actualInterfaces := []interface{}{} + for _, a := range actual { + actualInterfaces = append(actualInterfaces, a) + } + + return AssertElementsEqual(t, expectedInterfaces, actualInterfaces) +} + +// AssertNamespacesEqual asserts content in actual equals to content in +// expected regardless of ordering. +func AssertNamespacesEqual(t *testing.T, expected, actual []database.Namespace) bool { + expectedInterfaces := []interface{}{} + for _, e := range expected { + expectedInterfaces = append(expectedInterfaces, e) + } + + actualInterfaces := []interface{}{} + for _, a := range actual { + actualInterfaces = append(actualInterfaces, a) + } + + return AssertElementsEqual(t, expectedInterfaces, actualInterfaces) +} + +// AssertLayerNamespacesEqual asserts content in actual equals to content in +// expected regardless of ordering. +func AssertLayerNamespacesEqual(t *testing.T, expected, actual []database.LayerNamespace) bool { + expectedInterfaces := []interface{}{} + for _, e := range expected { + expectedInterfaces = append(expectedInterfaces, e) + } + + actualInterfaces := []interface{}{} + for _, a := range actual { + actualInterfaces = append(actualInterfaces, a) + } + + return AssertElementsEqual(t, expectedInterfaces, actualInterfaces) +} + +// AssertLayerEqual asserts actual layer equals to expected layer content wise. +func AssertLayerEqual(t *testing.T, expected, actual *database.Layer) bool { + if expected == actual { + return true + } + + if expected == nil || actual == nil { + return assert.Equal(t, expected, actual) + } + + return assert.Equal(t, expected.Hash, actual.Hash) && + AssertDetectorsEqual(t, expected.By, actual.By) && + AssertLayerFeaturesEqual(t, expected.Features, actual.Features) && + AssertLayerNamespacesEqual(t, expected.Namespaces, actual.Namespaces) +} + +// AssertIntStringMapEqual asserts two maps with integer as key and string as +// value are equal. +func AssertIntStringMapEqual(t *testing.T, expected, actual map[int]string) bool { + checked := mapset.NewSet() + for k, v := range expected { + assert.Equal(t, v, actual[k]) + checked.Add(k) + } + + for k := range actual { + if !assert.True(t, checked.Contains(k)) { + return false + } + } + + return true +} + +// AssertVulnerabilityEqual asserts two vulnerabilities are equal. +func AssertVulnerabilityEqual(t *testing.T, expected, actual *database.Vulnerability) bool { + return assert.Equal(t, expected.Name, actual.Name) && + assert.Equal(t, expected.Link, actual.Link) && + assert.Equal(t, expected.Description, actual.Description) && + assert.Equal(t, expected.Namespace, actual.Namespace) && + assert.Equal(t, expected.Severity, actual.Severity) && + AssertMetadataMapEqual(t, expected.Metadata, actual.Metadata) +} + +func castMetadataMapToInterface(metadata database.MetadataMap) map[string]interface{} { + content, err := json.Marshal(metadata) + if err != nil { + panic(err) + } + + data := make(map[string]interface{}) + if err := json.Unmarshal(content, &data); err != nil { + panic(err) + } + + return data +} + +// AssertMetadataMapEqual asserts two metadata maps are equal. +func AssertMetadataMapEqual(t *testing.T, expected, actual database.MetadataMap) bool { + expectedMap := castMetadataMapToInterface(expected) + actualMap := castMetadataMapToInterface(actual) + checked := mapset.NewSet() + for k, v := range expectedMap { + if !assert.Equal(t, v, (actualMap)[k]) { + return false + } + + checked.Add(k) + } + + for k := range actual { + if !assert.True(t, checked.Contains(k)) { + return false + } + } + + return true +}