database: add initial vulnerability support

This commit is contained in:
Quentin Machu 2016-01-12 10:40:46 -05:00 committed by Jimmy Zelinskie
parent 3a786ae020
commit 7c70fc1c20
14 changed files with 862 additions and 88 deletions

View File

@ -28,8 +28,9 @@ type Datastore interface {
DeleteLayer(name string) error DeleteLayer(name string) error
// Vulnerability // Vulnerability
// InsertVulnerabilities([]*Vulnerability) InsertVulnerabilities([]Vulnerability) error
// DeleteVulnerability(id string) // DeleteVulnerability(id string) error
FindVulnerability(namespaceName, name string) (Vulnerability, error)
// Notifications // Notifications
// InsertNotifications([]Notification) error // InsertNotifications([]Notification) error

View File

@ -2,6 +2,7 @@ package database
import "github.com/coreos/clair/utils/types" import "github.com/coreos/clair/utils/types"
// ID is only meant to be used by database implementations and should never be used for anything else.
type Model struct { type Model struct {
ID int ID int
} }
@ -46,8 +47,9 @@ type Vulnerability struct {
Description string Description string
Link string Link string
Severity types.Priority Severity types.Priority
// FixedIn map[types.Version]Feature // <<-- WRONG.
Affects []FeatureVersion FixedIn []FeatureVersion
//Affects []FeatureVersion
// For output purposes. Only make sense when the vulnerability // For output purposes. Only make sense when the vulnerability
// is already about a specific Feature/FeatureVersion. // is already about a specific Feature/FeatureVersion.

View File

@ -1,6 +1,8 @@
package pgsql package pgsql
import ( import (
"database/sql"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors" cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types" "github.com/coreos/clair/utils/types"
@ -54,6 +56,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
if err != nil { if err != nil {
return 0, err return 0, err
} }
featureVersion.Feature.ID = featureID
// Begin transaction. // Begin transaction.
tx, err := pgSQL.Begin() tx, err := pgSQL.Begin()
@ -62,6 +65,15 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
return 0, handleError("insertFeatureVersion.Begin()", err) return 0, handleError("insertFeatureVersion.Begin()", err)
} }
// Set transaction as SERIALIZABLE.
// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
// consistent.
_, err = tx.Exec(getQuery("set_tx_serializable"))
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.set_tx_serializable", err)
}
// Find or create FeatureVersion. // Find or create FeatureVersion.
var newOrExisting string var newOrExisting string
err = tx.QueryRow(getQuery("soi_featureversion"), featureID, &featureVersion.Version). err = tx.QueryRow(getQuery("soi_featureversion"), featureID, &featureVersion.Version).
@ -77,52 +89,15 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
// Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in // Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in
// Vulnerability_Affects_FeatureVersion. // Vulnerability_Affects_FeatureVersion.
err = linkFeatureVersionToVulnerabilities(tx, featureVersion)
// Lock Vulnerability_FixedIn_Feature because we can't let it to be modified while we modify
// Vulnerability_Affects_FeatureVersion.
_, err = tx.Exec(getQuery("l_share_vulnerability_fixedin_feature"))
if err != nil { if err != nil {
tx.Rollback() // tx.Rollback() is done in linkFeatureVersionToVulnerabilities.
return 0, handleError("l_share_vulnerability_fixedin_feature", err) return 0, err
}
// Select every vulnerability and the fixed version that affect this Feature.
rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureID)
if err != nil {
tx.Rollback()
return 0, handleError("s_vulnerability_fixedin_feature", err)
}
defer rows.Close()
var fixedInID, vulnerabilityID int
var fixedInVersion types.Version
for rows.Next() {
err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion)
if err != nil {
tx.Rollback()
return 0, handleError("s_vulnerability_fixedin_feature.Scan()", err)
}
if featureVersion.Version.Compare(fixedInVersion) < 0 {
// The version of the FeatureVersion we are inserting is lower than the fixed version on this
// Vulnerability, thus, this FeatureVersion is affected by it.
// TODO(Quentin-M): Prepare.
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID,
featureVersion.ID, fixedInID)
if err != nil {
tx.Rollback()
return 0, handleError("i_vulnerability_affects_featureversion", err)
}
}
}
if err = rows.Err(); err != nil {
return 0, handleError("s_vulnerability_fixedin_feature.Rows()", err)
} }
// Commit transaction. // Commit transaction.
err = tx.Commit() err = tx.Commit()
if err != nil { if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.Commit()", err) return 0, handleError("insertFeatureVersion.Commit()", err)
} }
@ -148,3 +123,42 @@ func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVers
return IDs, nil return IDs, nil
} }
func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error {
// Select every vulnerability and the fixed version that affect this Feature.
// TODO(Quentin-M): LIMIT
rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureVersion.Feature.ID)
if err != nil {
tx.Rollback()
return handleError("s_vulnerability_fixedin_feature", err)
}
defer rows.Close()
var fixedInID, vulnerabilityID int
var fixedInVersion types.Version
for rows.Next() {
err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion)
if err != nil {
tx.Rollback()
return handleError("s_vulnerability_fixedin_feature.Scan()", err)
}
if featureVersion.Version.Compare(fixedInVersion) < 0 {
// The version of the FeatureVersion we are inserting is lower than the fixed version on this
// Vulnerability, thus, this FeatureVersion is affected by it.
// TODO(Quentin-M): Prepare.
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID,
featureVersion.ID, fixedInID)
if err != nil {
tx.Rollback()
return handleError("i_vulnerability_affects_featureversion", err)
}
}
}
if err = rows.Err(); err != nil {
tx.Rollback()
return handleError("s_vulnerability_fixedin_feature.Rows()", err)
}
return nil
}

