From 921acb26fe875ed18c95b2f62a73fa3e1a8aa355 Mon Sep 17 00:00:00 2001 From: Sida Chen Date: Wed, 6 Mar 2019 16:34:58 -0500 Subject: [PATCH] pgsql: Split vulnerability.go to files in vulnerability module --- .../{ => vulnerability}/vulnerability.go | 246 ++++++++++++------ .../vulnerability_affected_feature.go | 118 +++++++++ ...lnerability_affected_namespaced_feature.go | 142 ++++++++++ .../{ => vulnerability}/vulnerability_test.go | 223 +++++++++++++--- 4 files changed, 608 insertions(+), 121 deletions(-) rename database/pgsql/{ => vulnerability}/vulnerability.go (55%) create mode 100644 database/pgsql/vulnerability/vulnerability_affected_feature.go create mode 100644 database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go rename database/pgsql/{ => vulnerability}/vulnerability_test.go (50%) diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability/vulnerability.go similarity index 55% rename from database/pgsql/vulnerability.go rename to database/pgsql/vulnerability/vulnerability.go index e96d6d47..4245299a 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability/vulnerability.go @@ -12,23 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package vulnerability import ( "database/sql" "errors" + "fmt" "time" "github.com/lib/pq" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/monitoring" + "github.com/coreos/clair/database/pgsql/page" + "github.com/coreos/clair/database/pgsql/util" "github.com/coreos/clair/ext/versionfmt" + "github.com/coreos/clair/pkg/pagination" ) const ( - lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` - searchVulnerability = ` SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format FROM vulnerability AS v, namespace AS n @@ -38,45 +42,12 @@ const ( AND v.deleted_at IS NULL ` - insertVulnerabilityAffected = ` - INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin) - VALUES ($1, $2, $3, $4, $5) - RETURNING ID - ` - - searchVulnerabilityAffected = ` - SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin - FROM vulnerability_affected_feature AS vaf, feature_type AS t - WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1) - ` - searchVulnerabilityByID = ` SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format FROM vulnerability AS v, namespace AS n WHERE v.namespace_id = n.id AND v.id = $1` - searchVulnerabilityPotentialAffected = ` - WITH req AS ( - SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id - FROM vulnerability_affected_feature AS vaf, - vulnerability AS v, - namespace AS n - WHERE vaf.vulnerability_id = ANY($1) - AND v.id = vaf.vulnerability_id - AND n.id = v.namespace_id - ) - SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by - FROM feature AS f, namespaced_feature AS nf, req - WHERE f.name = req.name - AND f.type = req.type - AND nf.namespace_id = req.n_id - AND nf.feature_id = f.id` - - insertVulnerabilityAffectedNamespacedFeature = ` - INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) - VALUES ($1, $2, $3)` - insertVulnerability = ` WITH ns AS ( SELECT id FROM namespace WHERE name = $6 AND version_format = $7 @@ -92,12 +63,56 @@ const ( AND name = $2 AND deleted_at IS NULL RETURNING id` -) -var ( - errVulnerabilityNotFound = errors.New("vulnerability is not in database") + searchNotificationVulnerableAncestry = ` + SELECT DISTINCT ON (a.id) + a.id, a.name + FROM vulnerability_affected_namespaced_feature AS vanf, + ancestry_layer AS al, ancestry_feature AS af, ancestry AS a + WHERE vanf.vulnerability_id = $1 + AND a.id >= $2 + AND al.ancestry_id = a.id + AND al.id = af.ancestry_layer_id + AND af.namespaced_feature_id = vanf.namespaced_feature_id + ORDER BY a.id ASC + LIMIT $3;` ) +func queryInvalidateVulnerabilityCache(count int) string { + return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature + WHERE vulnerability_id IN (%s)`, + util.QueryString(1, count)) +} + +// NOTE(Sida): Every search query can only have count less than postgres set +// stack depth. IN will be resolved to nested OR_s and the parser might exceed +// stack depth. TODO(Sida): Generate different queries for different count: if +// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for +// count > 65535, use is expected to split data into batches. +func querySearchLastDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT vid, vname, nname FROM ( + SELECT v.id AS vid, v.name AS vname, n.name AS nname, + row_number() OVER ( + PARTITION by (v.name, n.name) + ORDER BY v.deleted_at DESC + ) AS rownum + FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id + AND (v.name, n.name) IN ( %s ) + AND v.deleted_at IS NOT NULL + ) tmp WHERE rownum <= 1`, + util.QueryString(2, count)) +} + +func querySearchNotDeletedVulnerabilityID(count int) string { + return fmt.Sprintf(` + SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n + WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) + AND v.deleted_at IS NULL`, + util.QueryString(2, count)) +} + type affectedAncestry struct { name string id int64 @@ -113,8 +128,8 @@ type affectedFeatureRows struct { rows map[int64]database.AffectedFeature } -func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { - defer observeQueryTime("findVulnerabilities", "", time.Now()) +func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now()) resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) vulnIDMap := map[int64][]*database.NullableVulnerability{} @@ -151,7 +166,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit if err != nil && err != sql.ErrNoRows { stmt.Close() - return nil, handleError("searchVulnerability", err) + return nil, util.HandleError("searchVulnerability", err) } vuln.Valid = id.Valid resultVuln[i] = vuln @@ -161,7 +176,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit } if err := stmt.Close(); err != nil { - return nil, handleError("searchVulnerability", err) + return nil, util.HandleError("searchVulnerability", err) } toQuery := make([]int64, 0, len(vulnIDMap)) @@ -172,7 +187,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit // load vulnerability affected features rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) if err != nil { - return nil, handleError("searchVulnerabilityAffected", err) + return nil, util.HandleError("searchVulnerabilityAffected", err) } for rows.Next() { @@ -183,7 +198,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion) if err != nil { - return nil, handleError("searchVulnerabilityAffected", err) + return nil, util.HandleError("searchVulnerabilityAffected", err) } for _, vuln := range vulnIDMap[id] { @@ -195,41 +210,40 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit return resultVuln, nil } -func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { - defer observeQueryTime("insertVulnerabilities", "all", time.Now()) +func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error { + defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now()) // bulk insert vulnerabilities - vulnIDs, err := tx.insertVulnerabilities(vulnerabilities) + vulnIDs, err := insertVulnerabilities(tx, vulnerabilities) if err != nil { return err } // bulk insert vulnerability affected features - vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities) + vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities) if err != nil { return err } - return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap) + return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap) } // insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. // // i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. -func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { +func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { var ( vulnFeature = map[int64]affectedFeatureRows{} affectedID int64 ) - types, err := tx.getFeatureTypeMap() + types, err := feature.GetFeatureTypeMap(tx) if err != nil { return nil, err } - //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { - return nil, handleError("insertVulnerabilityAffected", err) + return nil, util.HandleError("insertVulnerabilityAffected", err) } defer stmt.Close() @@ -237,9 +251,9 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne // affected feature row ID -> affected feature affectedFeatures := map[int64]database.AffectedFeature{} for _, f := range vuln.Affected { - err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) + err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) if err != nil { - return nil, handleError("insertVulnerabilityAffected", err) + return nil, util.HandleError("insertVulnerabilityAffected", err) } affectedFeatures[affectedID] = f } @@ -251,7 +265,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne // insertVulnerabilities inserts a set of unique vulnerabilities into database, // under the assumption that all vulnerabilities are valid. -func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { +func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { var ( vulnID int64 vulnIDs = make([]int64, 0, len(vulnerabilities)) @@ -274,7 +288,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerability) if err != nil { - return nil, handleError("insertVulnerability", err) + return nil, util.HandleError("insertVulnerability", err) } defer stmt.Close() @@ -283,7 +297,7 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil vuln.Link, &vuln.Severity, &vuln.Metadata, vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) if err != nil { - return nil, handleError("insertVulnerability", err) + return nil, util.HandleError("insertVulnerability", err) } vulnIDs = append(vulnIDs, vulnID) @@ -292,19 +306,19 @@ func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.Vulnerabil return vulnIDs, nil } -func (tx *pgSession) lockFeatureVulnerabilityCache() error { +func LockFeatureVulnerabilityCache(tx *sql.Tx) error { _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { - return handleError("lockVulnerabilityAffects", err) + return util.HandleError("lockVulnerabilityAffects", err) } return nil } // cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID // to affected feature rows and caches them. -func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error { +func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error { // Prevent InsertNamespacedFeatures to modify it. - err := tx.lockFeatureVulnerabilityCache() + err := LockFeatureVulnerabilityCache(tx) if err != nil { return err } @@ -316,7 +330,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) if err != nil { - return handleError("searchVulnerabilityPotentialAffected", err) + return util.HandleError("searchVulnerabilityPotentialAffected", err) } defer rows.Close() @@ -332,7 +346,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) if err != nil { - return handleError("searchVulnerabilityPotentialAffected", err) + return util.HandleError("searchVulnerabilityPotentialAffected", err) } candidate, ok := affected[vulnID].rows[addedBy] @@ -361,7 +375,7 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int for _, r := range relation { result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) if err != nil { - return handleError("insertVulnerabilityAffectedNamespacedFeature", err) + return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err) } if num, err := result.RowsAffected(); err == nil { @@ -377,27 +391,27 @@ func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int return nil } -func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { - defer observeQueryTime("DeleteVulnerability", "all", time.Now()) +func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error { + defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now()) - vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities) + vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities) if err != nil { return err } - if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil { + if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil { return err } return nil } -func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error { +func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error { if len(vulnerabilityIDs) == 0 { return nil } // Prevent InsertNamespacedFeatures to modify it. - err := tx.lockFeatureVulnerabilityCache() + err := LockFeatureVulnerabilityCache(tx) if err != nil { return err } @@ -410,13 +424,13 @@ func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) erro _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) if err != nil { - return handleError("removeVulnerabilityAffectedFeature", err) + return util.HandleError("removeVulnerabilityAffectedFeature", err) } return nil } -func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) { +func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) { var ( vulnID sql.NullInt64 vulnIDs []int64 @@ -425,17 +439,17 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul // mark vulnerabilities deleted stmt, err := tx.Prepare(removeVulnerability) if err != nil { - return nil, handleError("removeVulnerability", err) + return nil, util.HandleError("removeVulnerability", err) } defer stmt.Close() for _, vuln := range vulnerabilities { err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) if err != nil { - return nil, handleError("removeVulnerability", err) + return nil, util.HandleError("removeVulnerability", err) } if !vulnID.Valid { - return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) + return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) } vulnIDs = append(vulnIDs, vulnID.Int64) } @@ -444,15 +458,15 @@ func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.Vul // findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in // database and the order of output array is not guaranteed. -func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { - return tx.findVulnerabilityIDs(vulnIDs, true) +func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return FindVulnerabilityIDs(tx, vulnIDs, true) } -func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { - return tx.findVulnerabilityIDs(vulnIDs, false) +func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { + return FindVulnerabilityIDs(tx, vulnIDs, false) } -func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { +func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { if len(vulnIDs) == 0 { return nil, nil } @@ -474,7 +488,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi rows, err := tx.Query(query, keys...) if err != nil { - return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err) + return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err) } defer rows.Close() @@ -485,7 +499,7 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi for rows.Next() { err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) if err != nil { - return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) + return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) } vulnIDMap[vulnID] = id } @@ -497,3 +511,67 @@ func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, wi return ids, nil } + +func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) { + vulnPage := database.PagedVulnerableAncestries{Limit: limit} + currentPage := page.Page{0} + if currentToken != pagination.FirstPageToken { + if err := key.UnmarshalToken(currentToken, ¤tPage); err != nil { + return vulnPage, err + } + } + + if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( + &vulnPage.Name, + &vulnPage.Description, + &vulnPage.Link, + &vulnPage.Severity, + &vulnPage.Metadata, + &vulnPage.Namespace.Name, + &vulnPage.Namespace.VersionFormat, + ); err != nil { + return vulnPage, util.HandleError("searchVulnerabilityByID", err) + } + + // the last result is used for the next page's startID + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) + if err != nil { + return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err) + } + defer rows.Close() + + ancestries := []affectedAncestry{} + for rows.Next() { + var ancestry affectedAncestry + err := rows.Scan(&ancestry.id, &ancestry.name) + if err != nil { + return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err) + } + ancestries = append(ancestries, ancestry) + } + + lastIndex := 0 + if len(ancestries)-1 < limit { + lastIndex = len(ancestries) + vulnPage.End = true + } else { + // Use the last ancestry's ID as the next page. + lastIndex = len(ancestries) - 1 + vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id}) + if err != nil { + return vulnPage, err + } + } + + vulnPage.Affected = map[int]string{} + for _, ancestry := range ancestries[0:lastIndex] { + vulnPage.Affected[int(ancestry.id)] = ancestry.name + } + + vulnPage.Current, err = key.MarshalToken(currentPage) + if err != nil { + return vulnPage, err + } + + return vulnPage, nil +} diff --git a/database/pgsql/vulnerability/vulnerability_affected_feature.go b/database/pgsql/vulnerability/vulnerability_affected_feature.go new file mode 100644 index 00000000..97716dd6 --- /dev/null +++ b/database/pgsql/vulnerability/vulnerability_affected_feature.go @@ -0,0 +1,118 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vulnerability + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/ext/versionfmt" + "github.com/lib/pq" +) + +const ( + searchPotentialAffectingVulneraibilities = ` + SELECT nf.id, v.id, vaf.affected_version, vaf.id + FROM vulnerability_affected_feature AS vaf, vulnerability AS v, + namespaced_feature AS nf, feature AS f + WHERE nf.id = ANY($1) + AND nf.feature_id = f.id + AND nf.namespace_id = v.namespace_id + AND vaf.feature_name = f.name + AND vaf.feature_type = f.type + AND vaf.vulnerability_id = v.id + AND v.deleted_at IS NULL` + insertVulnerabilityAffected = ` + INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin) + VALUES ($1, $2, $3, $4, $5) + RETURNING ID + ` + searchVulnerabilityAffected = ` + SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin + FROM vulnerability_affected_feature AS vaf, feature_type AS t + WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1) + ` + + searchVulnerabilityPotentialAffected = ` + WITH req AS ( + SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id + FROM vulnerability_affected_feature AS vaf, + vulnerability AS v, + namespace AS n + WHERE vaf.vulnerability_id = ANY($1) + AND v.id = vaf.vulnerability_id + AND n.id = v.namespace_id + ) + SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by + FROM feature AS f, namespaced_feature AS nf, req + WHERE f.name = req.name + AND f.type = req.type + AND nf.namespace_id = req.n_id + AND nf.feature_id = f.id` +) + +type vulnerabilityCache struct { + nsFeatureID int64 + vulnID int64 + vulnAffectingID int64 +} + +func SearchAffectingVulnerabilities(tx *sql.Tx, features []database.NamespacedFeature) ([]vulnerabilityCache, error) { + if len(features) == 0 { + return nil, nil + } + + ids, err := feature.FindNamespacedFeatureIDs(tx, features) + if err != nil { + return nil, err + } + + fMap := map[int64]database.NamespacedFeature{} + for i, f := range features { + if !ids[i].Valid { + return nil, database.ErrMissingEntities + } + fMap[ids[i].Int64] = f + } + + cacheTable := []vulnerabilityCache{} + rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids)) + if err != nil { + return nil, util.HandleError("searchPotentialAffectingVulneraibilities", err) + } + + defer rows.Close() + for rows.Next() { + var ( + cache vulnerabilityCache + affected string + ) + + err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID) + if err != nil { + return nil, err + } + + if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil { + return nil, err + } else if ok { + cacheTable = append(cacheTable, cache) + } + } + + return cacheTable, nil +} diff --git a/database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go b/database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go new file mode 100644 index 00000000..2a09fe5b --- /dev/null +++ b/database/pgsql/vulnerability/vulnerability_affected_namespaced_feature.go @@ -0,0 +1,142 @@ +// Copyright 2019 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vulnerability + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/util" + "github.com/lib/pq" + log "github.com/sirupsen/logrus" +) + +const ( + searchNamespacedFeaturesVulnerabilities = ` + SELECT vanf.namespaced_feature_id, v.name, v.description, v.link, + v.severity, v.metadata, vaf.fixedin, n.name, n.version_format + FROM vulnerability_affected_namespaced_feature AS vanf, + Vulnerability AS v, + vulnerability_affected_feature AS vaf, + namespace AS n + WHERE vanf.namespaced_feature_id = ANY($1) + AND vaf.id = vanf.added_by + AND v.id = vanf.vulnerability_id + AND n.id = v.namespace_id + AND v.deleted_at IS NULL` + + lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE` + + insertVulnerabilityAffectedNamespacedFeature = ` + INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by) + VALUES ($1, $2, $3)` +) + +func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string { + return util.QueryPersist(count, "vulnerability_affected_namespaced_feature", + "vulnerability_affected_namesp_vulnerability_id_namespaced_f_key", + "vulnerability_id", + "namespaced_feature_id", + "added_by") +} + +// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the +// feature. +func FindAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { + if len(features) == 0 { + return nil, nil + } + + vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) + featureIDs, err := feature.FindNamespacedFeatureIDs(tx, features) + if err != nil { + return nil, err + } + + for i, id := range featureIDs { + if id.Valid { + vulnerableFeatures[i].Valid = true + vulnerableFeatures[i].NamespacedFeature = features[i] + } + } + + rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs)) + if err != nil { + return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err) + } + defer rows.Close() + + for rows.Next() { + var ( + featureID int64 + vuln database.VulnerabilityWithFixedIn + ) + + err := rows.Scan(&featureID, + &vuln.Name, + &vuln.Description, + &vuln.Link, + &vuln.Severity, + &vuln.Metadata, + &vuln.FixedInVersion, + &vuln.Namespace.Name, + &vuln.Namespace.VersionFormat, + ) + + if err != nil { + return nil, util.HandleError("searchNamespacedFeaturesVulnerabilities", err) + } + + for i, id := range featureIDs { + if id.Valid && id.Int64 == featureID { + vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln) + } + } + } + + return vulnerableFeatures, nil +} + +func CacheAffectedNamespacedFeatures(tx *sql.Tx, features []database.NamespacedFeature) error { + if len(features) == 0 { + return nil + } + + _, err := tx.Exec(lockVulnerabilityAffects) + if err != nil { + return util.HandleError("lockVulnerabilityAffects", err) + } + + cache, err := SearchAffectingVulnerabilities(tx, features) + + keys := make([]interface{}, 0, len(cache)*3) + for _, c := range cache { + keys = append(keys, c.vulnID, c.nsFeatureID, c.vulnAffectingID) + } + + if len(cache) == 0 { + return nil + } + + affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...) + if err != nil { + return util.HandleError("persistVulnerabilityAffectedNamespacedFeature", err) + } + if count, err := affected.RowsAffected(); err != nil { + log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count) + } + return nil +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability/vulnerability_test.go similarity index 50% rename from database/pgsql/vulnerability_test.go rename to database/pgsql/vulnerability/vulnerability_test.go index 759bfe2f..d911adc6 100644 --- a/database/pgsql/vulnerability_test.go +++ b/database/pgsql/vulnerability/vulnerability_test.go @@ -12,19 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package vulnerability import ( + "database/sql" + "math/rand" + "strconv" "testing" + "github.com/pborman/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/feature" + "github.com/coreos/clair/database/pgsql/namespace" + "github.com/coreos/clair/database/pgsql/testutil" + "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt/dpkg" + "github.com/coreos/clair/pkg/strutil" ) func TestInsertVulnerabilities(t *testing.T) { - store, tx := openSessionForTest(t, "InsertVulnerabilities", true) + store, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilities") + defer cleanup() ns1 := database.Namespace{ Name: "name", @@ -56,45 +67,48 @@ func TestInsertVulnerabilities(t *testing.T) { Vulnerability: v2, } + tx, err := store.Begin() + require.Nil(t, err) + // empty - err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{}) assert.Nil(t, err) // invalid content: vwa1 is invalid - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa1, vwa2}) assert.NotNil(t, err) - tx = restartSession(t, store, tx, false) + tx = testutil.RestartTransaction(store, tx, false) // invalid content: duplicated input - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2, vwa2}) assert.NotNil(t, err) - tx = restartSession(t, store, tx, false) + tx = testutil.RestartTransaction(store, tx, false) // valid content - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2}) assert.Nil(t, err) - tx = restartSession(t, store, tx, true) + tx = testutil.RestartTransaction(store, tx, true) // ensure the content is in database - vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) + vulns, err := FindVulnerabilities(tx, []database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}}) if assert.Nil(t, err) && assert.Len(t, vulns, 1) { assert.True(t, vulns[0].Valid) } - tx = restartSession(t, store, tx, false) + tx = testutil.RestartTransaction(store, tx, false) // valid content: vwa2 removed and inserted - err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) + err = DeleteVulnerabilities(tx, []database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}}) assert.Nil(t, err) - err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2}) + err = InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vwa2}) assert.Nil(t, err) - closeTest(t, store, tx) + require.Nil(t, tx.Rollback()) } func TestCachingVulnerable(t *testing.T) { - datastore, tx := openSessionForTest(t, "CachingVulnerable", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "CachingVulnerable") + defer cleanup() ns := database.Namespace{ Name: "debian:8", @@ -163,11 +177,8 @@ func TestCachingVulnerable(t *testing.T) { FixedInVersion: "2.2", } - if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) { - t.FailNow() - } - - r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f}) + require.Nil(t, InsertVulnerabilities(tx, []database.VulnerabilityWithAffected{vuln, vuln2})) + r, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{f}) assert.Nil(t, err) assert.Len(t, r, 1) for _, anf := range r { @@ -186,10 +197,10 @@ func TestCachingVulnerable(t *testing.T) { } func TestFindVulnerabilities(t *testing.T) { - datastore, tx := openSessionForTest(t, "FindVulnerabilities", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilities") + defer cleanup() - vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + vuln, err := FindVulnerabilities(tx, []database.VulnerabilityID{ {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-NOPE", Namespace: "debian:7"}, {Name: "CVE-NOT HERE"}, @@ -255,7 +266,7 @@ func TestFindVulnerabilities(t *testing.T) { expected, ok := expectedExistingMap[key] if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) { - assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) + testutil.AssertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected) } } else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) { t.FailNow() @@ -264,7 +275,7 @@ func TestFindVulnerabilities(t *testing.T) { } // same vulnerability - r, err := tx.FindVulnerabilities([]database.VulnerabilityID{ + r, err := FindVulnerabilities(tx, []database.VulnerabilityID{ {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, {Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}, }) @@ -273,22 +284,22 @@ func TestFindVulnerabilities(t *testing.T) { for _, vuln := range r { if assert.True(t, vuln.Valid) { expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}] - assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) + testutil.AssertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected) } } } } func TestDeleteVulnerabilities(t *testing.T) { - datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteVulnerabilities") + defer cleanup() remove := []database.VulnerabilityID{} // empty case - assert.Nil(t, tx.DeleteVulnerabilities(remove)) + assert.Nil(t, DeleteVulnerabilities(tx, remove)) // invalid case remove = append(remove, database.VulnerabilityID{}) - assert.NotNil(t, tx.DeleteVulnerabilities(remove)) + assert.NotNil(t, DeleteVulnerabilities(tx, remove)) // valid case validRemove := []database.VulnerabilityID{ @@ -296,8 +307,8 @@ func TestDeleteVulnerabilities(t *testing.T) { {Name: "CVE-NOPE", Namespace: "debian:7"}, } - assert.Nil(t, tx.DeleteVulnerabilities(validRemove)) - vuln, err := tx.FindVulnerabilities(validRemove) + assert.Nil(t, DeleteVulnerabilities(tx, validRemove)) + vuln, err := FindVulnerabilities(tx, validRemove) if assert.Nil(t, err) { for _, v := range vuln { assert.False(t, v.Valid) @@ -306,20 +317,158 @@ func TestDeleteVulnerabilities(t *testing.T) { } func TestFindVulnerabilityIDs(t *testing.T) { - store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true) - defer closeTest(t, store, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindVulnerabilityIDs") + defer cleanup() - ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) + ids, err := FindLatestDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) if assert.Nil(t, err) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) { assert.Fail(t, "") } } - ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) + ids, err = FindNotDeletedVulnerabilityIDs(tx, []database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) if assert.Nil(t, err) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) { assert.Fail(t, "") } } } + +func TestFindAffectedNamespacedFeatures(t *testing.T) { + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "FindAffectedNamespacedFeatures") + defer cleanup() + + ns := database.NamespacedFeature{ + Feature: database.Feature{ + Name: "openssl", + Version: "1.0", + VersionFormat: "dpkg", + Type: database.SourcePackage, + }, + Namespace: database.Namespace{ + Name: "debian:7", + VersionFormat: "dpkg", + }, + } + + ans, err := FindAffectedNamespacedFeatures(tx, []database.NamespacedFeature{ns}) + if assert.Nil(t, err) && + assert.Len(t, ans, 1) && + assert.True(t, ans[0].Valid) && + assert.Len(t, ans[0].AffectedBy, 1) { + assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name) + } +} + +func genRandomVulnerabilityAndNamespacedFeature(t *testing.T, store *sql.DB) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) { + tx, err := store.Begin() + if err != nil { + panic(err) + } + + numFeatures := 100 + numVulnerabilities := 100 + + featureName := "TestFeature" + featureVersionFormat := dpkg.ParserName + // Insert the namespace on which we'll work. + ns := database.Namespace{ + Name: "TestRaceAffectsFeatureNamespace1", + VersionFormat: dpkg.ParserName, + } + + if !assert.Nil(t, namespace.PersistNamespaces(tx, []database.Namespace{ns})) { + t.FailNow() + } + + // Generate Distinct random features + features := make([]database.Feature, numFeatures) + nsFeatures := make([]database.NamespacedFeature, numFeatures) + for i := 0; i < numFeatures; i++ { + version := rand.Intn(numFeatures) + + features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat) + nsFeatures[i] = database.NamespacedFeature{ + Namespace: ns, + Feature: features[i], + } + } + + if !assert.Nil(t, feature.PersistFeatures(tx, features)) { + t.FailNow() + } + + // Generate vulnerabilities. + vulnerabilities := []database.VulnerabilityWithAffected{} + for i := 0; i < numVulnerabilities; i++ { + // any version less than this is vulnerable + version := rand.Intn(numFeatures) + 1 + + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: uuid.New(), + Namespace: ns, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{ + { + Namespace: ns, + FeatureName: featureName, + FeatureType: database.SourcePackage, + AffectedVersion: strconv.Itoa(version), + FixedInVersion: strconv.Itoa(version), + }, + }, + } + + vulnerabilities = append(vulnerabilities, vulnerability) + } + tx.Commit() + + return nsFeatures, vulnerabilities +} + +func TestVulnChangeAffectsVulnerableFeatures(t *testing.T) { + db, cleanup := testutil.CreateTestDB(t, "caching") + defer cleanup() + + nsFeatures, vulnerabilities := genRandomVulnerabilityAndNamespacedFeature(t, db) + tx, err := db.Begin() + require.Nil(t, err) + + require.Nil(t, feature.PersistNamespacedFeatures(tx, nsFeatures)) + require.Nil(t, tx.Commit()) + + tx, err = db.Begin() + require.Nil(t, InsertVulnerabilities(tx, vulnerabilities)) + require.Nil(t, tx.Commit()) + + tx, err = db.Begin() + require.Nil(t, err) + defer tx.Rollback() + + affected, err := FindAffectedNamespacedFeatures(tx, nsFeatures) + require.Nil(t, err) + + for _, ansf := range affected { + require.True(t, ansf.Valid) + + expectedAffectedNames := []string{} + for _, vuln := range vulnerabilities { + if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil { + if ok { + expectedAffectedNames = append(expectedAffectedNames, vuln.Name) + } + } + } + + actualAffectedNames := []string{} + for _, s := range ansf.AffectedBy { + actualAffectedNames = append(actualAffectedNames, s.Name) + } + + require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames) + require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) + } +}