// Copyright 2017 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package pgsql import ( "database/sql" "encoding/json" "errors" "time" "github.com/lib/pq" log "github.com/sirupsen/logrus" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/versionfmt" ) var ( errVulnerabilityNotFound = errors.New("vulnerability is not in database") ) type affectedAncestry struct { name string id int64 } type affectRelation struct { vulnerabilityID int64 namespacedFeatureID int64 addedBy int64 } type affectedFeatureRows struct { rows map[int64]database.AffectedFeature } func (tx *pgSession) ListVulnerabilities() ([]database.NullableVulnerability, error) { defer observeQueryTime("listVulnerabilities", "", time.Now()) vulnIDMap := map[int64][]*database.NullableVulnerability{} stmt, err := tx.Prepare(listVulnerabilities) if err != nil { return nil, err } rows, err := stmt.Query() if err != nil && err != sql.ErrNoRows { stmt.Close() return nil, handleError("listVulnerabilities", err) } defer rows.Close() // load vulnerabilities for rows.Next() { var ( id sql.NullInt64 vuln = database.NullableVulnerability{} ) err := rows.Scan( &id, &vuln.Name, &vuln.Description, &vuln.Link, &vuln.Severity, &vuln.Metadata, &vuln.Namespace.Name, &vuln.Namespace.VersionFormat, ) if err != nil && err != sql.ErrNoRows { stmt.Close() return nil, handleError("searchVulnerability", err) } vuln.Valid = id.Valid if id.Valid { vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &vuln) } } if err := stmt.Close(); err != nil { return nil, handleError("listVulnerabilities", err) } toQuery := make([]int64, 0, len(vulnIDMap)) for id := range vulnIDMap { toQuery = append(toQuery, id) } // load vulnerability affected features rows, err = tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) if err != nil { return nil, handleError("searchVulnerabilityAffected", err) } for rows.Next() { var ( id int64 f database.AffectedFeature ) err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) if err != nil { return nil, handleError("searchVulnerabilityAffected", err) } for _, vuln := range vulnIDMap[id] { f.Namespace = vuln.Namespace vuln.Affected = append(vuln.Affected, f) } } var resultVuln []database.NullableVulnerability for _, vulns := range vulnIDMap { for _, vuln := range vulns { resultVuln = append(resultVuln, *vuln) } } return resultVuln, nil } func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) { defer observeQueryTime("findVulnerabilities", "", time.Now()) resultVuln := make([]database.NullableVulnerability, len(vulnerabilities)) vulnIDMap := map[int64][]*database.NullableVulnerability{} //TODO(Sida): Change to bulk search. stmt, err := tx.Prepare(searchVulnerability) if err != nil { return nil, err } // load vulnerabilities for i, key := range vulnerabilities { var ( id sql.NullInt64 vuln = database.NullableVulnerability{ VulnerabilityWithAffected: database.VulnerabilityWithAffected{ Vulnerability: database.Vulnerability{ Name: key.Name, Namespace: database.Namespace{ Name: key.Namespace, }, }, }, } ) err := stmt.QueryRow(key.Name, key.Namespace).Scan( &id, &vuln.Description, &vuln.Link, &vuln.Severity, &vuln.Metadata, &vuln.Namespace.VersionFormat, ) if err != nil && err != sql.ErrNoRows { stmt.Close() return nil, handleError("searchVulnerability", err) } vuln.Valid = id.Valid resultVuln[i] = vuln if id.Valid { vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i]) } } if err := stmt.Close(); err != nil { return nil, handleError("searchVulnerability", err) } toQuery := make([]int64, 0, len(vulnIDMap)) for id := range vulnIDMap { toQuery = append(toQuery, id) } // load vulnerability affected features rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery)) if err != nil { return nil, handleError("searchVulnerabilityAffected", err) } for rows.Next() { var ( id int64 f database.AffectedFeature ) err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) if err != nil { return nil, handleError("searchVulnerabilityAffected", err) } for _, vuln := range vulnIDMap[id] { f.Namespace = vuln.Namespace vuln.Affected = append(vuln.Affected, f) } } return resultVuln, nil } func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error { defer observeQueryTime("insertVulnerabilities", "all", time.Now()) // bulk insert vulnerabilities vulnIDs, err := tx.insertVulnerabilities(vulnerabilities) if err != nil { return err } // bulk insert vulnerability affected features vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities) if err != nil { return err } return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap) } // insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided. // // i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided. func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) { var ( vulnFeature = map[int64]affectedFeatureRows{} affectedID int64 ) //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerabilityAffected) if err != nil { return nil, handleError("insertVulnerabilityAffected", err) } defer stmt.Close() for i, vuln := range vulnerabilities { // affected feature row ID -> affected feature affectedFeatures := map[int64]database.AffectedFeature{} for _, f := range vuln.Affected { err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID) if err != nil { return nil, handleError("insertVulnerabilityAffected", err) } affectedFeatures[affectedID] = f } vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures} } return vulnFeature, nil } // insertVulnerabilities inserts a set of unique vulnerabilities into database, // under the assumption that all vulnerabilities are valid. func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) { var ( vulnID int64 vulnIDs = make([]int64, 0, len(vulnerabilities)) vulnMap = map[database.VulnerabilityID]struct{}{} ) for _, v := range vulnerabilities { key := database.VulnerabilityID{ Name: v.Name, Namespace: v.Namespace.Name, } // Ensure uniqueness of vulnerability IDs if _, ok := vulnMap[key]; ok { return nil, errors.New("inserting duplicated vulnerabilities is not allowed") } vulnMap[key] = struct{}{} } //TODO(Sida): Change to bulk insert. stmt, err := tx.Prepare(insertVulnerability) if err != nil { return nil, handleError("insertVulnerability", err) } defer stmt.Close() for _, vuln := range vulnerabilities { err := stmt.QueryRow(vuln.Name, vuln.Description, vuln.Link, &vuln.Severity, &vuln.Metadata, vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID) if err != nil { return nil, handleError("insertVulnerability", err) } vulnIDs = append(vulnIDs, vulnID) } return vulnIDs, nil } // castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that // everything has the interface{} type. // It is required when comparing crafted MetadataMap against MetadataMap that we get from the // database. func castMetadata(m database.MetadataMap) database.MetadataMap { c := make(database.MetadataMap) j, _ := json.Marshal(m) json.Unmarshal(j, &c) return c } func (tx *pgSession) lockFeatureVulnerabilityCache() error { _, err := tx.Exec(lockVulnerabilityAffects) if err != nil { return handleError("lockVulnerabilityAffects", err) } return nil } // cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID // to affected feature rows and caches them. func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error { // Prevent InsertNamespacedFeatures to modify it. err := tx.lockFeatureVulnerabilityCache() if err != nil { return err } vulnIDs := []int64{} for id := range affected { vulnIDs = append(vulnIDs, id) } rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs)) if err != nil { return handleError("searchVulnerabilityPotentialAffected", err) } defer rows.Close() relation := []affectRelation{} for rows.Next() { var ( vulnID int64 nsfID int64 fVersion string addedBy int64 ) err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy) if err != nil { return handleError("searchVulnerabilityPotentialAffected", err) } candidate, ok := affected[vulnID].rows[addedBy] if !ok { return errors.New("vulnerability affected feature not found") } if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat, fVersion, candidate.AffectedVersion); err == nil { if in { relation = append(relation, affectRelation{ vulnerabilityID: vulnID, namespacedFeatureID: nsfID, addedBy: addedBy, }) } } else { return err } } //TODO(Sida): Change to bulk insert. for _, r := range relation { result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy) if err != nil { return handleError("insertVulnerabilityAffectedNamespacedFeature", err) } if num, err := result.RowsAffected(); err == nil { if num <= 0 { return errors.New("Nothing cached in database") } } else { return err } } log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation)) return nil } func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error { defer observeQueryTime("DeleteVulnerability", "all", time.Now()) vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities) if err != nil { return err } if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil { return err } return nil } func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error { if len(vulnerabilityIDs) == 0 { return nil } // Prevent InsertNamespacedFeatures to modify it. err := tx.lockFeatureVulnerabilityCache() if err != nil { return err } //TODO(Sida): Make a nicer interface for bulk inserting. keys := make([]interface{}, len(vulnerabilityIDs)) for i, id := range vulnerabilityIDs { keys[i] = id } _, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...) if err != nil { return handleError("removeVulnerabilityAffectedFeature", err) } return nil } func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) { var ( vulnID sql.NullInt64 vulnIDs []int64 ) // mark vulnerabilities deleted stmt, err := tx.Prepare(removeVulnerability) if err != nil { return nil, handleError("removeVulnerability", err) } defer stmt.Close() for _, vuln := range vulnerabilities { err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID) if err != nil { return nil, handleError("removeVulnerability", err) } if !vulnID.Valid { return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database")) } vulnIDs = append(vulnIDs, vulnID.Int64) } return vulnIDs, nil } // findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in // database and the order of output array is not guaranteed. func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { return tx.findVulnerabilityIDs(vulnIDs, true) } func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) { return tx.findVulnerabilityIDs(vulnIDs, false) } func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) { if len(vulnIDs) == 0 { return nil, nil } vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{} keys := make([]interface{}, len(vulnIDs)*2) for i, vulnID := range vulnIDs { keys[i*2] = vulnID.Name keys[i*2+1] = vulnID.Namespace vulnIDMap[vulnID] = sql.NullInt64{} } query := "" if withLatestDeleted { query = querySearchLastDeletedVulnerabilityID(len(vulnIDs)) } else { query = querySearchNotDeletedVulnerabilityID(len(vulnIDs)) } rows, err := tx.Query(query, keys...) if err != nil { return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err) } defer rows.Close() var ( id sql.NullInt64 vulnID database.VulnerabilityID ) for rows.Next() { err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace) if err != nil { return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err) } vulnIDMap[vulnID] = id } ids := make([]sql.NullInt64, len(vulnIDs)) for i, v := range vulnIDs { ids[i] = vulnIDMap[v] } return ids, nil }