View File

@ -77,7 +77,7 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas
// Query // Query
rows, err := pgSQL.Query(query, layerID) rows, err := pgSQL.Query(query, layerID)
if err != nil && err != sql.ErrNoRows { if err != nil {
return featureVersions, handleError(query, err) return featureVersions, handleError(query, err)
} }
defer rows.Close() defer rows.Close()
@ -201,14 +201,23 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
if err != nil && err != cerrors.ErrNotFound { if err != nil && err != cerrors.ErrNotFound {
return err return err
} else if err == nil { } else if err == nil {
if existingLayer.EngineVersion >= layer.EngineVersion {
// The layer exists and has an equal or higher engine verison, do nothing.
return nil
}
layer.ID = existingLayer.ID layer.ID = existingLayer.ID
} }
// Begin transaction. // Get parent ID.
tx, err := pgSQL.Begin() var parentID zero.Int
if err != nil { if layer.Parent != nil {
tx.Rollback() if layer.Parent.ID == 0 {
return handleError("InsertLayer.Begin()", err) log.Warning("Parent is expected to be retrieved from database when inserting a layer.")
return cerrors.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.")
}
parentID = zero.IntFrom(int64(layer.Parent.ID))
} }
// Find or insert namespace if provided. // Find or insert namespace if provided.
@ -216,7 +225,6 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
if layer.Namespace != nil { if layer.Namespace != nil {
n, err := pgSQL.insertNamespace(*layer.Namespace) n, err := pgSQL.insertNamespace(*layer.Namespace)
if err != nil { if err != nil {
tx.Rollback()
return err return err
} }
namespaceID = zero.IntFrom(int64(n)) namespaceID = zero.IntFrom(int64(n))
@ -227,18 +235,15 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
} }
} }
// Begin transaction.
tx, err := pgSQL.Begin()
if err != nil {
tx.Rollback()
return handleError("InsertLayer.Begin()", err)
}
if layer.ID == 0 { if layer.ID == 0 {
// Insert a new layer. // Insert a new layer.
var parentID zero.Int
if layer.Parent != nil {
if layer.Parent.ID == 0 {
log.Warning("Parent is expected to be retrieved from database when inserting a layer.")
return cerrors.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.")
}
parentID = zero.IntFrom(int64(layer.Parent.ID))
}
err = tx.QueryRow(getQuery("i_layer"), layer.Name, layer.EngineVersion, parentID, namespaceID). err = tx.QueryRow(getQuery("i_layer"), layer.Name, layer.EngineVersion, parentID, namespaceID).
Scan(&layer.ID) Scan(&layer.ID)
if err != nil { if err != nil {
@ -246,11 +251,6 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
return handleError("i_layer", err) return handleError("i_layer", err)
} }
} else { } else {
if existingLayer.EngineVersion >= layer.EngineVersion {
// The layer exists and has an equal or higher engine verison, do nothing.
return nil
}
// Update an existing layer. // Update an existing layer.
_, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID) _, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID)
if err != nil { if err != nil {
@ -269,6 +269,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
// Update Layer_diff_FeatureVersion now. // Update Layer_diff_FeatureVersion now.
err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer) err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer)
if err != nil { if err != nil {
tx.Rollback()
return err return err
} }
@ -293,7 +294,7 @@ func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *
} else if layer.Parent != nil { } else if layer.Parent != nil {
// There is a parent, we need to diff the Features with it. // There is a parent, we need to diff the Features with it.
// Build name:version strctures. // Build name:version structures.
layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features) layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features)
parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features) parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features)

View File

@ -73,7 +73,7 @@ CREATE TABLE IF NOT EXISTS Vulnerability (
name VARCHAR(128) NOT NULL, name VARCHAR(128) NOT NULL,
description TEXT NULL, description TEXT NULL,
link VARCHAR(128) NULL, link VARCHAR(128) NULL,
severity severity NULL, severity severity NOT NULL,
UNIQUE (namespace_id, name)); UNIQUE (namespace_id, name));
@ -137,6 +137,6 @@ DROP TABLE IF EXISTS Namespace,
Vulnerability, Vulnerability,
Vulnerability_FixedIn_Feature, Vulnerability_FixedIn_Feature,
Vulnerability_Affects_FeatureVersion, Vulnerability_Affects_FeatureVersion,
KeyValue KeyValue,
Lock Lock
CASCADE; CASCADE;

View File

@ -120,7 +120,7 @@ func dropDatabase(dataSource, databaseName string) error {
// Drop database. // Drop database.
_, err = db.Exec("DROP DATABASE " + databaseName + ";") _, err = db.Exec("DROP DATABASE " + databaseName + ";")
if err != nil { if err != nil {
return fmt.Errorf("could not create database: %v", err) return fmt.Errorf("could not drop database: %v", err)
} }
return nil return nil
@ -185,7 +185,7 @@ func handleError(desc string, err error) error {
return database.ErrBackendException return database.ErrBackendException
} else if err == sql.ErrNoRows { } else if err == sql.ErrNoRows {
return cerrors.ErrNotFound return cerrors.ErrNotFound
} else if err == sql.ErrTxDone { } else if err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
return database.ErrBackendException return database.ErrBackendException
} }

View File

