// Copyright 2017 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package vulnerability import ( "database/sql" "errors" "fmt" "time" "github.com/lib/pq" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/feature" "github.com/coreos/clair/database/pgsql/monitoring" "github.com/coreos/clair/database/pgsql/page" "github.com/coreos/clair/database/pgsql/util" "github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/pkg/pagination" ) const ( searchVulnerability = ` SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format FROM vulnerability AS v, namespace AS n WHERE v.namespace_id = n.id AND v.name = $1 AND n.name = $2 AND v.deleted_at IS NULL ` searchVulnerabilityByID = ` SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format FROM vulnerability AS v, namespace AS n WHERE v.namespace_id = n.id AND v.id = $1` insertVulnerability = ` WITH ns AS ( SELECT id FROM namespace WHERE name = $6 AND version_format = $7 ) INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP) RETURNING id` removeVulnerability = ` UPDATE Vulnerability SET deleted_at = CURRENT_TIMESTAMP WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) AND name = $2 AND deleted_at IS NULL RETURNING id` searchNotificationVulnerableAncestry = ` SELECT DISTINCT ON (a.id) a.id, a.name FROM vulnerability_affected_namespaced_feature AS vanf, ancestry_layer AS al, ancestry_feature AS af, ancestry AS a WHERE vanf.vulnerability_id = $1 AND a.id >= $2 AND al.ancestry_id = a.id AND al.id = af.ancestry_layer_id AND af.namespaced_feature_id = vanf.namespaced_feature_id ORDER BY a.id ASC LIMIT $3;` ) func queryInvalidateVulnerabilityCache(count int) string { return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature WHERE vulnerability_id IN (%s)`, util.QueryString(1, count)) } // NOTE(Sida): Every search query can only have count less than postgres set // stack depth. IN will be resolved to nested OR_s and the parser might exceed // stack depth. TODO(Sida): Generate different queries for different count: if // count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for // count > 65535, use is expected to split data into batches. func querySearchLastDeletedVulnerabilityID(count int) string { return fmt.Sprintf(` SELECT vid, vname, nname FROM ( SELECT v.id AS vid, v.name AS vname, n.name AS nname, row_number() OVER ( PARTITION by (v.name, n.name) ORDER BY v.deleted_at DESC ) AS rownum FROM vulnerability AS v, namespace AS n WHERE v.namespace_id = n.id AND (v.name, n.name) IN ( %s ) AND v.deleted_at IS NOT NULL ) tmp WHERE rownum <= 1`, util.QueryString(2, count)) } func querySearchNotDeletedVulnerabilityID(count int) string { return fmt.Sprintf(` SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s) AND v.deleted_at IS NULL`, util.QueryString(2, count)) } type affectedAncestry struct { name string id int64 } type affectRelation struct { vulnerabilityID int64 namespacedFeatureID int64 addedBy int64 } type affectedFeatureRows struct { rows map[int64]database.AffectedFeature } func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now()) resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) vulnIDMap := map[int64][]*database.NullableVulnerability{} //TODO(Sida): Change to bulk search. stmt, err := tx.Prepare(searchVulnerability) if err != nil { return nil, err } // load vulnerabilities for i, key := range vulnerabilities { var ( id sql.NullInt64 vuln = database.NullableVulnerability{ VulnerabilityWithAffected: database.VulnerabilityWithAffected{ Vulnerability: database.Vulnerability{ Name: key.Name, Namespace: database.Namespace{ Name: key.Namespace, }, }, }, } ) err := stmt.QueryRow(key.Name, key.Namespace).Scan( &id, &vuln.Description, &vuln.Link, &vuln.Severity, &vuln.Metadata, &vuln.Namespace.VersionFormat, ) if err != nil && err != sql.ErrNoRows { stmt.Close() return nil, util.HandleError("searchVulnerability", err) } vuln.Valid = id.Valid resultVuln[i] = vuln if id.Valid { vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i]) } } if err := stmt.Close(); err != nil { return nil, util.HandleError("searchVulnerability", err) } toQuery := make([]int64, 0, len(vulnIDMap)) for id := range vulnIDMap { toQuery = append(toQuery, id) } // load vulnerability affected features rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) if err != nil { return nil, util.HandleError("searchVulnerabilityAffected", err) } for rows.Next() { var ( id int64 f database.AffectedFeature ) err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion) if err != nil { return nil, util.HandleError("searchVulnerabilityAffected", err) } for _, vuln := range vulnIDMap[id] { f.Namespace = vuln.Namespace vuln.Affected = append(vuln.Affected, f) } } return resultVuln, nil } func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error { defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now()) // bulk insert vulnerabilities vulnIDs, err := insertVulnerabilities(tx, vulnerabilities) if err != nil { return err } // bulk insert vulnerability affected features vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities) if err != nil { return err } return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap) } // insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. // // i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { var ( vulnFeature = map[int64]affectedFeatureRows{} affectedID int64 ) types, err := feature.GetFeatureTypeMap(tx) if err != nil { return nil, err } stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { return nil, util.HandleError("insertVulnerabilityAffected", err) } defer stmt.Close() for i, vuln := range vulnerabilities { // affected feature row ID -> affected feature affectedFeatures := map[int64]database.AffectedFeature{} for _, f := range vuln.Affected { err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID) if err != nil { return nil, util.HandleError("insertVulnerabilityAffected", err) } affectedFeatures[affectedID] = f } vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures} } return vulnFeature, nil } // insertVulnerabilities inserts a set of unique vulnerabilities into database, // under the assumption that all vulnerabilities are valid. func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { var ( vulnID int64 vulnIDs = make([]int64, 0, len(vulnerabilities)) vulnMap = map[database.VulnerabilityID]struct{}{} ) for _, v := range vulnerabilities { key := database.VulnerabilityID{ Name: v.Name, Namespace: v.Namespace.Name, } // Ensure uniqueness of vulnerability IDs if _, ok := vulnMap[key]; ok { return nil, errors.New("inserting duplicated vulnerabilities is not allowed") } vulnMap[key] = struct{}{} } //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerability) if err != nil { return nil, util.HandleError("insertVulnerability", err) } defer stmt.Close() for _, vuln := range vulnerabilities { err := stmt.QueryRow(vuln.Name, vuln.Description, vuln.Link, &vuln.Severity, &vuln.Metadata, vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) if err != nil { return nil, util.HandleError("insertVulnerability", err) } vulnIDs = append(vulnIDs, vulnID) } return vulnIDs, nil } func LockFeatureVulnerabilityCache(tx *sql.Tx) error { _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { return util.HandleError("lockVulnerabilityAffects", err) } return nil } // cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID // to affected feature rows and caches them. func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error { // Prevent InsertNamespacedFeatures to modify it. err := LockFeatureVulnerabilityCache(tx) if err != nil { return err } vulnIDs := []int64{} for id := range affected { vulnIDs = append(vulnIDs, id) } rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) if err != nil { return util.HandleError("searchVulnerabilityPotentialAffected", err) } defer rows.Close() relation := []affectRelation{} for rows.Next() { var ( vulnID int64 nsfID int64 fVersion string addedBy int64 ) err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) if err != nil { return util.HandleError("searchVulnerabilityPotentialAffected", err) } candidate, ok := affected[vulnID].rows[addedBy] if !ok { return errors.New("vulnerability affected feature not found") } if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat, fVersion, candidate.AffectedVersion); err == nil { if in { relation = append(relation, affectRelation{ vulnerabilityID: vulnID, namespacedFeatureID: nsfID, addedBy: addedBy, }) } } else { return err } } //TODO(Sida): Change to bulk insert. for _, r := range relation { result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) if err != nil { return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err) } if num, err := result.RowsAffected(); err == nil { if num <= 0 { return errors.New("Nothing cached in database") } } else { return err } } log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation)) return nil } func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error { defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now()) vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities) if err != nil { return err } if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil { return err } return nil } func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error { if len(vulnerabilityIDs) == 0 { return nil } // Prevent InsertNamespacedFeatures to modify it. err := LockFeatureVulnerabilityCache(tx) if err != nil { return err } //TODO(Sida): Make a nicer interface for bulk inserting. keys := make([]interface{}, len(vulnerabilityIDs)) for i, id := range vulnerabilityIDs { keys[i] = id } _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) if err != nil { return util.HandleError("removeVulnerabilityAffectedFeature", err) } return nil } func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) { var ( vulnID sql.NullInt64 vulnIDs []int64 ) // mark vulnerabilities deleted stmt, err := tx.Prepare(removeVulnerability) if err != nil { return nil, util.HandleError("removeVulnerability", err) } defer stmt.Close() for _, vuln := range vulnerabilities { err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) if err != nil { return nil, util.HandleError("removeVulnerability", err) } if !vulnID.Valid { return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) } vulnIDs = append(vulnIDs, vulnID.Int64) } return vulnIDs, nil } // findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in // database and the order of output array is not guaranteed. func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { return FindVulnerabilityIDs(tx, vulnIDs, true) } func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { return FindVulnerabilityIDs(tx, vulnIDs, false) } func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { if len(vulnIDs) == 0 { return nil, nil } vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{} keys := make([]interface{}, len(vulnIDs)*2) for i, vulnID := range vulnIDs { keys[i*2] = vulnID.Name keys[i*2+1] = vulnID.Namespace vulnIDMap[vulnID] = sql.NullInt64{} } query := "" if withLatestDeleted { query = querySearchLastDeletedVulnerabilityID(len(vulnIDs)) } else { query = querySearchNotDeletedVulnerabilityID(len(vulnIDs)) } rows, err := tx.Query(query, keys...) if err != nil { return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err) } defer rows.Close() var ( id sql.NullInt64 vulnID database.VulnerabilityID ) for rows.Next() { err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) if err != nil { return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) } vulnIDMap[vulnID] = id } ids := make([]sql.NullInt64, len(vulnIDs)) for i, v := range vulnIDs { ids[i] = vulnIDMap[v] } return ids, nil } func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) { vulnPage := database.PagedVulnerableAncestries{Limit: limit} currentPage := page.Page{0} if currentToken != pagination.FirstPageToken { if err := key.UnmarshalToken(currentToken, ¤tPage); err != nil { return vulnPage, err } } if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( &vulnPage.Name, &vulnPage.Description, &vulnPage.Link, &vulnPage.Severity, &vulnPage.Metadata, &vulnPage.Namespace.Name, &vulnPage.Namespace.VersionFormat, ); err != nil { return vulnPage, util.HandleError("searchVulnerabilityByID", err) } // the last result is used for the next page's startID rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) if err != nil { return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err) } defer rows.Close() ancestries := []affectedAncestry{} for rows.Next() { var ancestry affectedAncestry err := rows.Scan(&ancestry.id, &ancestry.name) if err != nil { return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err) } ancestries = append(ancestries, ancestry) } lastIndex := 0 if len(ancestries)-1 < limit { lastIndex = len(ancestries) vulnPage.End = true } else { // Use the last ancestry's ID as the next page. lastIndex = len(ancestries) - 1 vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id}) if err != nil { return vulnPage, err } } vulnPage.Affected = map[int]string{} for _, ancestry := range ancestries[0:lastIndex] { vulnPage.Affected[int(ancestry.id)] = ancestry.name } vulnPage.Current, err = key.MarshalToken(currentPage) if err != nil { return vulnPage, err } return vulnPage, nil }