578 lines
16 KiB
Go
578 lines
16 KiB
Go
// Copyright 2017 clair authors
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package vulnerability
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/lib/pq"
|
|
log "github.com/sirupsen/logrus"
|
|
|
|
"github.com/coreos/clair/database"
|
|
"github.com/coreos/clair/database/pgsql/feature"
|
|
"github.com/coreos/clair/database/pgsql/monitoring"
|
|
"github.com/coreos/clair/database/pgsql/page"
|
|
"github.com/coreos/clair/database/pgsql/util"
|
|
"github.com/coreos/clair/ext/versionfmt"
|
|
"github.com/coreos/clair/pkg/pagination"
|
|
)
|
|
|
|
const (
|
|
searchVulnerability = `
|
|
SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format
|
|
FROM vulnerability AS v, namespace AS n
|
|
WHERE v.namespace_id = n.id
|
|
AND v.name = $1
|
|
AND n.name = $2
|
|
AND v.deleted_at IS NULL
|
|
`
|
|
|
|
searchVulnerabilityByID = `
|
|
SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format
|
|
FROM vulnerability AS v, namespace AS n
|
|
WHERE v.namespace_id = n.id
|
|
AND v.id = $1`
|
|
|
|
insertVulnerability = `
|
|
WITH ns AS (
|
|
SELECT id FROM namespace WHERE name = $6 AND version_format = $7
|
|
)
|
|
INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at)
|
|
VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP)
|
|
RETURNING id`
|
|
|
|
removeVulnerability = `
|
|
UPDATE Vulnerability
|
|
SET deleted_at = CURRENT_TIMESTAMP
|
|
WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1)
|
|
AND name = $2
|
|
AND deleted_at IS NULL
|
|
RETURNING id`
|
|
|
|
searchNotificationVulnerableAncestry = `
|
|
SELECT DISTINCT ON (a.id)
|
|
a.id, a.name
|
|
FROM vulnerability_affected_namespaced_feature AS vanf,
|
|
ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
|
|
WHERE vanf.vulnerability_id = $1
|
|
AND a.id >= $2
|
|
AND al.ancestry_id = a.id
|
|
AND al.id = af.ancestry_layer_id
|
|
AND af.namespaced_feature_id = vanf.namespaced_feature_id
|
|
ORDER BY a.id ASC
|
|
LIMIT $3;`
|
|
)
|
|
|
|
func queryInvalidateVulnerabilityCache(count int) string {
|
|
return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
|
|
WHERE vulnerability_id IN (%s)`,
|
|
util.QueryString(1, count))
|
|
}
|
|
|
|
// NOTE(Sida): Every search query can only have count less than postgres set
|
|
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
|
|
// stack depth. TODO(Sida): Generate different queries for different count: if
|
|
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
|
|
// count > 65535, use is expected to split data into batches.
|
|
func querySearchLastDeletedVulnerabilityID(count int) string {
|
|
return fmt.Sprintf(`
|
|
SELECT vid, vname, nname FROM (
|
|
SELECT v.id AS vid, v.name AS vname, n.name AS nname,
|
|
row_number() OVER (
|
|
PARTITION by (v.name, n.name)
|
|
ORDER BY v.deleted_at DESC
|
|
) AS rownum
|
|
FROM vulnerability AS v, namespace AS n
|
|
WHERE v.namespace_id = n.id
|
|
AND (v.name, n.name) IN ( %s )
|
|
AND v.deleted_at IS NOT NULL
|
|
) tmp WHERE rownum <= 1`,
|
|
util.QueryString(2, count))
|
|
}
|
|
|
|
func querySearchNotDeletedVulnerabilityID(count int) string {
|
|
return fmt.Sprintf(`
|
|
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
|
|
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
|
|
AND v.deleted_at IS NULL`,
|
|
util.QueryString(2, count))
|
|
}
|
|
|
|
type affectedAncestry struct {
|
|
name string
|
|
id int64
|
|
}
|
|
|
|
type affectRelation struct {
|
|
vulnerabilityID int64
|
|
namespacedFeatureID int64
|
|
addedBy int64
|
|
}
|
|
|
|
type affectedFeatureRows struct {
|
|
rows map[int64]database.AffectedFeature
|
|
}
|
|
|
|
func FindVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
|
|
defer monitoring.ObserveQueryTime("findVulnerabilities", "", time.Now())
|
|
resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
|
|
vulnIDMap := map[int64][]*database.NullableVulnerability{}
|
|
|
|
//TODO(Sida): Change to bulk search.
|
|
stmt, err := tx.Prepare(searchVulnerability)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// load vulnerabilities
|
|
for i, key := range vulnerabilities {
|
|
var (
|
|
id sql.NullInt64
|
|
vuln = database.NullableVulnerability{
|
|
VulnerabilityWithAffected: database.VulnerabilityWithAffected{
|
|
Vulnerability: database.Vulnerability{
|
|
Name: key.Name,
|
|
Namespace: database.Namespace{
|
|
Name: key.Namespace,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
)
|
|
|
|
err := stmt.QueryRow(key.Name, key.Namespace).Scan(
|
|
&id,
|
|
&vuln.Description,
|
|
&vuln.Link,
|
|
&vuln.Severity,
|
|
&vuln.Metadata,
|
|
&vuln.Namespace.VersionFormat,
|
|
)
|
|
|
|
if err != nil && err != sql.ErrNoRows {
|
|
stmt.Close()
|
|
return nil, util.HandleError("searchVulnerability", err)
|
|
}
|
|
vuln.Valid = id.Valid
|
|
resultVuln[i] = vuln
|
|
if id.Valid {
|
|
vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i])
|
|
}
|
|
}
|
|
|
|
if err := stmt.Close(); err != nil {
|
|
return nil, util.HandleError("searchVulnerability", err)
|
|
}
|
|
|
|
toQuery := make([]int64, 0, len(vulnIDMap))
|
|
for id := range vulnIDMap {
|
|
toQuery = append(toQuery, id)
|
|
}
|
|
|
|
// load vulnerability affected features
|
|
rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
|
|
if err != nil {
|
|
return nil, util.HandleError("searchVulnerabilityAffected", err)
|
|
}
|
|
|
|
for rows.Next() {
|
|
var (
|
|
id int64
|
|
f database.AffectedFeature
|
|
)
|
|
|
|
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion)
|
|
if err != nil {
|
|
return nil, util.HandleError("searchVulnerabilityAffected", err)
|
|
}
|
|
|
|
for _, vuln := range vulnIDMap[id] {
|
|
f.Namespace = vuln.Namespace
|
|
vuln.Affected = append(vuln.Affected, f)
|
|
}
|
|
}
|
|
|
|
return resultVuln, nil
|
|
}
|
|
|
|
func InsertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) error {
|
|
defer monitoring.ObserveQueryTime("insertVulnerabilities", "all", time.Now())
|
|
// bulk insert vulnerabilities
|
|
vulnIDs, err := insertVulnerabilities(tx, vulnerabilities)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// bulk insert vulnerability affected features
|
|
vulnFeatureMap, err := InsertVulnerabilityAffected(tx, vulnIDs, vulnerabilities)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return CacheVulnerabiltyAffectedNamespacedFeature(tx, vulnFeatureMap)
|
|
}
|
|
|
|
// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
|
|
//
|
|
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
|
|
func InsertVulnerabilityAffected(tx *sql.Tx, vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
|
|
var (
|
|
vulnFeature = map[int64]affectedFeatureRows{}
|
|
affectedID int64
|
|
)
|
|
|
|
types, err := feature.GetFeatureTypeMap(tx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
stmt, err := tx.Prepare(insertVulnerabilityAffected)
|
|
if err != nil {
|
|
return nil, util.HandleError("insertVulnerabilityAffected", err)
|
|
}
|
|
|
|
defer stmt.Close()
|
|
for i, vuln := range vulnerabilities {
|
|
// affected feature row ID -> affected feature
|
|
affectedFeatures := map[int64]database.AffectedFeature{}
|
|
for _, f := range vuln.Affected {
|
|
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.ByName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
|
|
if err != nil {
|
|
return nil, util.HandleError("insertVulnerabilityAffected", err)
|
|
}
|
|
affectedFeatures[affectedID] = f
|
|
}
|
|
vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures}
|
|
}
|
|
|
|
return vulnFeature, nil
|
|
}
|
|
|
|
// insertVulnerabilities inserts a set of unique vulnerabilities into database,
|
|
// under the assumption that all vulnerabilities are valid.
|
|
func insertVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
|
|
var (
|
|
vulnID int64
|
|
vulnIDs = make([]int64, 0, len(vulnerabilities))
|
|
vulnMap = map[database.VulnerabilityID]struct{}{}
|
|
)
|
|
|
|
for _, v := range vulnerabilities {
|
|
key := database.VulnerabilityID{
|
|
Name: v.Name,
|
|
Namespace: v.Namespace.Name,
|
|
}
|
|
|
|
// Ensure uniqueness of vulnerability IDs
|
|
if _, ok := vulnMap[key]; ok {
|
|
return nil, errors.New("inserting duplicated vulnerabilities is not allowed")
|
|
}
|
|
vulnMap[key] = struct{}{}
|
|
}
|
|
|
|
//TODO(Sida): Change to bulk insert.
|
|
stmt, err := tx.Prepare(insertVulnerability)
|
|
if err != nil {
|
|
return nil, util.HandleError("insertVulnerability", err)
|
|
}
|
|
|
|
defer stmt.Close()
|
|
for _, vuln := range vulnerabilities {
|
|
err := stmt.QueryRow(vuln.Name, vuln.Description,
|
|
vuln.Link, &vuln.Severity, &vuln.Metadata,
|
|
vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
|
|
if err != nil {
|
|
return nil, util.HandleError("insertVulnerability", err)
|
|
}
|
|
|
|
vulnIDs = append(vulnIDs, vulnID)
|
|
}
|
|
|
|
return vulnIDs, nil
|
|
}
|
|
|
|
func LockFeatureVulnerabilityCache(tx *sql.Tx) error {
|
|
_, err := tx.Exec(lockVulnerabilityAffects)
|
|
if err != nil {
|
|
return util.HandleError("lockVulnerabilityAffects", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
|
|
// to affected feature rows and caches them.
|
|
func CacheVulnerabiltyAffectedNamespacedFeature(tx *sql.Tx, affected map[int64]affectedFeatureRows) error {
|
|
// Prevent InsertNamespacedFeatures to modify it.
|
|
err := LockFeatureVulnerabilityCache(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
vulnIDs := []int64{}
|
|
for id := range affected {
|
|
vulnIDs = append(vulnIDs, id)
|
|
}
|
|
|
|
rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
|
|
if err != nil {
|
|
return util.HandleError("searchVulnerabilityPotentialAffected", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
|
|
relation := []affectRelation{}
|
|
for rows.Next() {
|
|
var (
|
|
vulnID int64
|
|
nsfID int64
|
|
fVersion string
|
|
addedBy int64
|
|
)
|
|
|
|
err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
|
|
if err != nil {
|
|
return util.HandleError("searchVulnerabilityPotentialAffected", err)
|
|
}
|
|
|
|
candidate, ok := affected[vulnID].rows[addedBy]
|
|
|
|
if !ok {
|
|
return errors.New("vulnerability affected feature not found")
|
|
}
|
|
|
|
if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat,
|
|
fVersion,
|
|
candidate.AffectedVersion); err == nil {
|
|
if in {
|
|
relation = append(relation,
|
|
affectRelation{
|
|
vulnerabilityID: vulnID,
|
|
namespacedFeatureID: nsfID,
|
|
addedBy: addedBy,
|
|
})
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
//TODO(Sida): Change to bulk insert.
|
|
for _, r := range relation {
|
|
result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
|
|
if err != nil {
|
|
return util.HandleError("insertVulnerabilityAffectedNamespacedFeature", err)
|
|
}
|
|
|
|
if num, err := result.RowsAffected(); err == nil {
|
|
if num <= 0 {
|
|
return errors.New("Nothing cached in database")
|
|
}
|
|
} else {
|
|
return err
|
|
}
|
|
}
|
|
|
|
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation))
|
|
return nil
|
|
}
|
|
|
|
func DeleteVulnerabilities(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) error {
|
|
defer monitoring.ObserveQueryTime("DeleteVulnerability", "all", time.Now())
|
|
|
|
vulnIDs, err := MarkVulnerabilitiesAsDeleted(tx, vulnerabilities)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if err := InvalidateVulnerabilityCache(tx, vulnIDs); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func InvalidateVulnerabilityCache(tx *sql.Tx, vulnerabilityIDs []int64) error {
|
|
if len(vulnerabilityIDs) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// Prevent InsertNamespacedFeatures to modify it.
|
|
err := LockFeatureVulnerabilityCache(tx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
//TODO(Sida): Make a nicer interface for bulk inserting.
|
|
keys := make([]interface{}, len(vulnerabilityIDs))
|
|
for i, id := range vulnerabilityIDs {
|
|
keys[i] = id
|
|
}
|
|
|
|
_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
|
|
if err != nil {
|
|
return util.HandleError("removeVulnerabilityAffectedFeature", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func MarkVulnerabilitiesAsDeleted(tx *sql.Tx, vulnerabilities []database.VulnerabilityID) ([]int64, error) {
|
|
var (
|
|
vulnID sql.NullInt64
|
|
vulnIDs []int64
|
|
)
|
|
|
|
// mark vulnerabilities deleted
|
|
stmt, err := tx.Prepare(removeVulnerability)
|
|
if err != nil {
|
|
return nil, util.HandleError("removeVulnerability", err)
|
|
}
|
|
|
|
defer stmt.Close()
|
|
for _, vuln := range vulnerabilities {
|
|
err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
|
|
if err != nil {
|
|
return nil, util.HandleError("removeVulnerability", err)
|
|
}
|
|
if !vulnID.Valid {
|
|
return nil, util.HandleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
|
|
}
|
|
vulnIDs = append(vulnIDs, vulnID.Int64)
|
|
}
|
|
return vulnIDs, nil
|
|
}
|
|
|
|
// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
|
|
// database and the order of output array is not guaranteed.
|
|
func FindLatestDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
|
return FindVulnerabilityIDs(tx, vulnIDs, true)
|
|
}
|
|
|
|
func FindNotDeletedVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
|
return FindVulnerabilityIDs(tx, vulnIDs, false)
|
|
}
|
|
|
|
func FindVulnerabilityIDs(tx *sql.Tx, vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
|
|
if len(vulnIDs) == 0 {
|
|
return nil, nil
|
|
}
|
|
|
|
vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{}
|
|
keys := make([]interface{}, len(vulnIDs)*2)
|
|
for i, vulnID := range vulnIDs {
|
|
keys[i*2] = vulnID.Name
|
|
keys[i*2+1] = vulnID.Namespace
|
|
vulnIDMap[vulnID] = sql.NullInt64{}
|
|
}
|
|
|
|
query := ""
|
|
if withLatestDeleted {
|
|
query = querySearchLastDeletedVulnerabilityID(len(vulnIDs))
|
|
} else {
|
|
query = querySearchNotDeletedVulnerabilityID(len(vulnIDs))
|
|
}
|
|
|
|
rows, err := tx.Query(query, keys...)
|
|
if err != nil {
|
|
return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
|
|
}
|
|
|
|
defer rows.Close()
|
|
var (
|
|
id sql.NullInt64
|
|
vulnID database.VulnerabilityID
|
|
)
|
|
for rows.Next() {
|
|
err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
|
|
if err != nil {
|
|
return nil, util.HandleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
|
|
}
|
|
vulnIDMap[vulnID] = id
|
|
}
|
|
|
|
ids := make([]sql.NullInt64, len(vulnIDs))
|
|
for i, v := range vulnIDs {
|
|
ids[i] = vulnIDMap[v]
|
|
}
|
|
|
|
return ids, nil
|
|
}
|
|
|
|
func FindPagedVulnerableAncestries(tx *sql.Tx, vulnID int64, limit int, currentToken pagination.Token, key pagination.Key) (database.PagedVulnerableAncestries, error) {
|
|
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
|
|
currentPage := page.Page{0}
|
|
if currentToken != pagination.FirstPageToken {
|
|
if err := key.UnmarshalToken(currentToken, ¤tPage); err != nil {
|
|
return vulnPage, err
|
|
}
|
|
}
|
|
|
|
if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
|
|
&vulnPage.Name,
|
|
&vulnPage.Description,
|
|
&vulnPage.Link,
|
|
&vulnPage.Severity,
|
|
&vulnPage.Metadata,
|
|
&vulnPage.Namespace.Name,
|
|
&vulnPage.Namespace.VersionFormat,
|
|
); err != nil {
|
|
return vulnPage, util.HandleError("searchVulnerabilityByID", err)
|
|
}
|
|
|
|
// the last result is used for the next page's startID
|
|
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
|
|
if err != nil {
|
|
return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
ancestries := []affectedAncestry{}
|
|
for rows.Next() {
|
|
var ancestry affectedAncestry
|
|
err := rows.Scan(&ancestry.id, &ancestry.name)
|
|
if err != nil {
|
|
return vulnPage, util.HandleError("searchNotificationVulnerableAncestry", err)
|
|
}
|
|
ancestries = append(ancestries, ancestry)
|
|
}
|
|
|
|
lastIndex := 0
|
|
if len(ancestries)-1 < limit {
|
|
lastIndex = len(ancestries)
|
|
vulnPage.End = true
|
|
} else {
|
|
// Use the last ancestry's ID as the next page.
|
|
lastIndex = len(ancestries) - 1
|
|
vulnPage.Next, err = key.MarshalToken(page.Page{ancestries[len(ancestries)-1].id})
|
|
if err != nil {
|
|
return vulnPage, err
|
|
}
|
|
}
|
|
|
|
vulnPage.Affected = map[int]string{}
|
|
for _, ancestry := range ancestries[0:lastIndex] {
|
|
vulnPage.Affected[int(ancestry.id)] = ancestry.name
|
|
}
|
|
|
|
vulnPage.Current, err = key.MarshalToken(currentPage)
|
|
if err != nil {
|
|
return vulnPage, err
|
|
}
|
|
|
|
return vulnPage, nil
|
|
}
|