@ -10,6 +10,8 @@ var queries map[string]string
func init() { func init() {
queries = make(map[string]string) queries = make(map[string]string)
queries["set_tx_serializable"] = `SET TRANSACTION ISOLATION LEVEL SERIALIZABLE`
// keyvalue.go // keyvalue.go
queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2` queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2`
queries["i_keyvalue"] = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` queries["i_keyvalue"] = `INSERT INTO KeyValue(key, value) VALUES($1, $2)`
@ -39,10 +41,6 @@ func init() {
UNION UNION
SELECT id FROM new_feature` SELECT id FROM new_feature`
queries["l_share_vulnerability_fixedin_feature"] = `
LOCK Vulnerability_FixedIn_Feature IN SHARE MODE
`
queries["soi_featureversion"] = ` queries["soi_featureversion"] = `
WITH new_featureversion AS ( WITH new_featureversion AS (
INSERT INTO FeatureVersion(feature_id, version) INSERT INTO FeatureVersion(feature_id, version)
@ -142,6 +140,41 @@ func init() {
queries["r_lock"] = `DELETE FROM Lock WHERE name = $1 AND owner = $2` queries["r_lock"] = `DELETE FROM Lock WHERE name = $1 AND owner = $2`
queries["r_lock_expired"] = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` queries["r_lock_expired"] = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP`
// vulnerability.go
queries["f_vulnerability"] = `
SELECT v.id, n.id, v.description, v.link, v.severity, vfif.version, f.id, f.Name
FROM Vulnerability v
JOIN Namespace n ON v.namespace_id = n.id
LEFT JOIN Vulnerability_FixedIn_Feature vfif ON v.id = vfif.vulnerability_id
LEFT JOIN Feature f ON vfif.feature_id = f.id
WHERE n.Name = $1 AND v.Name = $2`
queries["i_vulnerability"] = `
INSERT INTO Vulnerability(namespace_id, name, description, link, severity)
VALUES($1, $2, $3, $4, $5)
RETURNING id`
queries["u_vulnerability"] = `
UPDATE Vulnerability SET description = $2, link = $3, severity = $4 WHERE id = $1`
queries["i_vulnerability_fixedin_feature"] = `
INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version)
VALUES($1, $2, $3)
RETURNING id`
queries["u_vulnerability_fixedin_feature"] = `
UPDATE Vulnerability_FixedIn_Feature
SET version = $3
WHERE vulnerability_id = $1 AND feature_id = $2
RETURNING id`
queries["r_vulnerability_affects_featureversion"] = `
DELETE FROM Vulnerability_Affects_FeatureVersion
WHERE fixedin_id = $1`
queries["f_featureversion_by_feature"] = `
SELECT id, version FROM FeatureVersion WHERE feature_id = $1`
} }
func getQuery(name string) string { func getQuery(name string) string {

View File

@ -3,6 +3,7 @@ INSERT INTO namespace (id, name) VALUES (2, 'debian:8');
INSERT INTO feature (id, namespace_id, name) VALUES (1, 1, 'wechat'); INSERT INTO feature (id, namespace_id, name) VALUES (1, 1, 'wechat');
INSERT INTO feature (id, namespace_id, name) VALUES (2, 1, 'openssl'); INSERT INTO feature (id, namespace_id, name) VALUES (2, 1, 'openssl');
INSERT INTO feature (id, namespace_id, name) VALUES (4, 1, 'libssl');
INSERT INTO feature (id, namespace_id, name) VALUES (3, 2, 'openssl'); INSERT INTO feature (id, namespace_id, name) VALUES (3, 2, 'openssl');
INSERT INTO featureversion (id, feature_id, version) VALUES (1, 1, '0.5'); INSERT INTO featureversion (id, feature_id, version) VALUES (1, 1, '0.5');
INSERT INTO featureversion (id, feature_id, version) VALUES (2, 2, '1.0'); INSERT INTO featureversion (id, feature_id, version) VALUES (2, 2, '1.0');
@ -23,8 +24,9 @@ INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modifica
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'); INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High');
INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES (1, 1, 2, '2.0'); INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES (1, 1, 2, '2.0');
INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES (2, 1, 4, '1.9-abc');
INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', 'http://google.com/#q=NOPE', 'Negligible'); INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown');
SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1);

View File

@ -0,0 +1,306 @@
package pgsql
import (
"database/sql"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/guregu/null/zero"
)
func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) {
vulnerability := database.Vulnerability{
Name: name,
Namespace: database.Namespace{
Name: namespaceName,
},
}
// Find Vulnerability.
rows, err := pgSQL.Query(getQuery("f_vulnerability"), namespaceName, name)
if err != nil {
return vulnerability, handleError("f_vulnerability", err)
}
defer rows.Close()
// Iterate to scan the Vulnerability and its FixedIn FeatureVersions.
for rows.Next() {
var featureVersionID zero.Int
var featureVersionVersion zero.String
var featureVersionFeatureName zero.String
err := rows.Scan(&vulnerability.ID, &vulnerability.Namespace.ID, &vulnerability.Description,
&vulnerability.Link, &vulnerability.Severity, &featureVersionVersion, &featureVersionID,
&featureVersionFeatureName)
if err != nil {
return vulnerability, handleError("f_vulnerability.Scan()", err)
}
if !featureVersionID.IsZero() {
// Note that the ID we fill in featureVersion is actually a Feature ID, and not
// a FeatureVersion ID.
featureVersion := database.FeatureVersion{
Model: database.Model{ID: int(featureVersionID.Int64)},
Feature: database.Feature{
Model: database.Model{ID: int(featureVersionID.Int64)},
Namespace: vulnerability.Namespace,
Name: featureVersionFeatureName.String,
},
Version: types.NewVersionUnsafe(featureVersionVersion.String),
}
vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion)
}
}
if err = rows.Err(); err != nil {
return vulnerability, handleError("s_featureversions_vulnerabilities.Rows()", err)
}
if vulnerability.ID == 0 {
return vulnerability, cerrors.ErrNotFound
}
return vulnerability, nil
}
// FixedIn.Namespace are not necessary, they are overwritten by the vuln.
func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability) error {
for _, vulnerability := range vulnerabilities {
err := pgSQL.insertVulnerability(vulnerability)
if err != nil {
return err
}
}
return nil
}
func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) error {
// Verify parameters
if vulnerability.Name == "" || len(vulnerability.FixedIn) == 0 ||
vulnerability.Namespace.Name == "" || !vulnerability.Severity.IsValid() {
log.Warning("could not insert an invalid vulnerability")
return cerrors.NewBadRequestError("could not insert an invalid vulnerability")
}
for _, fixedInFeatureVersion := range vulnerability.FixedIn {
if fixedInFeatureVersion.Feature.Namespace.Name != "" &&
fixedInFeatureVersion.Feature.Namespace.Name != vulnerability.Namespace.Name {
msg := "could not insert an invalid vulnerability: FixedIn FeatureVersion must be in the " +
"same namespace as the Vulnerability"
log.Warning(msg)
return cerrors.NewBadRequestError(msg)
}
}
// Find or insert Vulnerability's Namespace.
namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace)
if err != nil {
return err
}
// Find vulnerability and its Vulnerability_FixedIn_Features.
existingVulnerability, err := pgSQL.FindVulnerability(vulnerability.Namespace.Name,
vulnerability.Name)
if err != nil && err != cerrors.ErrNotFound {
return err
}
// Compute new/updated FixedIn FeatureVersions.
var newFixedInFeatureVersions []database.FeatureVersion
var updatedFixedInFeatureVersions []database.FeatureVersion
if existingVulnerability.ID == 0 {
newFixedInFeatureVersions = vulnerability.FixedIn
} else {
newFixedInFeatureVersions, updatedFixedInFeatureVersions = diffFixedIn(vulnerability,
existingVulnerability)
}
if len(newFixedInFeatureVersions) == 0 && len(updatedFixedInFeatureVersions) == 0 {
// Nothing to do.
return nil
}
// Insert or find the new FeatureVersions.
// We already have the Feature IDs in updatedFixedInFeatureVersions because diffFixedIn fills them
// in using the existing vulnerability's FixedIn FeatureVersions. Note that even if FixedIn
// is type FeatureVersion, the actual stored ID in these structs are the Feature IDs.
//
// Also, we enforce the namespace of the FeatureVersion in case it was empty. There is a test
// above to ensure that the passed Namespace is either the same as the vulnerability or empty.
for i := 0; i < len(newFixedInFeatureVersions); i++ {
newFixedInFeatureVersions[i].Feature.Namespace.Name = vulnerability.Namespace.Name
newFixedInFeatureVersions[i].ID, err = pgSQL.insertFeatureVersion(newFixedInFeatureVersions[i])
if err != nil {
return err
}
}
// Begin transaction.
tx, err := pgSQL.Begin()
if err != nil {
tx.Rollback()
return handleError("insertVulnerability.Begin()", err)
}
// Set transaction as SERIALIZABLE.
// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
// consistent.
_, err = tx.Exec(getQuery("set_tx_serializable"))
if err != nil {
tx.Rollback()
return handleError("insertFeatureVersion.set_tx_serializable", err)
}
if existingVulnerability.ID == 0 {
// Insert new vulnerability.
err = tx.QueryRow(getQuery("i_vulnerability"), namespaceID, vulnerability.Name,
vulnerability.Description, vulnerability.Link, &vulnerability.Severity).Scan(&vulnerability.ID)
if err != nil {
tx.Rollback()
return handleError("i_vulnerability", err)
}
} else {
// Update vulnerability
_, err = tx.Exec(getQuery("u_vulnerability"), existingVulnerability.ID,
vulnerability.Description, vulnerability.Link, &vulnerability.Severity)
if err != nil {
tx.Rollback()
return handleError("u_vulnerability", err)
}
vulnerability.ID = existingVulnerability.ID
}
// Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now.
err = pgSQL.updateVulnerabilityFeatureVersions(tx, &vulnerability, &existingVulnerability, newFixedInFeatureVersions, updatedFixedInFeatureVersions)
if err != nil {
tx.Rollback()
return err
}
// Commit transaction.
err = tx.Commit()
if err != nil {
tx.Rollback()
return handleError("insertVulnerability.Commit()", err)
}
return nil
}
func diffFixedIn(vulnerability, existingVulnerability database.Vulnerability) (newFixedIn, updatedFixedIn []database.FeatureVersion) {
// Build FeatureVersion.Feature.Namespace.Name:FeatureVersion.Feature.Name (NaN) structures.
vulnerabilityFixedInNameMap, vulnerabilityFixedInNameSlice := createFeatureVersionNameMap(vulnerability.FixedIn)
existingFixedInMapNameMap, existingFixedInNameSlice := createFeatureVersionNameMap(existingVulnerability.FixedIn)
// Calculate the new FixedIn FeatureVersion NaN and updated ones.
newFixedInName := utils.CompareStringLists(vulnerabilityFixedInNameSlice,
existingFixedInNameSlice)
updatedFixedInName := utils.CompareStringListsInBoth(vulnerabilityFixedInNameSlice,
existingFixedInNameSlice)
for _, nan := range newFixedInName {
newFixedIn = append(newFixedIn, vulnerabilityFixedInNameMap[nan])
}
for _, nan := range updatedFixedInName {
fv := existingFixedInMapNameMap[nan]
fv.Version = vulnerabilityFixedInNameMap[nan].Version
updatedFixedIn = append(updatedFixedIn, fv)
}
return
}
func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) {
m := make(map[string]database.FeatureVersion, 0)
s := make([]string, 0, len(features))
for i := 0; i < len(features); i++ {
featureVersion := features[i]
m[featureVersion.Feature.Name] = featureVersion
s = append(s, featureVersion.Feature.Name)
}
return m, s
}
// TODO(Quentin-M): Add support for removing Vulnerability_FixedIn_Feature when Version = MinVersion.
// We should then update the vulnerability fetcher to do it.
// Also maybe we would delete a Vulnerability if it hasn't any FixedIn.
// --> And affects
func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability, existingVulnerability *database.Vulnerability, newFixedInFeatureVersions, updatedFixedInFeatureVersions []database.FeatureVersion) error {
var fixedInID int
for _, fv := range newFixedInFeatureVersions {
// Insert Vulnerability_FixedIn_Feature.
err := tx.QueryRow(getQuery("i_vulnerability_fixedin_feature"), vulnerability.ID, fv.ID,
&fv.Version).Scan(&fixedInID)
if err != nil {
return handleError("i_vulnerability_fixedin_feature", err)
}
// Insert Vulnerability_Affects_FeatureVersion.
err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerability.ID, fv.ID, fv.Version)
if err != nil {
return err
}
}
for _, fv := range updatedFixedInFeatureVersions {
// Update Vulnerability_FixedIn_Feature.
err := tx.QueryRow(getQuery("u_vulnerability_fixedin_feature"), vulnerability.ID, fv.ID,
&fv.Version).Scan(&fixedInID)
if err != nil {
return handleError("u_vulnerability_fixedin_feature", err)
}
// Drop all old Vulnerability_Affects_FeatureVersion.
_, err = tx.Exec(getQuery("r_vulnerability_affects_featureversion"), fixedInID)
if err != nil {
return handleError("r_vulnerability_affects_featureversion", err)
}
// Insert Vulnerability_Affects_FeatureVersion.
err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerability.ID, fv.ID, fv.Version)
if err != nil {
return err
}
}
return nil
}
func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error {
// Find every FeatureVersions of the Feature we want to affect.
// TODO(Quentin-M): LIMIT
rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID)
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return handleError("f_featureversion_by_feature", err)
}
defer rows.Close()
var featureVersionID int
var featureVersionVersion types.Version
for rows.Next() {
err := rows.Scan(&featureVersionID, &featureVersionVersion)
if err != nil {
return handleError("f_featureversion_by_feature.Scan()", err)
}
if featureVersionVersion.Compare(fixedInVersion) < 0 {
_, err := tx.Exec("i_vulnerability_affects_featureversion", vulnerabilityID, featureVersionID,
fixedInID)
if err != nil {
return handleError("i_vulnerability_affects_featureversion", err)
}
}
}
if err = rows.Err(); err != nil {
return handleError("f_featureversion_by_feature.Rows()", err)
}
return nil
}

View File

@ -0,0 +1,387 @@
package pgsql
import (
"testing"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
)
func TestFindVulnerability(t *testing.T) {
datastore, err := OpenForTest("FindVulnerability", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Find a vulnerability that does not exist.
_, err = datastore.FindVulnerability("", "")
assert.Equal(t, cerrors.ErrNotFound, err)
// Find a normal vulnerability.
v1 := database.Vulnerability{
Name: "CVE-OPENSSL-1-DEB7",
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: types.High,
FixedIn: []database.FeatureVersion{
database.FeatureVersion{
Feature: database.Feature{Name: "openssl"},
Version: types.NewVersionUnsafe("2.0"),
},
database.FeatureVersion{
Feature: database.Feature{Name: "libssl"},
Version: types.NewVersionUnsafe("1.9-abc"),
},
},
}
v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
if assert.Nil(t, err) {
equalsVuln(t, &v1, &v1f)
}
// Find a vulnerability that has no link, no severity and no FixedIn.
v2 := database.Vulnerability{
Name: "CVE-OPENSSL-1-DEB7",
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
}
v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE")
if assert.Nil(t, err) {
equalsVuln(t, &v2, &v2f)
}
}
func TestInsertVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Create some data.
n1 := database.Namespace{Name: "TestInsertVulnerabilityNamespace1"}
n2 := database.Namespace{Name: "TestInsertVulnerabilityNamespace2"}
f1 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion1",
Namespace: n1,
},
Version: types.NewVersionUnsafe("1.0"),
}
f2 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion1",
Namespace: n2,
},
Version: types.NewVersionUnsafe("1.0"),
}
f3 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion2",
},
Version: types.MaxVersion,
}
f4 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion2",
},
Version: types.NewVersionUnsafe("1.4"),
}
f5 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion3",
},
Version: types.NewVersionUnsafe("1.5"),
}
f6 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion4",
},
Version: types.NewVersionUnsafe("0.1"),
}
// Insert invalid vulnerabilities.
for _, vulnerability := range []database.Vulnerability{
database.Vulnerability{
Name: "",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1},
Severity: types.Unknown,
},
database.Vulnerability{
Name: "TestInsertVulnerability0",
Namespace: database.Namespace{},
FixedIn: []database.FeatureVersion{f1},
Severity: types.Unknown,
},
database.Vulnerability{
Name: "TestInsertVulnerability0-",
Namespace: database.Namespace{},
FixedIn: []database.FeatureVersion{f1},
},
database.Vulnerability{
Name: "TestInsertVulnerability0",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1},
Severity: types.Priority(""),
},
database.Vulnerability{
Name: "TestInsertVulnerability0",
Namespace: n1,
FixedIn: []database.FeatureVersion{f2},
Severity: types.Unknown,
},
} {
err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability})
assert.Error(t, err)
}
// Insert a simple vulnerability and find it.
v1 := database.Vulnerability{
Name: "TestInsertVulnerability1",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1, f3, f6},
Severity: types.Low,
Description: "TestInsertVulnerabilityDescription1",
Link: "TestInsertVulnerabilityLink1",
}
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1})
if assert.Nil(t, err) {
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
if assert.Nil(t, err) {
equalsVuln(t, &v1, &v1f)
}
}
// Update vulnerability.
v1.Description = "TestInsertVulnerabilityLink2"
v1.Link = "TestInsertVulnerabilityLink2"
v1.Severity = types.High
// Update f3 by f4, add fixed by f5, add fixed by f6 which already exists.
// TODO(Quentin-M): Remove FixedIn.
v1.FixedIn = []database.FeatureVersion{f4, f5, f6}
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1})
if assert.Nil(t, err) {
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
if assert.Nil(t, err) {
// We already had f1 before the update.
// Add it to the struct for comparison.
v1.FixedIn = append(v1.FixedIn, f1)
equalsVuln(t, &v1, &v1f)
}
}
}
func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name)
assert.Equal(t, expected.Description, actual.Description)
assert.Equal(t, expected.Link, actual.Link)
assert.Equal(t, expected.Severity, actual.Severity)
if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) {
for _, actualFeatureVersion := range actual.FixedIn {
found := false
for _, expectedFeatureVersion := range expected.FixedIn {
if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name {
found = true
assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name)
assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version)
}
}
if !found {
t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name)
}
}
}
}
// TODO Test Affects in Feature_Version and here.
//
// // Some data
// vuln1 := &database.Vulnerability{ID: "test1", Link: "link1", Priority: types.Medium, Description: "testDescription1", FixedInNodes: []string{"pkg1"}}
// vuln2 := &database.Vulnerability{ID: "test2", Link: "link2", Priority: types.High, Description: "testDescription2", FixedInNodes: []string{"pkg1", "pkg2"}}
// vuln3 := &database.Vulnerability{ID: "test3", Link: "link3", Priority: types.High, FixedInNodes: []string{"pkg3"}} // Empty description
//
// // Insert some vulnerabilities
// _, err := InsertVulnerabilities([]*database.Vulnerability{vuln1, vuln2, vuln3})
// if assert.Nil(t, err) {
// // Find one of the vulnerabilities we just inserted and verify its content
// v1, err := FindOnedatabase.Vulnerability(vuln1.ID, Fielddatabase.VulnerabilityAll)
// if assert.Nil(t, err) && assert.NotNil(t, v1) {
// assert.Equal(t, vuln1.ID, v1.ID)
// assert.Equal(t, vuln1.Link, v1.Link)
// assert.Equal(t, vuln1.Priority, v1.Priority)
// assert.Equal(t, vuln1.Description, v1.Description)
// if assert.Len(t, v1.FixedInNodes, 1) {
// assert.Equal(t, vuln1.FixedInNodes[0], v1.FixedInNodes[0])
// }
// }
// }
//
// // Update a database.Vulnerability and verify its new content
// pkg1 := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.0")}
// InsertPackages([]*Package{pkg1})
// vuln5 := &database.Vulnerability{ID: "test5", Link: "link5", Priority: types.Medium, Description: "testDescription5", FixedInNodes: []string{pkg1.Node}}
//
// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5})
// if assert.Nil(t, err) {
// // Partial updates
// // # Just a field update
// vuln5b := &database.Vulnerability{ID: "test5", Priority: types.High}
// _, err := InsertVulnerabilities([]*database.Vulnerability{vuln5b})
// if assert.Nil(t, err) {
// v5b, err := FindOnedatabase.Vulnerability(vuln5b.ID, Fielddatabase.VulnerabilityAll)
// if assert.Nil(t, err) && assert.NotNil(t, v5b) {
// assert.Equal(t, vuln5b.ID, v5b.ID)
// assert.Equal(t, vuln5b.Priority, v5b.Priority)
//
// if assert.Len(t, v5b.FixedInNodes, 1) {
// assert.Contains(t, v5b.FixedInNodes, pkg1.Node)
// }
// }
// }
//
// // # Just a field update, twice in the same transaction
// vuln5b1 := &database.Vulnerability{ID: "test5", Link: "http://foo.bar"}
// vuln5b2 := &database.Vulnerability{ID: "test5", Link: "http://bar.foo"}
// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5b1, vuln5b2})
// if assert.Nil(t, err) {
// v5b2, err := FindOnedatabase.Vulnerability(vuln5b2.ID, Fielddatabase.VulnerabilityAll)
// if assert.Nil(t, err) && assert.NotNil(t, v5b2) {
// assert.Equal(t, vuln5b2.Link, v5b2.Link)
// }
// }
//
// // # All fields except fixedIn update
// vuln5c := &database.Vulnerability{ID: "test5", Link: "link5c", Priority: types.Critical, Description: "testDescription5c"}
// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5c})
// if assert.Nil(t, err) {
// v5c, err := FindOnedatabase.Vulnerability(vuln5c.ID, Fielddatabase.VulnerabilityAll)
// if assert.Nil(t, err) && assert.NotNil(t, v5c) {
// assert.Equal(t, vuln5c.ID, v5c.ID)
// assert.Equal(t, vuln5c.Link, v5c.Link)
// assert.Equal(t, vuln5c.Priority, v5c.Priority)
// assert.Equal(t, vuln5c.Description, v5c.Description)
//
// if assert.Len(t, v5c.FixedInNodes, 1) {
// assert.Contains(t, v5c.FixedInNodes, pkg1.Node)
// }
// }
// }
//
// // Complete update
// pkg2 := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.1")}
// pkg3 := &Package{OS: "testOS", Name: "testpkg2", Version: types.NewVersionUnsafe("1.0")}
// InsertPackages([]*Package{pkg2, pkg3})
// vuln5d := &database.Vulnerability{ID: "test5", Link: "link5d", Priority: types.Low, Description: "testDescription5d", FixedInNodes: []string{pkg2.Node, pkg3.Node}}
//
// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5d})
// if assert.Nil(t, err) {
// v5d, err := FindOnedatabase.Vulnerability(vuln5d.ID, Fielddatabase.VulnerabilityAll)
// if assert.Nil(t, err) && assert.NotNil(t, v5d) {
// assert.Equal(t, vuln5d.ID, v5d.ID)
// assert.Equal(t, vuln5d.Link, v5d.Link)
// assert.Equal(t, vuln5d.Priority, v5d.Priority)
// assert.Equal(t, vuln5d.Description, v5d.Description)
//
// // Here, we ensure that a database.Vulnerability can only be fixed by one package of a given branch at a given time
// // And that we can add new fixed packages as well
// if assert.Len(t, v5d.FixedInNodes, 2) {
// assert.NotContains(t, v5d.FixedInNodes, pkg1.Node)
// }
// }
// }
// }
//
// // Create and update a database.Vulnerability's packages (and from the same branch) in the same batch
// pkg1 = &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.0")}
// pkg1b := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.1")}
// InsertPackages([]*Package{pkg1, pkg1b})
//
// // # Two updates of the same database.Vulnerability in the same batch with packages of the same branch
// pkg0 := &Package{OS: "testOS", Name: "testpkg0", Version: types.NewVersionUnsafe("1.0")}
// InsertPackages([]*Package{pkg0})
// _, err = InsertVulnerabilities([]*database.Vulnerability{&database.Vulnerability{ID: "test7", Link: "link7", Priority: types.Medium, Description: "testDescription7", FixedInNodes: []string{pkg0.Node}}})
// if assert.Nil(t, err) {
// vuln7b := &database.Vulnerability{ID: "test7", FixedInNodes: []string{pkg1.Node}}
// vuln7c := &database.Vulnerability{ID: "test7", FixedInNodes: []string{pkg1b.Node}}
// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln7b, vuln7c})
// if assert.Nil(t, err) {
// v7, err := FindOnedatabase.Vulnerability("test7", Fielddatabase.VulnerabilityAll)
// if assert.Nil(t, err) && assert.Len(t, v7.FixedInNodes, 2) {
// assert.Contains(t, v7.FixedInNodes, pkg0.Node)
// assert.NotContains(t, v7.FixedInNodes, pkg1.Node)
// assert.Contains(t, v7.FixedInNodes, pkg1b.Node)
// }
// }
// }
// func TestInsertVulnerabilityNotifications(t *testing.T) {
// Open(&config.DatabaseConfig{Type: "memstore"})
// defer Close()
//
// pkg1 := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.0")}
// pkg1b := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.2")}
// pkg2 := &Package{OS: "testOS", Name: "testpkg2", Version: types.NewVersionUnsafe("1.0")}
// InsertPackages([]*Package{pkg1, pkg1b, pkg2})
//
// // Newdatabase.VulnerabilityNotification
// vuln1 := &database.Vulnerability{ID: "test1", Link: "link1", Priority: types.Medium, Description: "testDescription1", FixedInNodes: []string{pkg1.Node}}
// vuln2 := &database.Vulnerability{ID: "test2", Link: "link2", Priority: types.High, Description: "testDescription2", FixedInNodes: []string{pkg1.Node, pkg2.Node}}
// vuln1b := &database.Vulnerability{ID: "test1", Priority: types.High, FixedInNodes: []string{"pkg3"}}
// notifications, err := InsertVulnerabilities([]*database.Vulnerability{vuln1, vuln2, vuln1b})
// if assert.Nil(t, err) {
// // We should only have two Newdatabase.VulnerabilityNotification notifications: one for test1 and one for test2
// // We should not have a database.VulnerabilityPriorityIncreasedNotification or a database.VulnerabilityPackageChangedNotification
// // for test1 because it is in the same batch
// if assert.Len(t, notifications, 2) {
// for _, n := range notifications {
// _, ok := n.(*Newdatabase.VulnerabilityNotification)
// assert.True(t, ok)
// }
// }
// }
//
// // database.VulnerabilityPriorityIncreasedNotification
// vuln1c := &database.Vulnerability{ID: "test1", Priority: types.Critical}
// notifications, err = InsertVulnerabilities([]*database.Vulnerability{vuln1c})
// if assert.Nil(t, err) {
// if assert.Len(t, notifications, 1) {
// if nn, ok := notifications[0].(*database.VulnerabilityPriorityIncreasedNotification); assert.True(t, ok) {
// assert.Equal(t, vuln1b.Priority, nn.OldPriority)
// assert.Equal(t, vuln1c.Priority, nn.NewPriority)
// }
// }
// }
//
// notifications, err = InsertVulnerabilities([]*database.Vulnerability{&database.Vulnerability{ID: "test1", Priority: types.Low}})
// assert.Nil(t, err)
// assert.Len(t, notifications, 0)
//
// // database.VulnerabilityPackageChangedNotification
// vuln1e := &database.Vulnerability{ID: "test1", FixedInNodes: []string{pkg1b.Node}}
// vuln1f := &database.Vulnerability{ID: "test1", FixedInNodes: []string{pkg2.Node}}
// notifications, err = InsertVulnerabilities([]*database.Vulnerability{vuln1e, vuln1f})
// if assert.Nil(t, err) {
// if assert.Len(t, notifications, 1) {
// if nn, ok := notifications[0].(*database.VulnerabilityPackageChangedNotification); assert.True(t, ok) {
// // Here, we say that pkg1b fixes the database.Vulnerability, but as pkg1b is in
// // the same branch as pkg1, pkg1 should be removed and pkg1b added
// // We also add pkg2 as fixed
// assert.Contains(t, nn.AddedFixedInNodes, pkg1b.Node)
// assert.Contains(t, nn.RemovedFixedInNodes, pkg1.Node)
//
// assert.Contains(t, nn.AddedFixedInNodes, pkg2.Node)
// }
// }
// }

View File

@ -26,7 +26,7 @@ import (
"github.com/coreos/clair/worker" "github.com/coreos/clair/worker"
) )
// MaxPostSize is the maximum number of bytes that ParseHTTPBody reads from an http.Request.Body. // MaxBodySize is the maximum number of bytes that ParseHTTPBody reads from an http.Request.Body.
const MaxBodySize int64 = 1048576 const MaxBodySize int64 = 1048576
// WriteHTTP writes a JSON-encoded object to a http.ResponseWriter, as well as // WriteHTTP writes a JSON-encoded object to a http.ResponseWriter, as well as
@ -54,7 +54,7 @@ func WriteHTTPError(w http.ResponseWriter, httpStatus int, err error) {
switch err { switch err {
case cerrors.ErrNotFound: case cerrors.ErrNotFound:
httpStatus = http.StatusNotFound httpStatus = http.StatusNotFound
case database.ErrTransaction, database.ErrBackendException: case database.ErrBackendException:
httpStatus = http.StatusServiceUnavailable httpStatus = http.StatusServiceUnavailable
case worker.ErrParentUnknown, worker.ErrUnsupported, utils.ErrCouldNotExtract, utils.ErrExtractedFileTooBig: case worker.ErrParentUnknown, worker.ErrUnsupported, utils.ErrCouldNotExtract, utils.ErrExtractedFileTooBig:
httpStatus = http.StatusBadRequest httpStatus = http.StatusBadRequest

View File

@ -22,7 +22,7 @@ import (
var urlParametersRegexp = regexp.MustCompile(`(\?|\&)([^=]+)\=([^ &]+)`) var urlParametersRegexp = regexp.MustCompile(`(\?|\&)([^=]+)\=([^ &]+)`)
// Hash returns an unique hash of the given string // Hash returns an unique hash of the given string.
func Hash(str string) string { func Hash(str string) string {
h := sha1.New() h := sha1.New()
h.Write([]byte(str)) h.Write([]byte(str))
@ -30,13 +30,13 @@ func Hash(str string) string {
return hex.EncodeToString(bs) return hex.EncodeToString(bs)
} }
// CleanURL removes all parameters from an URL // CleanURL removes all parameters from an URL.
func CleanURL(str string) string { func CleanURL(str string) string {
return urlParametersRegexp.ReplaceAllString(str, "") return urlParametersRegexp.ReplaceAllString(str, "")
} }
// Contains looks for a string into an array of strings and returns whether // Contains looks for a string into an array of strings and returns whether
// the string exists // the string exists.
func Contains(needle string, haystack []string) bool { func Contains(needle string, haystack []string) bool {
for _, h := range haystack { for _, h := range haystack {
if h == needle { if h == needle {
@ -46,22 +46,41 @@ func Contains(needle string, haystack []string) bool {
return false return false
} }
// CompareStringLists returns the strings which are present in X but not in Y // CompareStringLists returns the strings that are present in X but not in Y.
func CompareStringLists(X, Y []string) []string { func CompareStringLists(X, Y []string) []string {
m := make(map[string]int) m := make(map[string]bool)
for _, y := range Y { for _, y := range Y {
m[y] = 1 m[y] = true
} }
diff := []string{} diff := []string{}
for _, x := range X { for _, x := range X {
if m[x] > 0 { if m[x] {
continue continue
} }
diff = append(diff, x) diff = append(diff, x)
m[x] = 1 m[x] = true
}
return diff
}
// CompareStringListsInBoth returns the strings that are present in both X and Y.
func CompareStringListsInBoth(X, Y []string) []string {
m := make(map[string]struct{})
for _, y := range Y {
m[y] = struct{}{}
}
diff := []string{}
for _, x := range X {
if _, e := m[x]; e {
diff = append(diff, x)
delete(m, x)
}
} }
return diff return diff

View File

@ -106,5 +106,5 @@ func (p *Priority) Scan(value interface{}) error {
} }
func (p *Priority) Value() (driver.Value, error) { func (p *Priority) Value() (driver.Value, error) {
return p, nil return string(*p), nil
} }

View File

@ -29,7 +29,16 @@ const fileToDownload = "http://www.google.com/robots.txt"
// TestDiff tests the diff.go source file // TestDiff tests the diff.go source file
func TestDiff(t *testing.T) { func TestDiff(t *testing.T) {
assert.NotContains(t, CompareStringLists([]string{"a", "b", "a"}, []string{"a", "c"}), "a") cmp := CompareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"})
assert.Len(t, cmp, 1)
assert.NotContains(t, cmp, "a")
assert.Contains(t, cmp, "b")
cmp = CompareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"})
assert.Len(t, cmp, 2)
assert.NotContains(t, cmp, "b")
assert.Contains(t, cmp, "a")
assert.Contains(t, cmp, "c")
} }
// TestExec tests the exec.go source file // TestExec tests the exec.go source file