439 lines
12 KiB
Go
439 lines
12 KiB
Go
// Copyright 2017 clair authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package pgsql
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"time"
|
|
|
|
"github.com/lib/pq"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/coreos/clair/database"
|
|
"github.com/coreos/clair/ext/versionfmt"
|
|
)
|
|
|
|
var (
|
|
errVulnerabilityNotFound = errors.New("vulnerability is not in database")
|
|
)
|
|
|
|
type affectedAncestry struct {
|
|
name string
|
|
id int64
|
|
}
|
|
|
|
type affectRelation struct {
|
|
vulnerabilityID int64
|
|
namespacedFeatureID int64
|
|
addedBy int64
|
|
}
|
|
|
|
type affectedFeatureRows struct {
|
|
rows map[int64]database.AffectedFeature
|
|
}
|
|
|
|
func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
|
|
defer observeQueryTime("findVulnerabilities", "", time.Now())
|
|
resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
|
|
vulnIDMap := map[int64][]*database.NullableVulnerability{}
|
|
|
|
//TODO(Sida): Change to bulk search.
|
|
stmt, err := tx.Prepare(searchVulnerability)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// load vulnerabilities
|
|
for i, key := range vulnerabilities {
|
|
var (
|
|
id sql.NullInt64
|
|
vuln = database.NullableVulnerability{
|
|
VulnerabilityWithAffected: database.VulnerabilityWithAffected{
|
|
Vulnerability: database.Vulnerability{
|
|
Name: key.Name,
|
|
Namespace: database.Namespace{
|
|
Name: key.Namespace,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
)
|
|
|
|
err := stmt.QueryRow(key.Name, key.Namespace).Scan(
|
|
&id,
|
|
&vuln.Description,
|
|
&vuln.Link,
|
|
&vuln.Severity,
|
|
&vuln.Metadata,
|
|
&vuln.Namespace.VersionFormat,
|
|
)
|
|
|
|
if err != nil && err != sql.ErrNoRows {
|
|
stmt.Close()
|
|
return nil, handleError("searchVulnerability", err)
|
|
}
|
|
vuln.Valid = id.Valid
|
|
resultVuln[i] = vuln
|
|
if id.Valid {
|
|
vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i])
|
|
}
|
|
}
|
|
|
|
if err := stmt.Close(); err != nil {
|
|
return nil, handleError("searchVulnerability", err)
|
|
}
|
|
|
|
toQuery := make([]int64, 0, len(vulnIDMap))
|
|
for id := range vulnIDMap {
|
|
toQuery = append(toQuery, id)
|
|
}
|
|
|
|
// load vulnerability affected features
|
|
rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
|
|
if err != nil {
|
|
return nil, handleError("searchVulnerabilityAffected", err)
|
|
}
|
|
|
|
for rows.Next() {
|
|
var (
|
|
id int64
|
|
f database.AffectedFeature
|
|
)
|
|
|
|
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion)
|
|
if err != nil {
|
|
return nil, handleError("searchVulnerabilityAffected", err)
|
|
}
|
|
|
|
for _, vuln := range vulnIDMap[id] {
|
|
f.Namespace = vuln.Namespace
|
|
vuln.Affected = append(vuln.Affected, f)
|
|
}
|
|
}
|
|
|
|
return resultVuln, nil
|
|
}
|
|
|
|
func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error {
|
|
defer observeQueryTime("insertVulnerabilities", "all", time.Now())
|
|
// bulk insert vulnerabilities
|
|
vulnIDs, err := tx.insertVulnerabilities(vulnerabilities)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// bulk insert vulnerability affected features
|
|
vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap)
|
|
}
|
|
|
|
// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
|
|
//
|
|
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
|
|
func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
|
|
var (
|
|
vulnFeature = map[int64]affectedFeatureRows{}
|
|
affectedID int64
|
|
)
|
|
|
|
//TODO(Sida): Change to bulk insert.
|
|
stmt, err := tx.Prepare(insertVulnerabilityAffected)
|
|
if err != nil {
|
|
return nil, handleError("insertVulnerabilityAffected", err)
|
|
}
|
|
|
|
defer stmt.Close()
|
|
for i, vuln := range vulnerabilities {
|
|
// affected feature row ID -> affected feature
|
|
affectedFeatures := map[int64]database.AffectedFeature{}
|
|
for _, f := range vuln.Affected {
|
|
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID)
|
|
if err != nil {
|
|
return nil, handleError("insertVulnerabilityAffected", err)
|
|
}
|
|
affectedFeatures[affectedID] = f
|
|
}
|
|
vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures}
|
|
}
|
|
|
|
return vulnFeature, nil
|
|
}
|
|
|
|
// insertVulnerabilities inserts a set of unique vulnerabilities into database,
|
|
// under the assumption that all vulnerabilities are valid.
|
|
func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
|
|
var (
|
|
vulnID int64
|
|
vulnIDs = make([]int64, 0, len(vulnerabilities))
|
|
vulnMap = map[database.VulnerabilityID]struct{}{}
|
|
)
|
|
|
|
for _, v := range vulnerabilities {
|
|
key := database.VulnerabilityID{
|
|
Name: v.Name,
|
|
Namespace: v.Namespace.Name,
|
|
}
|
|
|
|
// Ensure uniqueness of vulnerability IDs
|
|
if _, ok := vulnMap[key]; ok {
|
|
return nil, errors.New("inserting duplicated vulnerabilities is not allowed")
|
|
}
|
|
vulnMap[key] = struct{}{}
|
|
}
|
|
|
|
//TODO(Sida): Change to bulk insert.
|
|
stmt, err := tx.Prepare(insertVulnerability)
|
|
if err != nil {
|
|
return nil, handleError("insertVulnerability", err)
|
|
}
|
|
|
|
defer stmt.Close()
|
|
for _, vuln := range vulnerabilities {
|
|
err := stmt.QueryRow(vuln.Name, vuln.Description,
|
|
vuln.Link, &vuln.Severity, &vuln.Metadata,
|
|
vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
|
|
if err != nil {
|
|
return nil, handleError("insertVulnerability", err)
|
|
}
|
|
|
|
vulnIDs = append(vulnIDs, vulnID)
|
|
}
|
|
|
|
return vulnIDs, nil
|
|
}
|
|
|
|
// castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that
|
|
// everything has the interface{} type.
|
|
// It is required when comparing crafted MetadataMap against MetadataMap that we get from the
|
|
// database.
|
|
func castMetadata(m database.MetadataMap) database.MetadataMap {
|
|
c := make(database.MetadataMap)
|
|
j, _ := json.Marshal(m)
|
|
json.Unmarshal(j, &c)
|
|
return c
|
|
}
|
|
|
|
func (tx *pgSession) lockFeatureVulnerabilityCache() error {
|
|
_, err := tx.Exec(lockVulnerabilityAffects)
|
|
if err != nil {
|
|
return handleError("lockVulnerabilityAffects", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
|
|
// to affected feature rows and caches them.
|
|
func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error {
|
|
// Prevent InsertNamespacedFeatures to modify it.
|
|
err := tx.lockFeatureVulnerabilityCache()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
vulnIDs := []int64{}
|
|
for id := range affected {
|
|
vulnIDs = append(vulnIDs, id)
|
|
}
|
|
|
|
rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
|
|
if err != nil {
|
|
return handleError("searchVulnerabilityPotentialAffected", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
relation := []affectRelation{}
|
|
for rows.Next() {
|
|
var (
|
|
vulnID int64
|
|
nsfID int64
|
|
fVersion string
|
|
addedBy int64
|
|
)
|
|
|
|
err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
|
|
if err != nil {
|
|
return handleError("searchVulnerabilityPotentialAffected", err)
|
|
}
|
|
|
|
candidate, ok := affected[vulnID].rows[addedBy]
|
|
|
|
if !ok {
|
|
return errors.New("vulnerability affected feature not found")
|
|
}
|
|
|
|
if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat,
|
|
fVersion,
|
|
candidate.AffectedVersion); err == nil {
|
|
if in {
|
|
relation = append(relation,
|
|
affectRelation{
|
|
vulnerabilityID: vulnID,
|
|
namespacedFeatureID: nsfID,
|
|
addedBy: addedBy,
|
|
})
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
//TODO(Sida): Change to bulk insert.
|
|
for _, r := range relation {
|
|
result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
|
|
if err != nil {
|
|
return handleError("insertVulnerabilityAffectedNamespacedFeature", err)
|
|
}
|
|
|
|
if num, err := result.RowsAffected(); err == nil {
|
|
if num <= 0 {
|
|
return errors.New("Nothing cached in database")
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation))
|
|
return nil
|
|
}
|
|
|
|
func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error {
|
|
defer observeQueryTime("DeleteVulnerability", "all", time.Now())
|
|
|
|
vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error {
|
|
if len(vulnerabilityIDs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Prevent InsertNamespacedFeatures to modify it.
|
|
err := tx.lockFeatureVulnerabilityCache()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
//TODO(Sida): Make a nicer interface for bulk inserting.
|
|
keys := make([]interface{}, len(vulnerabilityIDs))
|
|
for i, id := range vulnerabilityIDs {
|
|
keys[i] = id
|
|
}
|
|
|
|
_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
|
|
if err != nil {
|
|
return handleError("removeVulnerabilityAffectedFeature", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) {
|
|
var (
|
|
vulnID sql.NullInt64
|
|
vulnIDs []int64
|
|
)
|
|
|
|
// mark vulnerabilities deleted
|
|
stmt, err := tx.Prepare(removeVulnerability)
|
|
if err != nil {
|
|
return nil, handleError("removeVulnerability", err)
|
|
}
|
|
|
|
defer stmt.Close()
|
|
for _, vuln := range vulnerabilities {
|
|
err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
|
|
if err != nil {
|
|
return nil, handleError("removeVulnerability", err)
|
|
}
|
|
if !vulnID.Valid {
|
|
return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
|
|
}
|
|
vulnIDs = append(vulnIDs, vulnID.Int64)
|
|
}
|
|
return vulnIDs, nil
|
|
}
|
|
|
|
// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
|
|
// database and the order of output array is not guaranteed.
|
|
func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
|
return tx.findVulnerabilityIDs(vulnIDs, true)
|
|
}
|
|
|
|
func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
|
return tx.findVulnerabilityIDs(vulnIDs, false)
|
|
}
|
|
|
|
func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
|
|
if len(vulnIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{}
|
|
keys := make([]interface{}, len(vulnIDs)*2)
|
|
for i, vulnID := range vulnIDs {
|
|
keys[i*2] = vulnID.Name
|
|
keys[i*2+1] = vulnID.Namespace
|
|
vulnIDMap[vulnID] = sql.NullInt64{}
|
|
}
|
|
|
|
query := ""
|
|
if withLatestDeleted {
|
|
query = querySearchLastDeletedVulnerabilityID(len(vulnIDs))
|
|
} else {
|
|
query = querySearchNotDeletedVulnerabilityID(len(vulnIDs))
|
|
}
|
|
|
|
rows, err := tx.Query(query, keys...)
|
|
if err != nil {
|
|
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
var (
|
|
id sql.NullInt64
|
|
vulnID database.VulnerabilityID
|
|
)
|
|
for rows.Next() {
|
|
err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
|
|
if err != nil {
|
|
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
|
|
}
|
|
vulnIDMap[vulnID] = id
|
|
}
|
|
|
|
ids := make([]sql.NullInt64, len(vulnIDs))
|
|
for i, v := range vulnIDs {
|
|
ids[i] = vulnIDMap[v]
|
|
}
|
|
|
|
return ids, nil
|
|
}
|