From 7c70fc1c205caa45926ae1435d74d162abf13d54 Mon Sep 17 00:00:00 2001 From: Quentin Machu Date: Tue, 12 Jan 2016 10:40:46 -0500 Subject: [PATCH] database: add initial vulnerability support --- database/database.go | 5 +- database/models.go | 6 +- database/pgsql/feature.go | 94 +++-- database/pgsql/layer.go | 47 +-- .../migrations/20151222113213_Initial.sql | 4 +- database/pgsql/pgsql.go | 4 +- database/pgsql/queries.go | 41 +- database/pgsql/testdata/data.sql | 4 +- database/pgsql/vulnerability.go | 306 ++++++++++++++ database/pgsql/vulnerability_test.go | 387 ++++++++++++++++++ utils/http/http.go | 4 +- utils/string.go | 35 +- utils/types/priority.go | 2 +- utils/utils_test.go | 11 +- 14 files changed, 862 insertions(+), 88 deletions(-) create mode 100644 database/pgsql/vulnerability.go create mode 100644 database/pgsql/vulnerability_test.go diff --git a/database/database.go b/database/database.go index 3a31aca9..be508aa2 100644 --- a/database/database.go +++ b/database/database.go @@ -28,8 +28,9 @@ type Datastore interface { DeleteLayer(name string) error // Vulnerability - // InsertVulnerabilities([]*Vulnerability) - // DeleteVulnerability(id string) + InsertVulnerabilities([]Vulnerability) error + // DeleteVulnerability(id string) error + FindVulnerability(namespaceName, name string) (Vulnerability, error) // Notifications // InsertNotifications([]Notification) error diff --git a/database/models.go b/database/models.go index 95e84542..a7bef10d 100644 --- a/database/models.go +++ b/database/models.go @@ -2,6 +2,7 @@ package database import "github.com/coreos/clair/utils/types" +// ID is only meant to be used by database implementations and should never be used for anything else. type Model struct { ID int } @@ -46,8 +47,9 @@ type Vulnerability struct { Description string Link string Severity types.Priority - // FixedIn map[types.Version]Feature // <<-- WRONG. - Affects []FeatureVersion + + FixedIn []FeatureVersion + //Affects []FeatureVersion // For output purposes. Only make sense when the vulnerability // is already about a specific Feature/FeatureVersion. diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index 33d81ab8..8414540b 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -1,6 +1,8 @@ package pgsql import ( + "database/sql" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/clair/utils/types" @@ -54,6 +56,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) if err != nil { return 0, err } + featureVersion.Feature.ID = featureID // Begin transaction. tx, err := pgSQL.Begin() @@ -62,6 +65,15 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) return 0, handleError("insertFeatureVersion.Begin()", err) } + // Set transaction as SERIALIZABLE. + // This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always + // consistent. + _, err = tx.Exec(getQuery("set_tx_serializable")) + if err != nil { + tx.Rollback() + return 0, handleError("insertFeatureVersion.set_tx_serializable", err) + } + // Find or create FeatureVersion. var newOrExisting string err = tx.QueryRow(getQuery("soi_featureversion"), featureID, &featureVersion.Version). @@ -77,52 +89,15 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) // Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in // Vulnerability_Affects_FeatureVersion. - - // Lock Vulnerability_FixedIn_Feature because we can't let it to be modified while we modify - // Vulnerability_Affects_FeatureVersion. - _, err = tx.Exec(getQuery("l_share_vulnerability_fixedin_feature")) + err = linkFeatureVersionToVulnerabilities(tx, featureVersion) if err != nil { - tx.Rollback() - return 0, handleError("l_share_vulnerability_fixedin_feature", err) - } - - // Select every vulnerability and the fixed version that affect this Feature. - rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureID) - if err != nil { - tx.Rollback() - return 0, handleError("s_vulnerability_fixedin_feature", err) - } - defer rows.Close() - - var fixedInID, vulnerabilityID int - var fixedInVersion types.Version - for rows.Next() { - err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion) - if err != nil { - tx.Rollback() - return 0, handleError("s_vulnerability_fixedin_feature.Scan()", err) - } - - if featureVersion.Version.Compare(fixedInVersion) < 0 { - // The version of the FeatureVersion we are inserting is lower than the fixed version on this - // Vulnerability, thus, this FeatureVersion is affected by it. - // TODO(Quentin-M): Prepare. - _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, - featureVersion.ID, fixedInID) - if err != nil { - tx.Rollback() - return 0, handleError("i_vulnerability_affects_featureversion", err) - } - } - } - if err = rows.Err(); err != nil { - return 0, handleError("s_vulnerability_fixedin_feature.Rows()", err) + // tx.Rollback() is done in linkFeatureVersionToVulnerabilities. + return 0, err } // Commit transaction. err = tx.Commit() if err != nil { - tx.Rollback() return 0, handleError("insertFeatureVersion.Commit()", err) } @@ -148,3 +123,42 @@ func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVers return IDs, nil } + +func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { + // Select every vulnerability and the fixed version that affect this Feature. + // TODO(Quentin-M): LIMIT + rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureVersion.Feature.ID) + if err != nil { + tx.Rollback() + return handleError("s_vulnerability_fixedin_feature", err) + } + defer rows.Close() + + var fixedInID, vulnerabilityID int + var fixedInVersion types.Version + for rows.Next() { + err := rows.Scan(&fixedInID, &vulnerabilityID, &fixedInVersion) + if err != nil { + tx.Rollback() + return handleError("s_vulnerability_fixedin_feature.Scan()", err) + } + + if featureVersion.Version.Compare(fixedInVersion) < 0 { + // The version of the FeatureVersion we are inserting is lower than the fixed version on this + // Vulnerability, thus, this FeatureVersion is affected by it. + // TODO(Quentin-M): Prepare. + _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, + featureVersion.ID, fixedInID) + if err != nil { + tx.Rollback() + return handleError("i_vulnerability_affects_featureversion", err) + } + } + } + if err = rows.Err(); err != nil { + tx.Rollback() + return handleError("s_vulnerability_fixedin_feature.Rows()", err) + } + + return nil +} diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index d2f941f5..645ce2f4 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -77,7 +77,7 @@ func (pgSQL *pgSQL) getLayerFeatureVersions(layerID int, idOnly bool) ([]databas // Query rows, err := pgSQL.Query(query, layerID) - if err != nil && err != sql.ErrNoRows { + if err != nil { return featureVersions, handleError(query, err) } defer rows.Close() @@ -201,14 +201,23 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { if err != nil && err != cerrors.ErrNotFound { return err } else if err == nil { + if existingLayer.EngineVersion >= layer.EngineVersion { + // The layer exists and has an equal or higher engine verison, do nothing. + return nil + } + layer.ID = existingLayer.ID } - // Begin transaction. - tx, err := pgSQL.Begin() - if err != nil { - tx.Rollback() - return handleError("InsertLayer.Begin()", err) + // Get parent ID. + var parentID zero.Int + if layer.Parent != nil { + if layer.Parent.ID == 0 { + log.Warning("Parent is expected to be retrieved from database when inserting a layer.") + return cerrors.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.") + } + + parentID = zero.IntFrom(int64(layer.Parent.ID)) } // Find or insert namespace if provided. @@ -216,7 +225,6 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { if layer.Namespace != nil { n, err := pgSQL.insertNamespace(*layer.Namespace) if err != nil { - tx.Rollback() return err } namespaceID = zero.IntFrom(int64(n)) @@ -227,18 +235,15 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { } } + // Begin transaction. + tx, err := pgSQL.Begin() + if err != nil { + tx.Rollback() + return handleError("InsertLayer.Begin()", err) + } + if layer.ID == 0 { // Insert a new layer. - var parentID zero.Int - if layer.Parent != nil { - if layer.Parent.ID == 0 { - log.Warning("Parent is expected to be retrieved from database when inserting a layer.") - return cerrors.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.") - } - - parentID = zero.IntFrom(int64(layer.Parent.ID)) - } - err = tx.QueryRow(getQuery("i_layer"), layer.Name, layer.EngineVersion, parentID, namespaceID). Scan(&layer.ID) if err != nil { @@ -246,11 +251,6 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { return handleError("i_layer", err) } } else { - if existingLayer.EngineVersion >= layer.EngineVersion { - // The layer exists and has an equal or higher engine verison, do nothing. - return nil - } - // Update an existing layer. _, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID) if err != nil { @@ -269,6 +269,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { // Update Layer_diff_FeatureVersion now. err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer) if err != nil { + tx.Rollback() return err } @@ -293,7 +294,7 @@ func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer * } else if layer.Parent != nil { // There is a parent, we need to diff the Features with it. - // Build name:version strctures. + // Build name:version structures. layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features) parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features) diff --git a/database/pgsql/migrations/20151222113213_Initial.sql b/database/pgsql/migrations/20151222113213_Initial.sql index 14f3092a..bdf8e9f8 100644 --- a/database/pgsql/migrations/20151222113213_Initial.sql +++ b/database/pgsql/migrations/20151222113213_Initial.sql @@ -73,7 +73,7 @@ CREATE TABLE IF NOT EXISTS Vulnerability ( name VARCHAR(128) NOT NULL, description TEXT NULL, link VARCHAR(128) NULL, - severity severity NULL, + severity severity NOT NULL, UNIQUE (namespace_id, name)); @@ -137,6 +137,6 @@ DROP TABLE IF EXISTS Namespace, Vulnerability, Vulnerability_FixedIn_Feature, Vulnerability_Affects_FeatureVersion, - KeyValue + KeyValue, Lock CASCADE; diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 476eca0f..d0ff887e 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -120,7 +120,7 @@ func dropDatabase(dataSource, databaseName string) error { // Drop database. _, err = db.Exec("DROP DATABASE " + databaseName + ";") if err != nil { - return fmt.Errorf("could not create database: %v", err) + return fmt.Errorf("could not drop database: %v", err) } return nil @@ -185,7 +185,7 @@ func handleError(desc string, err error) error { return database.ErrBackendException } else if err == sql.ErrNoRows { return cerrors.ErrNotFound - } else if err == sql.ErrTxDone { + } else if err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { return database.ErrBackendException } diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index 112813a8..e1088765 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -10,6 +10,8 @@ var queries map[string]string func init() { queries = make(map[string]string) + queries["set_tx_serializable"] = `SET TRANSACTION ISOLATION LEVEL SERIALIZABLE` + // keyvalue.go queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2` queries["i_keyvalue"] = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` @@ -39,10 +41,6 @@ func init() { UNION SELECT id FROM new_feature` - queries["l_share_vulnerability_fixedin_feature"] = ` - LOCK Vulnerability_FixedIn_Feature IN SHARE MODE - ` - queries["soi_featureversion"] = ` WITH new_featureversion AS ( INSERT INTO FeatureVersion(feature_id, version) @@ -142,6 +140,41 @@ func init() { queries["r_lock"] = `DELETE FROM Lock WHERE name = $1 AND owner = $2` queries["r_lock_expired"] = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` + + // vulnerability.go + queries["f_vulnerability"] = ` + SELECT v.id, n.id, v.description, v.link, v.severity, vfif.version, f.id, f.Name + FROM Vulnerability v + JOIN Namespace n ON v.namespace_id = n.id + LEFT JOIN Vulnerability_FixedIn_Feature vfif ON v.id = vfif.vulnerability_id + LEFT JOIN Feature f ON vfif.feature_id = f.id + WHERE n.Name = $1 AND v.Name = $2` + + queries["i_vulnerability"] = ` + INSERT INTO Vulnerability(namespace_id, name, description, link, severity) + VALUES($1, $2, $3, $4, $5) + RETURNING id` + + queries["u_vulnerability"] = ` + UPDATE Vulnerability SET description = $2, link = $3, severity = $4 WHERE id = $1` + + queries["i_vulnerability_fixedin_feature"] = ` + INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) + VALUES($1, $2, $3) + RETURNING id` + + queries["u_vulnerability_fixedin_feature"] = ` + UPDATE Vulnerability_FixedIn_Feature + SET version = $3 + WHERE vulnerability_id = $1 AND feature_id = $2 + RETURNING id` + + queries["r_vulnerability_affects_featureversion"] = ` + DELETE FROM Vulnerability_Affects_FeatureVersion + WHERE fixedin_id = $1` + + queries["f_featureversion_by_feature"] = ` + SELECT id, version FROM FeatureVersion WHERE feature_id = $1` } func getQuery(name string) string { diff --git a/database/pgsql/testdata/data.sql b/database/pgsql/testdata/data.sql index a555ecf2..3033533c 100644 --- a/database/pgsql/testdata/data.sql +++ b/database/pgsql/testdata/data.sql @@ -3,6 +3,7 @@ INSERT INTO namespace (id, name) VALUES (2, 'debian:8'); INSERT INTO feature (id, namespace_id, name) VALUES (1, 1, 'wechat'); INSERT INTO feature (id, namespace_id, name) VALUES (2, 1, 'openssl'); +INSERT INTO feature (id, namespace_id, name) VALUES (4, 1, 'libssl'); INSERT INTO feature (id, namespace_id, name) VALUES (3, 2, 'openssl'); INSERT INTO featureversion (id, feature_id, version) VALUES (1, 1, '0.5'); INSERT INTO featureversion (id, feature_id, version) VALUES (2, 2, '1.0'); @@ -23,8 +24,9 @@ INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modifica INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'); INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES (1, 1, 2, '2.0'); +INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES (2, 1, 4, '1.9-abc'); INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 -INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', 'http://google.com/#q=NOPE', 'Negligible'); +INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown'); SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1); diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go new file mode 100644 index 00000000..cc7d4f85 --- /dev/null +++ b/database/pgsql/vulnerability.go @@ -0,0 +1,306 @@ +package pgsql + +import ( + "database/sql" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" + "github.com/guregu/null/zero" +) + +func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { + vulnerability := database.Vulnerability{ + Name: name, + Namespace: database.Namespace{ + Name: namespaceName, + }, + } + + // Find Vulnerability. + rows, err := pgSQL.Query(getQuery("f_vulnerability"), namespaceName, name) + if err != nil { + return vulnerability, handleError("f_vulnerability", err) + } + defer rows.Close() + + // Iterate to scan the Vulnerability and its FixedIn FeatureVersions. + for rows.Next() { + var featureVersionID zero.Int + var featureVersionVersion zero.String + var featureVersionFeatureName zero.String + + err := rows.Scan(&vulnerability.ID, &vulnerability.Namespace.ID, &vulnerability.Description, + &vulnerability.Link, &vulnerability.Severity, &featureVersionVersion, &featureVersionID, + &featureVersionFeatureName) + if err != nil { + return vulnerability, handleError("f_vulnerability.Scan()", err) + } + + if !featureVersionID.IsZero() { + // Note that the ID we fill in featureVersion is actually a Feature ID, and not + // a FeatureVersion ID. + featureVersion := database.FeatureVersion{ + Model: database.Model{ID: int(featureVersionID.Int64)}, + Feature: database.Feature{ + Model: database.Model{ID: int(featureVersionID.Int64)}, + Namespace: vulnerability.Namespace, + Name: featureVersionFeatureName.String, + }, + Version: types.NewVersionUnsafe(featureVersionVersion.String), + } + vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + } + } + if err = rows.Err(); err != nil { + return vulnerability, handleError("s_featureversions_vulnerabilities.Rows()", err) + } + if vulnerability.ID == 0 { + return vulnerability, cerrors.ErrNotFound + } + + return vulnerability, nil +} + +// FixedIn.Namespace are not necessary, they are overwritten by the vuln. +func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability) error { + for _, vulnerability := range vulnerabilities { + err := pgSQL.insertVulnerability(vulnerability) + if err != nil { + return err + } + } + return nil +} + +func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) error { + // Verify parameters + if vulnerability.Name == "" || len(vulnerability.FixedIn) == 0 || + vulnerability.Namespace.Name == "" || !vulnerability.Severity.IsValid() { + log.Warning("could not insert an invalid vulnerability") + return cerrors.NewBadRequestError("could not insert an invalid vulnerability") + } + + for _, fixedInFeatureVersion := range vulnerability.FixedIn { + if fixedInFeatureVersion.Feature.Namespace.Name != "" && + fixedInFeatureVersion.Feature.Namespace.Name != vulnerability.Namespace.Name { + msg := "could not insert an invalid vulnerability: FixedIn FeatureVersion must be in the " + + "same namespace as the Vulnerability" + log.Warning(msg) + return cerrors.NewBadRequestError(msg) + } + } + + // Find or insert Vulnerability's Namespace. + namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace) + if err != nil { + return err + } + + // Find vulnerability and its Vulnerability_FixedIn_Features. + existingVulnerability, err := pgSQL.FindVulnerability(vulnerability.Namespace.Name, + vulnerability.Name) + if err != nil && err != cerrors.ErrNotFound { + return err + } + + // Compute new/updated FixedIn FeatureVersions. + var newFixedInFeatureVersions []database.FeatureVersion + var updatedFixedInFeatureVersions []database.FeatureVersion + if existingVulnerability.ID == 0 { + newFixedInFeatureVersions = vulnerability.FixedIn + } else { + newFixedInFeatureVersions, updatedFixedInFeatureVersions = diffFixedIn(vulnerability, + existingVulnerability) + } + + if len(newFixedInFeatureVersions) == 0 && len(updatedFixedInFeatureVersions) == 0 { + // Nothing to do. + return nil + } + + // Insert or find the new FeatureVersions. + // We already have the Feature IDs in updatedFixedInFeatureVersions because diffFixedIn fills them + // in using the existing vulnerability's FixedIn FeatureVersions. Note that even if FixedIn + // is type FeatureVersion, the actual stored ID in these structs are the Feature IDs. + // + // Also, we enforce the namespace of the FeatureVersion in case it was empty. There is a test + // above to ensure that the passed Namespace is either the same as the vulnerability or empty. + for i := 0; i < len(newFixedInFeatureVersions); i++ { + newFixedInFeatureVersions[i].Feature.Namespace.Name = vulnerability.Namespace.Name + newFixedInFeatureVersions[i].ID, err = pgSQL.insertFeatureVersion(newFixedInFeatureVersions[i]) + if err != nil { + return err + } + } + + // Begin transaction. + tx, err := pgSQL.Begin() + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.Begin()", err) + } + + // Set transaction as SERIALIZABLE. + // This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always + // consistent. + _, err = tx.Exec(getQuery("set_tx_serializable")) + if err != nil { + tx.Rollback() + return handleError("insertFeatureVersion.set_tx_serializable", err) + } + + if existingVulnerability.ID == 0 { + // Insert new vulnerability. + err = tx.QueryRow(getQuery("i_vulnerability"), namespaceID, vulnerability.Name, + vulnerability.Description, vulnerability.Link, &vulnerability.Severity).Scan(&vulnerability.ID) + if err != nil { + tx.Rollback() + return handleError("i_vulnerability", err) + } + } else { + // Update vulnerability + _, err = tx.Exec(getQuery("u_vulnerability"), existingVulnerability.ID, + vulnerability.Description, vulnerability.Link, &vulnerability.Severity) + if err != nil { + tx.Rollback() + return handleError("u_vulnerability", err) + } + + vulnerability.ID = existingVulnerability.ID + } + + // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. + err = pgSQL.updateVulnerabilityFeatureVersions(tx, &vulnerability, &existingVulnerability, newFixedInFeatureVersions, updatedFixedInFeatureVersions) + if err != nil { + tx.Rollback() + return err + } + + // Commit transaction. + err = tx.Commit() + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.Commit()", err) + } + + return nil +} + +func diffFixedIn(vulnerability, existingVulnerability database.Vulnerability) (newFixedIn, updatedFixedIn []database.FeatureVersion) { + // Build FeatureVersion.Feature.Namespace.Name:FeatureVersion.Feature.Name (NaN) structures. + vulnerabilityFixedInNameMap, vulnerabilityFixedInNameSlice := createFeatureVersionNameMap(vulnerability.FixedIn) + existingFixedInMapNameMap, existingFixedInNameSlice := createFeatureVersionNameMap(existingVulnerability.FixedIn) + + // Calculate the new FixedIn FeatureVersion NaN and updated ones. + newFixedInName := utils.CompareStringLists(vulnerabilityFixedInNameSlice, + existingFixedInNameSlice) + updatedFixedInName := utils.CompareStringListsInBoth(vulnerabilityFixedInNameSlice, + existingFixedInNameSlice) + + for _, nan := range newFixedInName { + newFixedIn = append(newFixedIn, vulnerabilityFixedInNameMap[nan]) + } + for _, nan := range updatedFixedInName { + fv := existingFixedInMapNameMap[nan] + fv.Version = vulnerabilityFixedInNameMap[nan].Version + updatedFixedIn = append(updatedFixedIn, fv) + } + + return +} + +func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) { + m := make(map[string]database.FeatureVersion, 0) + s := make([]string, 0, len(features)) + + for i := 0; i < len(features); i++ { + featureVersion := features[i] + m[featureVersion.Feature.Name] = featureVersion + s = append(s, featureVersion.Feature.Name) + } + + return m, s +} + +// TODO(Quentin-M): Add support for removing Vulnerability_FixedIn_Feature when Version = MinVersion. +// We should then update the vulnerability fetcher to do it. +// Also maybe we would delete a Vulnerability if it hasn't any FixedIn. +// --> And affects +func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability, existingVulnerability *database.Vulnerability, newFixedInFeatureVersions, updatedFixedInFeatureVersions []database.FeatureVersion) error { + var fixedInID int + + for _, fv := range newFixedInFeatureVersions { + // Insert Vulnerability_FixedIn_Feature. + err := tx.QueryRow(getQuery("i_vulnerability_fixedin_feature"), vulnerability.ID, fv.ID, + &fv.Version).Scan(&fixedInID) + if err != nil { + return handleError("i_vulnerability_fixedin_feature", err) + } + + // Insert Vulnerability_Affects_FeatureVersion. + err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerability.ID, fv.ID, fv.Version) + if err != nil { + return err + } + } + + for _, fv := range updatedFixedInFeatureVersions { + // Update Vulnerability_FixedIn_Feature. + err := tx.QueryRow(getQuery("u_vulnerability_fixedin_feature"), vulnerability.ID, fv.ID, + &fv.Version).Scan(&fixedInID) + if err != nil { + return handleError("u_vulnerability_fixedin_feature", err) + } + + // Drop all old Vulnerability_Affects_FeatureVersion. + _, err = tx.Exec(getQuery("r_vulnerability_affects_featureversion"), fixedInID) + if err != nil { + return handleError("r_vulnerability_affects_featureversion", err) + } + + // Insert Vulnerability_Affects_FeatureVersion. + err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerability.ID, fv.ID, fv.Version) + if err != nil { + return err + } + } + + return nil +} + +func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error { + // Find every FeatureVersions of the Feature we want to affect. + // TODO(Quentin-M): LIMIT + rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID) + if err == sql.ErrNoRows { + return nil + } + if err != nil { + return handleError("f_featureversion_by_feature", err) + } + defer rows.Close() + + var featureVersionID int + var featureVersionVersion types.Version + for rows.Next() { + err := rows.Scan(&featureVersionID, &featureVersionVersion) + if err != nil { + return handleError("f_featureversion_by_feature.Scan()", err) + } + + if featureVersionVersion.Compare(fixedInVersion) < 0 { + _, err := tx.Exec("i_vulnerability_affects_featureversion", vulnerabilityID, featureVersionID, + fixedInID) + if err != nil { + return handleError("i_vulnerability_affects_featureversion", err) + } + } + } + if err = rows.Err(); err != nil { + return handleError("f_featureversion_by_feature.Rows()", err) + } + + return nil +} diff --git a/database/pgsql/vulnerability_test.go b/database/pgsql/vulnerability_test.go new file mode 100644 index 00000000..690a9670 --- /dev/null +++ b/database/pgsql/vulnerability_test.go @@ -0,0 +1,387 @@ +package pgsql + +import ( + "testing" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" + "github.com/stretchr/testify/assert" +) + +func TestFindVulnerability(t *testing.T) { + datastore, err := OpenForTest("FindVulnerability", true) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Find a vulnerability that does not exist. + _, err = datastore.FindVulnerability("", "") + assert.Equal(t, cerrors.ErrNotFound, err) + + // Find a normal vulnerability. + v1 := database.Vulnerability{ + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: types.High, + FixedIn: []database.FeatureVersion{ + database.FeatureVersion{ + Feature: database.Feature{Name: "openssl"}, + Version: types.NewVersionUnsafe("2.0"), + }, + database.FeatureVersion{ + Feature: database.Feature{Name: "libssl"}, + Version: types.NewVersionUnsafe("1.9-abc"), + }, + }, + } + + v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") + if assert.Nil(t, err) { + equalsVuln(t, &v1, &v1f) + } + + // Find a vulnerability that has no link, no severity and no FixedIn. + v2 := database.Vulnerability{ + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + } + + v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE") + if assert.Nil(t, err) { + equalsVuln(t, &v2, &v2f) + } +} + +func TestInsertVulnerability(t *testing.T) { + datastore, err := OpenForTest("InsertVulnerability", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Create some data. + n1 := database.Namespace{Name: "TestInsertVulnerabilityNamespace1"} + n2 := database.Namespace{Name: "TestInsertVulnerabilityNamespace2"} + + f1 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion1", + Namespace: n1, + }, + Version: types.NewVersionUnsafe("1.0"), + } + f2 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion1", + Namespace: n2, + }, + Version: types.NewVersionUnsafe("1.0"), + } + f3 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion2", + }, + Version: types.MaxVersion, + } + f4 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion2", + }, + Version: types.NewVersionUnsafe("1.4"), + } + f5 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion3", + }, + Version: types.NewVersionUnsafe("1.5"), + } + f6 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion4", + }, + Version: types.NewVersionUnsafe("0.1"), + } + + // Insert invalid vulnerabilities. + for _, vulnerability := range []database.Vulnerability{ + database.Vulnerability{ + Name: "", + Namespace: n1, + FixedIn: []database.FeatureVersion{f1}, + Severity: types.Unknown, + }, + database.Vulnerability{ + Name: "TestInsertVulnerability0", + Namespace: database.Namespace{}, + FixedIn: []database.FeatureVersion{f1}, + Severity: types.Unknown, + }, + database.Vulnerability{ + Name: "TestInsertVulnerability0-", + Namespace: database.Namespace{}, + FixedIn: []database.FeatureVersion{f1}, + }, + database.Vulnerability{ + Name: "TestInsertVulnerability0", + Namespace: n1, + FixedIn: []database.FeatureVersion{f1}, + Severity: types.Priority(""), + }, + database.Vulnerability{ + Name: "TestInsertVulnerability0", + Namespace: n1, + FixedIn: []database.FeatureVersion{f2}, + Severity: types.Unknown, + }, + } { + err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}) + assert.Error(t, err) + } + + // Insert a simple vulnerability and find it. + v1 := database.Vulnerability{ + Name: "TestInsertVulnerability1", + Namespace: n1, + FixedIn: []database.FeatureVersion{f1, f3, f6}, + Severity: types.Low, + Description: "TestInsertVulnerabilityDescription1", + Link: "TestInsertVulnerabilityLink1", + } + err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}) + if assert.Nil(t, err) { + v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) + if assert.Nil(t, err) { + equalsVuln(t, &v1, &v1f) + } + } + + // Update vulnerability. + v1.Description = "TestInsertVulnerabilityLink2" + v1.Link = "TestInsertVulnerabilityLink2" + v1.Severity = types.High + // Update f3 by f4, add fixed by f5, add fixed by f6 which already exists. + // TODO(Quentin-M): Remove FixedIn. + v1.FixedIn = []database.FeatureVersion{f4, f5, f6} + + err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}) + if assert.Nil(t, err) { + v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) + if assert.Nil(t, err) { + // We already had f1 before the update. + // Add it to the struct for comparison. + v1.FixedIn = append(v1.FixedIn, f1) + equalsVuln(t, &v1, &v1f) + } + } +} + +func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) { + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name) + assert.Equal(t, expected.Description, actual.Description) + assert.Equal(t, expected.Link, actual.Link) + assert.Equal(t, expected.Severity, actual.Severity) + if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) { + for _, actualFeatureVersion := range actual.FixedIn { + found := false + for _, expectedFeatureVersion := range expected.FixedIn { + if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name { + found = true + + assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name) + assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version) + } + } + if !found { + t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name) + } + } + } +} + +// TODO Test Affects in Feature_Version and here. + +// +// // Some data +// vuln1 := &database.Vulnerability{ID: "test1", Link: "link1", Priority: types.Medium, Description: "testDescription1", FixedInNodes: []string{"pkg1"}} +// vuln2 := &database.Vulnerability{ID: "test2", Link: "link2", Priority: types.High, Description: "testDescription2", FixedInNodes: []string{"pkg1", "pkg2"}} +// vuln3 := &database.Vulnerability{ID: "test3", Link: "link3", Priority: types.High, FixedInNodes: []string{"pkg3"}} // Empty description +// +// // Insert some vulnerabilities +// _, err := InsertVulnerabilities([]*database.Vulnerability{vuln1, vuln2, vuln3}) +// if assert.Nil(t, err) { +// // Find one of the vulnerabilities we just inserted and verify its content +// v1, err := FindOnedatabase.Vulnerability(vuln1.ID, Fielddatabase.VulnerabilityAll) +// if assert.Nil(t, err) && assert.NotNil(t, v1) { +// assert.Equal(t, vuln1.ID, v1.ID) +// assert.Equal(t, vuln1.Link, v1.Link) +// assert.Equal(t, vuln1.Priority, v1.Priority) +// assert.Equal(t, vuln1.Description, v1.Description) +// if assert.Len(t, v1.FixedInNodes, 1) { +// assert.Equal(t, vuln1.FixedInNodes[0], v1.FixedInNodes[0]) +// } +// } +// } +// +// // Update a database.Vulnerability and verify its new content +// pkg1 := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.0")} +// InsertPackages([]*Package{pkg1}) +// vuln5 := &database.Vulnerability{ID: "test5", Link: "link5", Priority: types.Medium, Description: "testDescription5", FixedInNodes: []string{pkg1.Node}} +// +// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5}) +// if assert.Nil(t, err) { +// // Partial updates +// // # Just a field update +// vuln5b := &database.Vulnerability{ID: "test5", Priority: types.High} +// _, err := InsertVulnerabilities([]*database.Vulnerability{vuln5b}) +// if assert.Nil(t, err) { +// v5b, err := FindOnedatabase.Vulnerability(vuln5b.ID, Fielddatabase.VulnerabilityAll) +// if assert.Nil(t, err) && assert.NotNil(t, v5b) { +// assert.Equal(t, vuln5b.ID, v5b.ID) +// assert.Equal(t, vuln5b.Priority, v5b.Priority) +// +// if assert.Len(t, v5b.FixedInNodes, 1) { +// assert.Contains(t, v5b.FixedInNodes, pkg1.Node) +// } +// } +// } +// +// // # Just a field update, twice in the same transaction +// vuln5b1 := &database.Vulnerability{ID: "test5", Link: "http://foo.bar"} +// vuln5b2 := &database.Vulnerability{ID: "test5", Link: "http://bar.foo"} +// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5b1, vuln5b2}) +// if assert.Nil(t, err) { +// v5b2, err := FindOnedatabase.Vulnerability(vuln5b2.ID, Fielddatabase.VulnerabilityAll) +// if assert.Nil(t, err) && assert.NotNil(t, v5b2) { +// assert.Equal(t, vuln5b2.Link, v5b2.Link) +// } +// } +// +// // # All fields except fixedIn update +// vuln5c := &database.Vulnerability{ID: "test5", Link: "link5c", Priority: types.Critical, Description: "testDescription5c"} +// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5c}) +// if assert.Nil(t, err) { +// v5c, err := FindOnedatabase.Vulnerability(vuln5c.ID, Fielddatabase.VulnerabilityAll) +// if assert.Nil(t, err) && assert.NotNil(t, v5c) { +// assert.Equal(t, vuln5c.ID, v5c.ID) +// assert.Equal(t, vuln5c.Link, v5c.Link) +// assert.Equal(t, vuln5c.Priority, v5c.Priority) +// assert.Equal(t, vuln5c.Description, v5c.Description) +// +// if assert.Len(t, v5c.FixedInNodes, 1) { +// assert.Contains(t, v5c.FixedInNodes, pkg1.Node) +// } +// } +// } +// +// // Complete update +// pkg2 := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.1")} +// pkg3 := &Package{OS: "testOS", Name: "testpkg2", Version: types.NewVersionUnsafe("1.0")} +// InsertPackages([]*Package{pkg2, pkg3}) +// vuln5d := &database.Vulnerability{ID: "test5", Link: "link5d", Priority: types.Low, Description: "testDescription5d", FixedInNodes: []string{pkg2.Node, pkg3.Node}} +// +// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln5d}) +// if assert.Nil(t, err) { +// v5d, err := FindOnedatabase.Vulnerability(vuln5d.ID, Fielddatabase.VulnerabilityAll) +// if assert.Nil(t, err) && assert.NotNil(t, v5d) { +// assert.Equal(t, vuln5d.ID, v5d.ID) +// assert.Equal(t, vuln5d.Link, v5d.Link) +// assert.Equal(t, vuln5d.Priority, v5d.Priority) +// assert.Equal(t, vuln5d.Description, v5d.Description) +// +// // Here, we ensure that a database.Vulnerability can only be fixed by one package of a given branch at a given time +// // And that we can add new fixed packages as well +// if assert.Len(t, v5d.FixedInNodes, 2) { +// assert.NotContains(t, v5d.FixedInNodes, pkg1.Node) +// } +// } +// } +// } +// +// // Create and update a database.Vulnerability's packages (and from the same branch) in the same batch +// pkg1 = &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.0")} +// pkg1b := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.1")} +// InsertPackages([]*Package{pkg1, pkg1b}) +// +// // # Two updates of the same database.Vulnerability in the same batch with packages of the same branch +// pkg0 := &Package{OS: "testOS", Name: "testpkg0", Version: types.NewVersionUnsafe("1.0")} +// InsertPackages([]*Package{pkg0}) +// _, err = InsertVulnerabilities([]*database.Vulnerability{&database.Vulnerability{ID: "test7", Link: "link7", Priority: types.Medium, Description: "testDescription7", FixedInNodes: []string{pkg0.Node}}}) +// if assert.Nil(t, err) { +// vuln7b := &database.Vulnerability{ID: "test7", FixedInNodes: []string{pkg1.Node}} +// vuln7c := &database.Vulnerability{ID: "test7", FixedInNodes: []string{pkg1b.Node}} +// _, err = InsertVulnerabilities([]*database.Vulnerability{vuln7b, vuln7c}) +// if assert.Nil(t, err) { +// v7, err := FindOnedatabase.Vulnerability("test7", Fielddatabase.VulnerabilityAll) +// if assert.Nil(t, err) && assert.Len(t, v7.FixedInNodes, 2) { +// assert.Contains(t, v7.FixedInNodes, pkg0.Node) +// assert.NotContains(t, v7.FixedInNodes, pkg1.Node) +// assert.Contains(t, v7.FixedInNodes, pkg1b.Node) +// } +// } +// } + +// func TestInsertVulnerabilityNotifications(t *testing.T) { +// Open(&config.DatabaseConfig{Type: "memstore"}) +// defer Close() +// +// pkg1 := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.0")} +// pkg1b := &Package{OS: "testOS", Name: "testpkg1", Version: types.NewVersionUnsafe("1.2")} +// pkg2 := &Package{OS: "testOS", Name: "testpkg2", Version: types.NewVersionUnsafe("1.0")} +// InsertPackages([]*Package{pkg1, pkg1b, pkg2}) +// +// // Newdatabase.VulnerabilityNotification +// vuln1 := &database.Vulnerability{ID: "test1", Link: "link1", Priority: types.Medium, Description: "testDescription1", FixedInNodes: []string{pkg1.Node}} +// vuln2 := &database.Vulnerability{ID: "test2", Link: "link2", Priority: types.High, Description: "testDescription2", FixedInNodes: []string{pkg1.Node, pkg2.Node}} +// vuln1b := &database.Vulnerability{ID: "test1", Priority: types.High, FixedInNodes: []string{"pkg3"}} +// notifications, err := InsertVulnerabilities([]*database.Vulnerability{vuln1, vuln2, vuln1b}) +// if assert.Nil(t, err) { +// // We should only have two Newdatabase.VulnerabilityNotification notifications: one for test1 and one for test2 +// // We should not have a database.VulnerabilityPriorityIncreasedNotification or a database.VulnerabilityPackageChangedNotification +// // for test1 because it is in the same batch +// if assert.Len(t, notifications, 2) { +// for _, n := range notifications { +// _, ok := n.(*Newdatabase.VulnerabilityNotification) +// assert.True(t, ok) +// } +// } +// } +// +// // database.VulnerabilityPriorityIncreasedNotification +// vuln1c := &database.Vulnerability{ID: "test1", Priority: types.Critical} +// notifications, err = InsertVulnerabilities([]*database.Vulnerability{vuln1c}) +// if assert.Nil(t, err) { +// if assert.Len(t, notifications, 1) { +// if nn, ok := notifications[0].(*database.VulnerabilityPriorityIncreasedNotification); assert.True(t, ok) { +// assert.Equal(t, vuln1b.Priority, nn.OldPriority) +// assert.Equal(t, vuln1c.Priority, nn.NewPriority) +// } +// } +// } +// +// notifications, err = InsertVulnerabilities([]*database.Vulnerability{&database.Vulnerability{ID: "test1", Priority: types.Low}}) +// assert.Nil(t, err) +// assert.Len(t, notifications, 0) +// +// // database.VulnerabilityPackageChangedNotification +// vuln1e := &database.Vulnerability{ID: "test1", FixedInNodes: []string{pkg1b.Node}} +// vuln1f := &database.Vulnerability{ID: "test1", FixedInNodes: []string{pkg2.Node}} +// notifications, err = InsertVulnerabilities([]*database.Vulnerability{vuln1e, vuln1f}) +// if assert.Nil(t, err) { +// if assert.Len(t, notifications, 1) { +// if nn, ok := notifications[0].(*database.VulnerabilityPackageChangedNotification); assert.True(t, ok) { +// // Here, we say that pkg1b fixes the database.Vulnerability, but as pkg1b is in +// // the same branch as pkg1, pkg1 should be removed and pkg1b added +// // We also add pkg2 as fixed +// assert.Contains(t, nn.AddedFixedInNodes, pkg1b.Node) +// assert.Contains(t, nn.RemovedFixedInNodes, pkg1.Node) +// +// assert.Contains(t, nn.AddedFixedInNodes, pkg2.Node) +// } +// } +// } diff --git a/utils/http/http.go b/utils/http/http.go index 81dfdc19..1aac81f3 100644 --- a/utils/http/http.go +++ b/utils/http/http.go @@ -26,7 +26,7 @@ import ( "github.com/coreos/clair/worker" ) -// MaxPostSize is the maximum number of bytes that ParseHTTPBody reads from an http.Request.Body. +// MaxBodySize is the maximum number of bytes that ParseHTTPBody reads from an http.Request.Body. const MaxBodySize int64 = 1048576 // WriteHTTP writes a JSON-encoded object to a http.ResponseWriter, as well as @@ -54,7 +54,7 @@ func WriteHTTPError(w http.ResponseWriter, httpStatus int, err error) { switch err { case cerrors.ErrNotFound: httpStatus = http.StatusNotFound - case database.ErrTransaction, database.ErrBackendException: + case database.ErrBackendException: httpStatus = http.StatusServiceUnavailable case worker.ErrParentUnknown, worker.ErrUnsupported, utils.ErrCouldNotExtract, utils.ErrExtractedFileTooBig: httpStatus = http.StatusBadRequest diff --git a/utils/string.go b/utils/string.go index 6e958874..a366c2f5 100644 --- a/utils/string.go +++ b/utils/string.go @@ -22,7 +22,7 @@ import ( var urlParametersRegexp = regexp.MustCompile(`(\?|\&)([^=]+)\=([^ &]+)`) -// Hash returns an unique hash of the given string +// Hash returns an unique hash of the given string. func Hash(str string) string { h := sha1.New() h.Write([]byte(str)) @@ -30,13 +30,13 @@ func Hash(str string) string { return hex.EncodeToString(bs) } -// CleanURL removes all parameters from an URL +// CleanURL removes all parameters from an URL. func CleanURL(str string) string { return urlParametersRegexp.ReplaceAllString(str, "") } // Contains looks for a string into an array of strings and returns whether -// the string exists +// the string exists. func Contains(needle string, haystack []string) bool { for _, h := range haystack { if h == needle { @@ -46,22 +46,41 @@ func Contains(needle string, haystack []string) bool { return false } -// CompareStringLists returns the strings which are present in X but not in Y +// CompareStringLists returns the strings that are present in X but not in Y. func CompareStringLists(X, Y []string) []string { - m := make(map[string]int) + m := make(map[string]bool) for _, y := range Y { - m[y] = 1 + m[y] = true } diff := []string{} for _, x := range X { - if m[x] > 0 { + if m[x] { continue } diff = append(diff, x) - m[x] = 1 + m[x] = true + } + + return diff +} + +// CompareStringListsInBoth returns the strings that are present in both X and Y. +func CompareStringListsInBoth(X, Y []string) []string { + m := make(map[string]struct{}) + + for _, y := range Y { + m[y] = struct{}{} + } + + diff := []string{} + for _, x := range X { + if _, e := m[x]; e { + diff = append(diff, x) + delete(m, x) + } } return diff diff --git a/utils/types/priority.go b/utils/types/priority.go index f8f10942..aac56d40 100644 --- a/utils/types/priority.go +++ b/utils/types/priority.go @@ -106,5 +106,5 @@ func (p *Priority) Scan(value interface{}) error { } func (p *Priority) Value() (driver.Value, error) { - return p, nil + return string(*p), nil } diff --git a/utils/utils_test.go b/utils/utils_test.go index aafcc2dc..8962c84f 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -29,7 +29,16 @@ const fileToDownload = "http://www.google.com/robots.txt" // TestDiff tests the diff.go source file func TestDiff(t *testing.T) { - assert.NotContains(t, CompareStringLists([]string{"a", "b", "a"}, []string{"a", "c"}), "a") + cmp := CompareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"}) + assert.Len(t, cmp, 1) + assert.NotContains(t, cmp, "a") + assert.Contains(t, cmp, "b") + + cmp = CompareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"}) + assert.Len(t, cmp, 2) + assert.NotContains(t, cmp, "b") + assert.Contains(t, cmp, "a") + assert.Contains(t, cmp, "c") } // TestExec tests the exec.go source file