diff --git a/database/pgsql/ancestry.go b/database/pgsql/ancestry.go index fa0c0ad5..90ddd811 100644 --- a/database/pgsql/ancestry.go +++ b/database/pgsql/ancestry.go @@ -23,10 +23,11 @@ const ( findAncestryFeatures = ` SELECT namespace.name, namespace.version_format, feature.name, - feature.version, feature.version_format, ancestry_layer.ancestry_index, + feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index, ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id - FROM namespace, feature, namespaced_feature, ancestry_layer, ancestry_feature + FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature WHERE ancestry_layer.ancestry_id = $1 + AND feature_type.id = feature.type AND ancestry_feature.ancestry_layer_id = ancestry_layer.id AND ancestry_feature.namespaced_feature_id = namespaced_feature.id AND namespaced_feature.feature_id = feature.id @@ -256,6 +257,7 @@ func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMa &feature.Feature.Name, &feature.Feature.Version, &feature.Feature.VersionFormat, + &feature.Feature.Type, &index, &featureDetectorID, &namespaceDetectorID, diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index de8b0f20..e4a10928 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -25,6 +25,7 @@ import ( "github.com/pborman/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" @@ -65,19 +66,13 @@ func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database for i := 0; i < numFeatures; i++ { version := rand.Intn(numFeatures) - features[i] = database.Feature{ - Name: featureName, - VersionFormat: featureVersionFormat, - Version: strconv.Itoa(version), - } - + features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat) nsFeatures[i] = database.NamespacedFeature{ Namespace: namespace, Feature: features[i], } } - // insert features if !assert.Nil(t, tx.PersistFeatures(features)) { t.FailNow() } @@ -98,6 +93,7 @@ func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database { Namespace: namespace, FeatureName: featureName, + FeatureType: database.SourcePackage, AffectedVersion: strconv.Itoa(version), FixedInVersion: strconv.Itoa(version), }, @@ -117,7 +113,6 @@ func TestConcurrency(t *testing.T) { t.FailNow() } defer store.Close() - start := time.Now() var wg sync.WaitGroup wg.Add(100) @@ -137,65 +132,39 @@ func TestConcurrency(t *testing.T) { fmt.Println("total", time.Since(start)) } -func genRandomNamespaces(t *testing.T, count int) []database.Namespace { - r := make([]database.Namespace, count) - for i := 0; i < count; i++ { - r[i] = database.Namespace{ - Name: uuid.New(), - VersionFormat: "dpkg", - } - } - return r -} - func TestCaching(t *testing.T) { store, err := openDatabaseForTest("Caching", false) if !assert.Nil(t, err) { t.FailNow() } defer store.Close() - nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store) - - fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities)) - - var wg sync.WaitGroup - wg.Add(2) - go func() { - defer wg.Done() - tx, err := store.Begin() - if !assert.Nil(t, err) { - t.FailNow() - } - - assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures)) - fmt.Println("finished to insert namespaced features") - - tx.Commit() - }() - - go func() { - defer wg.Done() - tx, err := store.Begin() - if !assert.Nil(t, err) { - t.FailNow() - } - - assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities)) - fmt.Println("finished to insert vulnerabilities") - tx.Commit() - - }() - - wg.Wait() - tx, err := store.Begin() if !assert.Nil(t, err) { t.FailNow() } + + require.Nil(t, tx.PersistNamespacedFeatures(nsFeatures)) + if err := tx.Commit(); err != nil { + panic(err) + } + + tx, err = store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } + + require.Nil(t, tx.InsertVulnerabilities(vulnerabilities)) + if err := tx.Commit(); err != nil { + panic(err) + } + + tx, err = store.Begin() + if !assert.Nil(t, err) { + t.FailNow() + } defer tx.Rollback() - // Verify consistency now. affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures) if !assert.Nil(t, err) { t.FailNow() @@ -220,7 +189,7 @@ func TestCaching(t *testing.T) { actualAffectedNames = append(actualAffectedNames, s.Name) } - assert.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0) - assert.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) + require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames) + require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) } } diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index f716130f..3df7ac6a 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -46,6 +46,7 @@ const ( AND nf.feature_id = f.id AND nf.namespace_id = v.namespace_id AND vaf.feature_name = f.name + AND vaf.feature_type = f.type AND vaf.vulnerability_id = v.id AND v.deleted_at IS NULL` @@ -68,6 +69,11 @@ func (tx *pgSession) PersistFeatures(features []database.Feature) error { return nil } + types, err := tx.getFeatureTypeMap() + if err != nil { + return err + } + // Sorting is needed before inserting into database to prevent deadlock. sort.Slice(features, func(i, j int) bool { return features[i].Name < features[j].Name || @@ -78,13 +84,13 @@ func (tx *pgSession) PersistFeatures(features []database.Feature) error { // TODO(Sida): A better interface for bulk insertion is needed. keys := make([]interface{}, 0, len(features)*3) for _, f := range features { - keys = append(keys, f.Name, f.Version, f.VersionFormat) + keys = append(keys, f.Name, f.Version, f.VersionFormat, types.byName[f.Type]) if f.Name == "" || f.Version == "" || f.VersionFormat == "" { return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") } } - _, err := tx.Exec(queryPersistFeature(len(features)), keys...) + _, err = tx.Exec(queryPersistFeature(len(features)), keys...) return handleError("queryPersistFeature", err) } @@ -240,55 +246,31 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea return nil } -// FindAffectedNamespacedFeatures looks up cache table and retrieves all -// vulnerabilities associated with the features. +// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the +// feature. func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { if len(features) == 0 { return nil, nil } - returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) - - // featureMap is used to keep track of duplicated features. - featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{} - // initialize return value and generate unique feature request queries. - for i, f := range features { - returnFeatures[i] = database.NullableAffectedNamespacedFeature{ - AffectedNamespacedFeature: database.AffectedNamespacedFeature{ - NamespacedFeature: f, - }, - } - - featureMap[f] = append(featureMap[f], &returnFeatures[i]) - } - - // query unique namespaced features - distinctFeatures := []database.NamespacedFeature{} - for f := range featureMap { - distinctFeatures = append(distinctFeatures, f) - } - - nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures) + vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) + featureIDs, err := tx.findNamespacedFeatureIDs(features) if err != nil { return nil, err } - toQuery := []int64{} - featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{} - for i, id := range nsFeatureIDs { + for i, id := range featureIDs { if id.Valid { - toQuery = append(toQuery, id.Int64) - for _, f := range featureMap[distinctFeatures[i]] { - f.Valid = id.Valid - featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f) - } + vulnerableFeatures[i].Valid = true + vulnerableFeatures[i].NamespacedFeature = features[i] } } - rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery)) + rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs)) if err != nil { return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) } + defer rows.Close() for rows.Next() { @@ -296,6 +278,7 @@ func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.Namespac featureID int64 vuln database.VulnerabilityWithFixedIn ) + err := rows.Scan(&featureID, &vuln.Name, &vuln.Description, @@ -306,16 +289,19 @@ func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.Namespac &vuln.Namespace.Name, &vuln.Namespace.VersionFormat, ) + if err != nil { return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) } - for _, f := range featureIDMap[featureID] { - f.AffectedBy = append(f.AffectedBy, vuln) + for i, id := range featureIDs { + if id.Valid && id.Int64 == featureID { + vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln) + } } } - return returnFeatures, nil + return vulnerableFeatures, nil } func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { @@ -323,11 +309,10 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) return nil, nil } - nfsMap := map[database.NamespacedFeature]sql.NullInt64{} - keys := make([]interface{}, 0, len(nfs)*4) + nfsMap := map[database.NamespacedFeature]int64{} + keys := make([]interface{}, 0, len(nfs)*5) for _, nf := range nfs { - keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Namespace.Name) - nfsMap[nf] = sql.NullInt64{} + keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name) } rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...) @@ -337,12 +322,12 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) defer rows.Close() var ( - id sql.NullInt64 + id int64 nf database.NamespacedFeature ) for rows.Next() { - err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name) + err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name) nf.Namespace.VersionFormat = nf.VersionFormat if err != nil { return nil, handleError("searchNamespacedFeature", err) @@ -352,7 +337,11 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ids := make([]sql.NullInt64, len(nfs)) for i, nf := range nfs { - ids[i] = nfsMap[nf] + if id, ok := nfsMap[nf]; ok { + ids[i] = sql.NullInt64{id, true} + } else { + ids[i] = sql.NullInt64{} + } } return ids, nil @@ -363,11 +352,17 @@ func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, err return nil, nil } + types, err := tx.getFeatureTypeMap() + if err != nil { + return nil, err + } + fMap := map[database.Feature]sql.NullInt64{} - keys := make([]interface{}, 0, len(fs)*3) + keys := make([]interface{}, 0, len(fs)*4) for _, f := range fs { - keys = append(keys, f.Name, f.Version, f.VersionFormat) + typeID := types.byName[f.Type] + keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID) fMap[f] = sql.NullInt64{} } @@ -382,10 +377,13 @@ func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, err f database.Feature ) for rows.Next() { - err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat) + var typeID int + err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID) if err != nil { return nil, handleError("querySearchFeatureID", err) } + + f.Type = types.byID[typeID] fMap[f] = id } diff --git a/database/pgsql/feature_test.go b/database/pgsql/feature_test.go index 2823e1e8..574bfeab 100644 --- a/database/pgsql/feature_test.go +++ b/database/pgsql/feature_test.go @@ -18,134 +18,53 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/coreos/clair/database" - - // register dpkg feature lister for testing - _ "github.com/coreos/clair/ext/featurefmt/dpkg" ) func TestPersistFeatures(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistFeatures", false) - defer closeTest(t, datastore, tx) + tx, cleanup := createTestPgSession(t, "TestPersistFeatures") + defer cleanup() - f1 := database.Feature{} - f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"} + invalid := database.Feature{} + valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg") - // empty - assert.Nil(t, tx.PersistFeatures([]database.Feature{})) // invalid - assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1})) - // duplicated - assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2})) + require.NotNil(t, tx.PersistFeatures([]database.Feature{invalid})) // existing - assert.Nil(t, tx.PersistFeatures([]database.Feature{f2})) + require.Nil(t, tx.PersistFeatures([]database.Feature{valid})) + require.Nil(t, tx.PersistFeatures([]database.Feature{valid})) - fs := listFeatures(t, tx) - assert.Len(t, fs, 1) - assert.Equal(t, f2, fs[0]) + features := selectAllFeatures(t, tx) + assert.Equal(t, []database.Feature{valid}, features) } func TestPersistNamespacedFeatures(t *testing.T) { - datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) - defer closeTest(t, datastore, tx) + tx, cleanup := createTestPgSessionWithFixtures(t, "TestPersistNamespacedFeatures") + defer cleanup() // existing features - f1 := database.Feature{ - Name: "ourchat", - Version: "0.5", - VersionFormat: "dpkg", - } - + f1 := database.NewSourcePackage("ourchat", "0.5", "dpkg") // non-existing features - f2 := database.Feature{ - Name: "fake!", - } - - f3 := database.Feature{ - Name: "openssl", - Version: "2.0", - VersionFormat: "dpkg", - } - + f2 := database.NewSourcePackage("fake!", "", "") // exising namespace - n1 := database.Namespace{ - Name: "debian:7", - VersionFormat: "dpkg", - } - - n3 := database.Namespace{ - Name: "debian:8", - VersionFormat: "dpkg", - } - + n1 := database.NewNamespace("debian:7", "dpkg") // non-existing namespace - n2 := database.Namespace{ - Name: "debian:non", - VersionFormat: "dpkg", - } - + n2 := database.NewNamespace("debian:non", "dpkg") // existing namespaced feature - nf1 := database.NamespacedFeature{ - Namespace: n1, - Feature: f1, - } - + nf1 := database.NewNamespacedFeature(n1, f1) // invalid namespaced feature - nf2 := database.NamespacedFeature{ - Namespace: n2, - Feature: f2, - } - - // new namespaced feature affected by vulnerability - nf3 := database.NamespacedFeature{ - Namespace: n3, - Feature: f3, - } - + nf2 := database.NewNamespacedFeature(n2, f2) // namespaced features with namespaces or features not in the database will // generate error. assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) - - assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2})) + assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1, *nf2})) // valid case: insert nf3 - assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3})) + assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1})) all := listNamespacedFeatures(t, tx) - assert.Contains(t, all, nf1) - assert.Contains(t, all, nf3) -} - -func TestVulnerableFeature(t *testing.T) { - datastore, tx := openSessionForTest(t, "VulnerableFeature", true) - defer closeTest(t, datastore, tx) - - f1 := database.Feature{ - Name: "openssl", - Version: "1.3", - VersionFormat: "dpkg", - } - - n1 := database.Namespace{ - Name: "debian:7", - VersionFormat: "dpkg", - } - - nf1 := database.NamespacedFeature{ - Namespace: n1, - Feature: f1, - } - assert.Nil(t, tx.PersistFeatures([]database.Feature{f1})) - assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1})) - assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})) - // ensure the namespaced feature is affected correctly - anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}) - if assert.Nil(t, err) && - assert.Len(t, anf, 1) && - assert.True(t, anf[0].Valid) && - assert.Len(t, anf[0].AffectedBy, 1) { - assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name) - } + assert.Contains(t, all, *nf1) } func TestFindAffectedNamespacedFeatures(t *testing.T) { @@ -156,6 +75,7 @@ func TestFindAffectedNamespacedFeatures(t *testing.T) { Name: "openssl", Version: "1.0", VersionFormat: "dpkg", + Type: database.SourcePackage, }, Namespace: database.Namespace{ Name: "debian:7", @@ -173,30 +93,41 @@ func TestFindAffectedNamespacedFeatures(t *testing.T) { } func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature { - rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format + types, err := tx.getFeatureTypeMap() + if err != nil { + panic(err) + } + + rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, f.type, n.name, n.version_format FROM feature AS f, namespace AS n, namespaced_feature AS nf WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`) if err != nil { - t.Error(err) - t.FailNow() + panic(err) } nf := []database.NamespacedFeature{} for rows.Next() { f := database.NamespacedFeature{} - err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat) + var typeID int + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID, &f.Namespace.Name, &f.Namespace.VersionFormat) if err != nil { - t.Error(err) - t.FailNow() + panic(err) } + + f.Type = types.byID[typeID] nf = append(nf, f) } return nf } -func listFeatures(t *testing.T, tx *pgSession) []database.Feature { - rows, err := tx.Query("SELECT name, version, version_format FROM feature") +func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature { + types, err := tx.getFeatureTypeMap() + if err != nil { + panic(err) + } + + rows, err := tx.Query("SELECT name, version, version_format, type FROM feature") if err != nil { t.FailNow() } @@ -204,7 +135,9 @@ func listFeatures(t *testing.T, tx *pgSession) []database.Feature { fs := []database.Feature{} for rows.Next() { f := database.Feature{} - err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) + var typeID int + err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID) + f.Type = types.byID[typeID] if err != nil { t.FailNow() } @@ -233,3 +166,33 @@ func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFe } return false } + +func TestFindNamespacedFeatureIDs(t *testing.T) { + tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNamespacedFeatureIDs") + defer cleanup() + + features := []database.NamespacedFeature{} + expectedIDs := []int{} + for id, feature := range realNamespacedFeatures { + features = append(features, feature) + expectedIDs = append(expectedIDs, id) + } + + features = append(features, realNamespacedFeatures[1]) // test duplicated + expectedIDs = append(expectedIDs, 1) + + namespace := realNamespaces[1] + features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature + + ids, err := tx.findNamespacedFeatureIDs(features) + require.Nil(t, err) + require.Len(t, ids, len(expectedIDs)+1) + for i, id := range ids { + if i == len(ids)-1 { + require.False(t, id.Valid) + } else { + require.True(t, id.Valid) + require.Equal(t, expectedIDs[i], int(id.Int64)) + } + } +} diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 782e541c..a071eb4b 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -37,9 +37,10 @@ const ( SELECT id FROM layer WHERE hash = $1` findLayerFeatures = ` - SELECT f.name, f.version, f.version_format, lf.detector_id - FROM layer_feature AS lf, feature AS f + SELECT f.name, f.version, f.version_format, t.name, lf.detector_id + FROM layer_feature AS lf, feature AS f, feature_type AS t WHERE lf.feature_id = f.id + AND t.id = f.type AND lf.layer_id = $1` findLayerNamespaces = ` @@ -307,7 +308,7 @@ func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([] detectorID int64 feature database.LayerFeature ) - if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &detectorID); err != nil { + if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID); err != nil { return nil, handleError("findLayerFeatures", err) } diff --git a/database/pgsql/layer_test.go b/database/pgsql/layer_test.go index 478b2171..5211eb11 100644 --- a/database/pgsql/layer_test.go +++ b/database/pgsql/layer_test.go @@ -43,12 +43,12 @@ var persistLayerTests = []struct { features: []database.LayerFeature{ {realFeatures[1], realDetectors[1]}, }, - err: "database: parameters are not valid", + err: "parameters are not valid", }, { title: "layer with non-existing feature", name: "random-forest", - err: "database: associated immutable entities are missing in the database", + err: "associated immutable entities are missing in the database", by: []database.Detector{realDetectors[2]}, features: []database.LayerFeature{ {fakeFeatures[1], realDetectors[2]}, @@ -57,7 +57,7 @@ var persistLayerTests = []struct { { title: "layer with non-existing namespace", name: "random-forest2", - err: "database: associated immutable entities are missing in the database", + err: "associated immutable entities are missing in the database", by: []database.Detector{realDetectors[1]}, namespaces: []database.LayerNamespace{ {fakeNamespaces[1], realDetectors[1]}, @@ -66,7 +66,7 @@ var persistLayerTests = []struct { { title: "layer with non-existing detector", name: "random-forest3", - err: "database: associated immutable entities are missing in the database", + err: "associated immutable entities are missing in the database", by: []database.Detector{fakeDetector[1]}, }, { diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 0a23abca..da3b3248 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -211,8 +211,8 @@ func TestInsertVulnerabilityNotifications(t *testing.T) { } func TestFindNewNotification(t *testing.T) { - datastore, tx := openSessionForTest(t, "FindNewNotification", true) - defer closeTest(t, datastore, tx) + tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification") + defer cleanup() noti, ok, err := tx.FindNewNotification(time.Now()) if assert.Nil(t, err) && assert.True(t, ok) { @@ -229,7 +229,7 @@ func TestFindNewNotification(t *testing.T) { assert.Nil(t, err) assert.False(t, ok) // can find the notified after a period of time - noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second))) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) assert.NotEqual(t, time.Time{}, noti.Notified) diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 863445a5..03eda8c7 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -109,6 +109,7 @@ func dropTemplateDatabase(url string, name string) { } } + func TestMain(m *testing.M) { fURL, fName := genTemplateDatabase("fixture", true) nfURL, nfName := genTemplateDatabase("nonfixture", false) diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index 2d4b7e99..5cd5c3c9 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -52,21 +52,22 @@ func querySearchNotDeletedVulnerabilityID(count int) string { func querySearchFeatureID(featureCount int) string { return fmt.Sprintf(` - SELECT id, name, version, version_format - FROM Feature WHERE (name, version, version_format) IN (%s)`, - queryString(3, featureCount), + SELECT id, name, version, version_format, type + FROM Feature WHERE (name, version, version_format, type) IN (%s)`, + queryString(4, featureCount), ) } func querySearchNamespacedFeature(nsfCount int) string { return fmt.Sprintf(` - SELECT nf.id, f.name, f.version, f.version_format, n.name - FROM namespaced_feature AS nf, feature AS f, namespace AS n + SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name + FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t WHERE nf.feature_id = f.id AND nf.namespace_id = n.id AND n.version_format = f.version_format - AND (f.name, f.version, f.version_format, n.name) IN (%s)`, - queryString(4, nsfCount), + AND f.type = t.id + AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`, + queryString(5, nsfCount), ) } @@ -110,10 +111,11 @@ func queryInsertNotifications(count int) string { func queryPersistFeature(count int) string { return queryPersist(count, "feature", - "feature_name_version_version_format_key", + "feature_name_version_version_format_type_key", "name", "version", - "version_format") + "version_format", + "type") } func queryPersistLayerFeature(count int) string { diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testdata/data.sql index e7484209..4c90ae0d 100644 --- a/database/pgsql/testdata/data.sql +++ b/database/pgsql/testdata/data.sql @@ -4,11 +4,12 @@ INSERT INTO namespace (id, name, version_format) VALUES (2, 'debian:8', 'dpkg'), (3, 'fake:1.0', 'rpm'); -INSERT INTO feature (id, name, version, version_format) VALUES - (1, 'ourchat', '0.5', 'dpkg'), - (2, 'openssl', '1.0', 'dpkg'), - (3, 'openssl', '2.0', 'dpkg'), - (4, 'fake', '2.0', 'rpm'); +INSERT INTO feature (id, name, version, version_format, type) VALUES + (1, 'ourchat', '0.5', 'dpkg', 1), + (2, 'openssl', '1.0', 'dpkg', 1), + (3, 'openssl', '2.0', 'dpkg', 1), + (4, 'fake', '2.0', 'rpm', 1), + (5, 'mount', '2.31.1-0.4ubuntu3.1', 'dpkg', 2); INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES (1, 1, 1), -- ourchat 0.5, debian:7 @@ -112,9 +113,9 @@ INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES (3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483'); -INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES -(1, 1, 'openssl', '2.0', '2.0'), -(2, 1, 'libssl', '1.9-abc', '1.9-abc'); +INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin, feature_type) VALUES +(1, 1, 'openssl', '2.0', '2.0', 1), +(2, 1, 'libssl', '1.9-abc', '1.9-abc', 1); INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES (1, 1, 2, 1); diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index 93518a87..e96d6d47 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -39,15 +39,15 @@ const ( ` insertVulnerabilityAffected = ` - INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin) - VALUES ($1, $2, $3, $4) + INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin) + VALUES ($1, $2, $3, $4, $5) RETURNING ID ` searchVulnerabilityAffected = ` - SELECT vulnerability_id, feature_name, affected_version, fixedin - FROM vulnerability_affected_feature - WHERE vulnerability_id = ANY($1) + SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin + FROM vulnerability_affected_feature AS vaf, feature_type AS t + WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1) ` searchVulnerabilityByID = ` @@ -58,7 +58,7 @@ const ( searchVulnerabilityPotentialAffected = ` WITH req AS ( - SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id + SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id FROM vulnerability_affected_feature AS vaf, vulnerability AS v, namespace AS n @@ -69,6 +69,7 @@ const ( SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by FROM feature AS f, namespaced_feature AS nf, req WHERE f.name = req.name + AND f.type = req.type AND nf.namespace_id = req.n_id AND nf.feature_id = f.id` @@ -180,7 +181,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit f database.AffectedFeature ) - err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) + err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion) if err != nil { return nil, handleError("searchVulnerabilityAffected", err) } @@ -220,6 +221,11 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne affectedID int64 ) + types, err := tx.getFeatureTypeMap() + if err != nil { + return nil, err + } + //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { @@ -231,7 +237,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne // affected feature row ID -> affected feature affectedFeatures := map[int64]database.AffectedFeature{} for _, f := range vuln.Affected { - err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID) + err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) if err != nil { return nil, handleError("insertVulnerabilityAffected", err) } diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go index bfa465b2..759bfe2f 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability_test.go @@ -106,6 +106,7 @@ func TestCachingVulnerable(t *testing.T) { Name: "openssl", Version: "1.0", VersionFormat: dpkg.ParserName, + Type: database.SourcePackage, }, Namespace: ns, } @@ -120,6 +121,7 @@ func TestCachingVulnerable(t *testing.T) { { Namespace: ns, FeatureName: "openssl", + FeatureType: database.SourcePackage, AffectedVersion: "2.0", FixedInVersion: "2.1", }, @@ -136,6 +138,7 @@ func TestCachingVulnerable(t *testing.T) { { Namespace: ns, FeatureName: "openssl", + FeatureType: database.SourcePackage, AffectedVersion: "2.1", FixedInVersion: "2.2", }, @@ -209,12 +212,14 @@ func TestFindVulnerabilities(t *testing.T) { Affected: []database.AffectedFeature{ { FeatureName: "openssl", + FeatureType: database.SourcePackage, AffectedVersion: "2.0", FixedInVersion: "2.0", Namespace: ns, }, { FeatureName: "libssl", + FeatureType: database.SourcePackage, AffectedVersion: "1.9-abc", FixedInVersion: "1.9-abc", Namespace: ns, @@ -318,26 +323,3 @@ func TestFindVulnerabilityIDs(t *testing.T) { } } } - -func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { - return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected) -} - -func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool { - if assert.Len(t, actual, len(expected)) { - has := map[database.AffectedFeature]bool{} - for _, i := range expected { - has[i] = false - } - for _, i := range actual { - if visited, ok := has[i]; !ok { - return false - } else if visited { - return false - } - has[i] = true - } - return true - } - return false -}