diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go index 9896861a..12f0b72a 100644 --- a/database/pgsql/complex_test.go +++ b/database/pgsql/complex_test.go @@ -128,7 +128,7 @@ func TestRaceAffects(t *testing.T) { featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String()) // Get actual affects. - rows, err := datastore.Query(getQuery("s_complextest_featureversion_affects"), + rows, err := datastore.Query(searchComplexTestFeatureVersionAffects, featureVersion.ID) assert.Nil(t, err) defer rows.Close() diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index d09093b2..247d5701 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -49,9 +49,9 @@ func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { // Find or create Feature. var id int - err = pgSQL.QueryRow(getQuery("soi_feature"), feature.Name, namespaceID).Scan(&id) + err = pgSQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id) if err != nil { - return 0, handleError("soi_feature", err) + return 0, handleError("soiFeature", err) } if pgSQL.cache != nil { @@ -103,25 +103,25 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) promConcurrentLockVAFV.Inc() defer promConcurrentLockVAFV.Dec() t = time.Now() - _, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion")) + _, err = tx.Exec(lockVulnerabilityAffects) observeQueryTime("insertFeatureVersion", "lock", t) if err != nil { tx.Rollback() - return 0, handleError("insertFeatureVersion.l_vulnerability_affects_featureversion", err) + return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err) } // Find or create FeatureVersion. var newOrExisting string t = time.Now() - err = tx.QueryRow(getQuery("soi_featureversion"), featureID, &featureVersion.Version). + err = tx.QueryRow(soiFeatureVersion, featureID, &featureVersion.Version). Scan(&newOrExisting, &featureVersion.ID) - observeQueryTime("insertFeatureVersion", "soi_featureversion", t) + observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t) if err != nil { tx.Rollback() - return 0, handleError("soi_featureversion", err) + return 0, handleError("soiFeatureVersion", err) } if newOrExisting == "exi" { @@ -183,9 +183,9 @@ type vulnerabilityAffectsFeatureVersion struct { func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { // Select every vulnerability and the fixed version that affect this Feature. // TODO(Quentin-M): LIMIT - rows, err := tx.Query(getQuery("s_vulnerability_fixedin_feature"), featureVersion.Feature.ID) + rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID) if err != nil { - return handleError("s_vulnerability_fixedin_feature", err) + return handleError("searchVulnerabilityFixedInFeature", err) } defer rows.Close() @@ -195,7 +195,7 @@ func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.Fea err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion) if err != nil { - return handleError("s_vulnerability_fixedin_feature.Scan()", err) + return handleError("searchVulnerabilityFixedInFeature.Scan()", err) } if featureVersion.Version.Compare(affect.fixedInVersion) < 0 { @@ -205,17 +205,17 @@ func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.Fea } } if err = rows.Err(); err != nil { - return handleError("s_vulnerability_fixedin_feature.Rows()", err) + return handleError("searchVulnerabilityFixedInFeature.Rows()", err) } rows.Close() // Insert into Vulnerability_Affects_FeatureVersion. for _, affect := range affects { // TODO(Quentin-M): Batch me. - _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), affect.vulnerabilityID, + _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID, featureVersion.ID, affect.fixedInID) if err != nil { - return handleError("i_vulnerability_affects_featureversion", err) + return handleError("insertVulnerabilityAffectsFeatureVersion", err) } } diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue.go index 68995ed1..264774c7 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue.go @@ -41,9 +41,9 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { for { // First, try to update. - r, err := pgSQL.Exec(getQuery("u_keyvalue"), value, key) + r, err := pgSQL.Exec(updateKeyValue, value, key) if err != nil { - return handleError("u_keyvalue", err) + return handleError("updateKeyValue", err) } if n, _ := r.RowsAffected(); n > 0 { // Updated successfully. @@ -52,13 +52,13 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { // Try to insert the key. // If someone else inserts the same key concurrently, we could get a unique-key violation error. - _, err = pgSQL.Exec(getQuery("i_keyvalue"), key, value) + _, err = pgSQL.Exec(insertKeyValue, key, value) if err != nil { if isErrUniqueViolation(err) { // Got unique constraint violation, retry. continue } - return handleError("i_keyvalue", err) + return handleError("insertKeyValue", err) } return nil @@ -70,13 +70,13 @@ func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { defer observeQueryTime("GetKeyValue", "all", time.Now()) var value string - err := pgSQL.QueryRow(getQuery("s_keyvalue"), key).Scan(&value) + err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value) if err == sql.ErrNoRows { return "", nil } if err != nil { - return "", handleError("s_keyvalue", err) + return "", handleError("searchKeyValue", err) } return value, nil diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 2581d14d..297a86ec 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -41,13 +41,13 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo var namespaceName sql.NullString t := time.Now() - err := pgSQL.QueryRow(getQuery("s_layer"), name). + err := pgSQL.QueryRow(searchLayer, name). Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, &namespaceName) - observeQueryTime("FindLayer", "s_layer", t) + observeQueryTime("FindLayer", "searchLayer", t) if err != nil { - return layer, handleError("s_layer", err) + return layer, handleError("searchLayer", err) } if !parentID.IsZero() { @@ -78,11 +78,11 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo } defer tx.Commit() - _, err = tx.Exec(getQuery("disable_hashjoin")) + _, err = tx.Exec(disableHashJoin) if err != nil { log.Warningf("FindLayer: could not disable hash join: %s", err) } - _, err = tx.Exec(getQuery("disable_mergejoin")) + _, err = tx.Exec(disableMergeJoin) if err != nil { log.Warningf("FindLayer: could not disable merge join: %s", err) } @@ -117,9 +117,9 @@ func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion var featureVersions []database.FeatureVersion // Query. - rows, err := tx.Query(getQuery("s_layer_featureversion"), layerID) + rows, err := tx.Query(searchLayerFeatureVersion, layerID) if err != nil { - return featureVersions, handleError("s_layer_featureversion", err) + return featureVersions, handleError("searchLayerFeatureVersion", err) } defer rows.Close() @@ -134,7 +134,7 @@ func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion &featureVersion.Feature.Name, &featureVersion.ID, &featureVersion.Version, &featureVersion.AddedBy.ID, &featureVersion.AddedBy.Name) if err != nil { - return featureVersions, handleError("s_layer_featureversion.Scan()", err) + return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err) } // Do transitive closure. @@ -149,7 +149,7 @@ func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion } } if err = rows.Err(); err != nil { - return featureVersions, handleError("s_layer_featureversion.Rows()", err) + return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err) } // Build result by converting our map to a slice. @@ -173,10 +173,10 @@ func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID) } - rows, err := tx.Query(getQuery("s_featureversions_vulnerabilities"), + rows, err := tx.Query(searchFeatureVersionVulnerability, buildInputArray(featureVersionIDs)) if err != nil && err != sql.ErrNoRows { - return handleError("s_featureversions_vulnerabilities", err) + return handleError("searchFeatureVersionVulnerability", err) } defer rows.Close() @@ -188,12 +188,12 @@ func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error &vulnerability.Description, &vulnerability.Link, &vulnerability.Severity, &vulnerability.Metadata, &vulnerability.Namespace.Name, &vulnerability.FixedBy) if err != nil { - return handleError("s_featureversions_vulnerabilities.Scan()", err) + return handleError("searchFeatureVersionVulnerability.Scan()", err) } vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability) } if err = rows.Err(); err != nil { - return handleError("s_featureversions_vulnerabilities.Rows()", err) + return handleError("searchFeatureVersionVulnerability.Rows()", err) } // Assign vulnerabilities to every FeatureVersions @@ -271,7 +271,7 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { if layer.ID == 0 { // Insert a new layer. - err = tx.QueryRow(getQuery("i_layer"), layer.Name, layer.EngineVersion, parentID, namespaceID). + err = tx.QueryRow(insertLayer, layer.Name, layer.EngineVersion, parentID, namespaceID). Scan(&layer.ID) if err != nil { tx.Rollback() @@ -280,21 +280,21 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { // Ignore this error, another process collided. return nil } - return handleError("i_layer", err) + return handleError("insertLayer", err) } } else { // Update an existing layer. - _, err = tx.Exec(getQuery("u_layer"), layer.ID, layer.EngineVersion, namespaceID) + _, err = tx.Exec(updateLayer, layer.ID, layer.EngineVersion, namespaceID) if err != nil { tx.Rollback() - return handleError("u_layer", err) + return handleError("updateLayer", err) } // Remove all existing Layer_diff_FeatureVersion. - _, err = tx.Exec(getQuery("r_layer_diff_featureversion"), layer.ID) + _, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID) if err != nil { tx.Rollback() - return handleError("r_layer_diff_featureversion", err) + return handleError("removeLayerDiffFeatureVersion", err) } } @@ -355,15 +355,15 @@ func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer * // Insert diff in the database. if len(addIDs) > 0 { - _, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "add", buildInputArray(addIDs)) + _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "add", buildInputArray(addIDs)) if err != nil { - return handleError("i_layer_diff_featureversion.Add", err) + return handleError("insertLayerDiffFeatureVersion.Add", err) } } if len(delIDs) > 0 { - _, err = tx.Exec(getQuery("i_layer_diff_featureversion"), layer.ID, "del", buildInputArray(delIDs)) + _, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "del", buildInputArray(delIDs)) if err != nil { - return handleError("i_layer_diff_featureversion.Del", err) + return handleError("insertLayerDiffFeatureVersion.Del", err) } } @@ -387,14 +387,14 @@ func createNV(features []database.FeatureVersion) (map[string]*database.FeatureV func (pgSQL *pgSQL) DeleteLayer(name string) error { defer observeQueryTime("DeleteLayer", "all", time.Now()) - result, err := pgSQL.Exec(getQuery("r_layer"), name) + result, err := pgSQL.Exec(removeLayer, name) if err != nil { - return handleError("r_layer", err) + return handleError("removeLayer", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("r_layer.RowsAffected()", err) + return handleError("removeLayer.RowsAffected()", err) } if affected <= 0 { diff --git a/database/pgsql/lock.go b/database/pgsql/lock.go index cbf34e1a..2f491caa 100644 --- a/database/pgsql/lock.go +++ b/database/pgsql/lock.go @@ -37,9 +37,9 @@ func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, rene if renew { // Renew lock. - r, err := pgSQL.Exec(getQuery("u_lock"), name, owner, until) + r, err := pgSQL.Exec(updateLock, name, owner, until) if err != nil { - handleError("u_lock", err) + handleError("updateLock", err) return false, until } if n, _ := r.RowsAffected(); n > 0 { @@ -52,10 +52,10 @@ func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, rene } // Lock. - _, err := pgSQL.Exec(getQuery("i_lock"), name, owner, until) + _, err := pgSQL.Exec(insertLock, name, owner, until) if err != nil { if !isErrUniqueViolation(err) { - handleError("i_lock", err) + handleError("insertLock", err) } return false, until } @@ -72,7 +72,7 @@ func (pgSQL *pgSQL) Unlock(name, owner string) { defer observeQueryTime("Unlock", "all", time.Now()) - pgSQL.Exec(getQuery("r_lock"), name, owner) + pgSQL.Exec(removeLock, name, owner) } // FindLock returns the owner of a lock specified by its name and its @@ -87,9 +87,9 @@ func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { var owner string var until time.Time - err := pgSQL.QueryRow(getQuery("f_lock"), name).Scan(&owner, &until) + err := pgSQL.QueryRow(searchLock, name).Scan(&owner, &until) if err != nil { - return owner, until, handleError("f_lock", err) + return owner, until, handleError("searchLock", err) } return owner, until, nil @@ -99,7 +99,7 @@ func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { func (pgSQL *pgSQL) pruneLocks() { defer observeQueryTime("pruneLocks", "all", time.Now()) - if _, err := pgSQL.Exec(getQuery("r_lock_expired")); err != nil { - handleError("r_lock_expired", err) + if _, err := pgSQL.Exec(removeLockExpired); err != nil { + handleError("removeLockExpired", err) } } diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index f755d3aa..3c85c784 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -38,9 +38,9 @@ func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { defer observeQueryTime("insertNamespace", "all", time.Now()) var id int - err := pgSQL.QueryRow(getQuery("soi_namespace"), namespace.Name).Scan(&id) + err := pgSQL.QueryRow(soiNamespace, namespace.Name).Scan(&id) if err != nil { - return 0, handleError("soi_namespace", err) + return 0, handleError("soiNamespace", err) } if pgSQL.cache != nil { @@ -51,9 +51,9 @@ func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { } func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error) { - rows, err := pgSQL.Query(getQuery("l_namespace")) + rows, err := pgSQL.Query(listNamespace) if err != nil { - return namespaces, handleError("l_namespace", err) + return namespaces, handleError("listNamespace", err) } defer rows.Close() @@ -62,13 +62,13 @@ func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error err = rows.Scan(&namespace.ID, &namespace.Name) if err != nil { - return namespaces, handleError("l_namespace.Scan()", err) + return namespaces, handleError("listNamespace.Scan()", err) } namespaces = append(namespaces, namespace) } if err = rows.Err(); err != nil { - return namespaces, handleError("l_namespace.Rows()", err) + return namespaces, handleError("listNamespace.Rows()", err) } return namespaces, err diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index 3989f360..b7b11aec 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -18,10 +18,10 @@ func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) // Insert Notification. oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0} - _, err := tx.Exec(getQuery("i_notification"), uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) + _, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) if err != nil { tx.Rollback() - return handleError("i_notification", err) + return handleError("insertNotification", err) } return nil @@ -33,19 +33,19 @@ func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (da defer observeQueryTime("GetAvailableNotification", "all", time.Now()) before := time.Now().Add(-renotifyInterval) - row := pgSQL.QueryRow(getQuery("s_notification_available"), before) + row := pgSQL.QueryRow(searchNotificationAvailable, before) notification, err := pgSQL.scanNotification(row, false) - return notification, handleError("s_notification_available", err) + return notification, handleError("searchNotificationAvailable", err) } func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) { defer observeQueryTime("GetNotification", "all", time.Now()) // Get Notification. - notification, err := pgSQL.scanNotification(pgSQL.QueryRow(getQuery("s_notification"), name), true) + notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true) if err != nil { - return notification, page, handleError("s_notification", err) + return notification, page, handleError("searchNotification", err) } // Load vulnerabilities' LayersIntroducingVulnerability. @@ -149,10 +149,10 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf) // Query with limit + 1, the last item will be used to know the next starting ID. - rows, err := pgSQL.Query(getQuery("s_notification_layer_introducing_vulnerability"), + rows, err := pgSQL.Query(searchNotificationLayerIntroducingVulnerability, vulnerability.ID, startID, limit+1) if err != nil { - return 0, handleError("s_vulnerability_fixedin_feature", err) + return 0, handleError("searchVulnerabilityFixedInFeature", err) } defer rows.Close() @@ -161,13 +161,13 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu var layer database.Layer if err := rows.Scan(&layer.ID, &layer.Name); err != nil { - return -1, handleError("s_notification_layer_introducing_vulnerability.Scan()", err) + return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err) } layers = append(layers, layer) } if err = rows.Err(); err != nil { - return -1, handleError("s_notification_layer_introducing_vulnerability.Rows()", err) + return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err) } size := limit @@ -187,8 +187,8 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu func (pgSQL *pgSQL) SetNotificationNotified(name string) error { defer observeQueryTime("SetNotificationNotified", "all", time.Now()) - if _, err := pgSQL.Exec(getQuery("u_notification_notified"), name); err != nil { - return handleError("u_notification_notified", err) + if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil { + return handleError("updatedNotificationNotified", err) } return nil } @@ -196,14 +196,14 @@ func (pgSQL *pgSQL) SetNotificationNotified(name string) error { func (pgSQL *pgSQL) DeleteNotification(name string) error { defer observeQueryTime("DeleteNotification", "all", time.Now()) - result, err := pgSQL.Exec(getQuery("r_notification"), name) + result, err := pgSQL.Exec(removeNotification, name) if err != nil { - return handleError("r_notification", err) + return handleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("r_notification.RowsAffected()", err) + return handleError("removeNotification.RowsAffected()", err) } if affected <= 0 { diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index ae84bb1d..80c61784 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -14,80 +14,72 @@ package pgsql -import ( - "fmt" - "strconv" -) +import "strconv" -var queries map[string]string - -func init() { - queries = make(map[string]string) - - queries["l_vulnerability_affects_featureversion"] = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` - queries["disable_hashjoin"] = `SET LOCAL enable_hashjoin = off` - queries["disable_mergejoin"] = `SET LOCAL enable_mergejoin = off` +const ( + lockVulnerabilityAffects = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` + disableHashJoin = `SET LOCAL enable_hashjoin = off` + disableMergeJoin = `SET LOCAL enable_mergejoin = off` // keyvalue.go - queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2` - queries["i_keyvalue"] = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` - queries["s_keyvalue"] = `SELECT value FROM KeyValue WHERE key = $1` + updateKeyValue = `UPDATE KeyValue SET value = $1 WHERE key = $2` + insertKeyValue = `INSERT INTO KeyValue(key, value) VALUES($1, $2)` + searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1` // namespace.go - queries["soi_namespace"] = ` - WITH new_namespace AS ( - INSERT INTO Namespace(name) - SELECT CAST($1 AS VARCHAR) - WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1) - RETURNING id - ) - SELECT id FROM Namespace WHERE name = $1 - UNION - SELECT id FROM new_namespace` + soiNamespace = ` + WITH new_namespace AS ( + INSERT INTO Namespace(name) + SELECT CAST($1 AS VARCHAR) + WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1) + RETURNING id + ) + SELECT id FROM Namespace WHERE name = $1 + UNION + SELECT id FROM new_namespace` - queries["l_namespace"] = `SELECT id, name FROM Namespace` + listNamespace = `SELECT id, name FROM Namespace` // feature.go - queries["soi_feature"] = ` - WITH new_feature AS ( - INSERT INTO Feature(name, namespace_id) - SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER) - WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2) - RETURNING id - ) - SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2 - UNION - SELECT id FROM new_feature` + soiFeature = ` + WITH new_feature AS ( + INSERT INTO Feature(name, namespace_id) + SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER) + WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2) + RETURNING id + ) + SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2 + UNION + SELECT id FROM new_feature` - queries["soi_featureversion"] = ` - WITH new_featureversion AS ( - INSERT INTO FeatureVersion(feature_id, version) - SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR) - WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2) - RETURNING id - ) - SELECT 'exi', id FROM FeatureVersion WHERE feature_id = $1 AND version = $2 - UNION - SELECT 'new', id FROM new_featureversion - ` + soiFeatureVersion = ` + WITH new_featureversion AS ( + INSERT INTO FeatureVersion(feature_id, version) + SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR) + WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2) + RETURNING id + ) + SELECT 'exi', id FROM FeatureVersion WHERE feature_id = $1 AND version = $2 + UNION + SELECT 'new', id FROM new_featureversion` - queries["s_vulnerability_fixedin_feature"] = ` - SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature + searchVulnerabilityFixedInFeature = ` + SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature WHERE feature_id = $1` - queries["i_vulnerability_affects_featureversion"] = ` - INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, + insertVulnerabilityAffectsFeatureVersion = ` + INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id) VALUES($1, $2, $3)` // layer.go - queries["s_layer"] = ` - SELECT l.id, l.name, l.engineversion, p.id, p.name, n.id, n.name - FROM Layer l - LEFT JOIN Layer p ON l.parent_id = p.id - LEFT JOIN Namespace n ON l.namespace_id = n.id - WHERE l.name = $1;` + searchLayer = ` + SELECT l.id, l.name, l.engineversion, p.id, p.name, n.id, n.name + FROM Layer l + LEFT JOIN Layer p ON l.parent_id = p.id + LEFT JOIN Namespace n ON l.namespace_id = n.id + WHERE l.name = $1;` - queries["s_layer_featureversion"] = ` + searchLayerFeatureVersion = ` WITH RECURSIVE layer_tree(id, name, parent_id, depth, path, cycle) AS( SELECT l.id, l.name, l.parent_id, 1, ARRAY[l.id], false FROM Layer l @@ -105,76 +97,70 @@ func init() { WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id ORDER BY ltree.ordering` - queries["s_featureversions_vulnerabilities"] = ` - SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata, - vn.name, vfif.version - FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v, - Namespace vn, Vulnerability_FixedIn_Feature vfif - WHERE vafv.featureversion_id = ANY($1::integer[]) - AND vfif.vulnerability_id = v.id - AND vafv.fixedin_id = vfif.id - AND v.namespace_id = vn.id - AND v.deleted_at IS NULL` + searchFeatureVersionVulnerability = ` + SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata, + vn.name, vfif.version + FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v, + Namespace vn, Vulnerability_FixedIn_Feature vfif + WHERE vafv.featureversion_id = ANY($1::integer[]) + AND vfif.vulnerability_id = v.id + AND vafv.fixedin_id = vfif.id + AND v.namespace_id = vn.id + AND v.deleted_at IS NULL` - queries["i_layer"] = ` - INSERT INTO Layer(name, engineversion, parent_id, namespace_id, created_at) + insertLayer = ` + INSERT INTO Layer(name, engineversion, parent_id, namespace_id, created_at) VALUES($1, $2, $3, $4, CURRENT_TIMESTAMP) RETURNING id` - queries["u_layer"] = `UPDATE LAYER SET engineversion = $2, namespace_id = $3 WHERE id = $1` + updateLayer = `UPDATE LAYER SET engineversion = $2, namespace_id = $3 WHERE id = $1` - queries["r_layer_diff_featureversion"] = ` - DELETE FROM Layer_diff_FeatureVersion - WHERE layer_id = $1` + removeLayerDiffFeatureVersion = ` + DELETE FROM Layer_diff_FeatureVersion + WHERE layer_id = $1` - queries["i_layer_diff_featureversion"] = ` - INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) - SELECT $1, fv.id, $2 - FROM FeatureVersion fv - WHERE fv.id = ANY($3::integer[])` + insertLayerDiffFeatureVersion = ` + INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) + SELECT $1, fv.id, $2 + FROM FeatureVersion fv + WHERE fv.id = ANY($3::integer[])` - queries["r_layer"] = `DELETE FROM Layer WHERE name = $1` + removeLayer = `DELETE FROM Layer WHERE name = $1` // lock.go - queries["i_lock"] = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)` - - queries["f_lock"] = `SELECT owner, until FROM Lock WHERE name = $1` - - queries["u_lock"] = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2` - - queries["r_lock"] = `DELETE FROM Lock WHERE name = $1 AND owner = $2` - - queries["r_lock_expired"] = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` + insertLock = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)` + searchLock = `SELECT owner, until FROM Lock WHERE name = $1` + updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2` + removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2` + removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` // vulnerability.go - queries["f_vulnerability_base"] = ` - SELECT v.id, v.name, n.id, n.name, v.description, v.link, v.severity, v.metadata - FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` + searchVulnerabilityBase = ` + SELECT v.id, v.name, n.id, n.name, v.description, v.link, v.severity, v.metadata + FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` + searchVulnerabilityForUpdate = ` FOR UPDATE OF v` + searchVulnerabilityByNamespaceAndName = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL` + searchVulnerabilityByID = ` WHERE v.id = $1` - queries["f_vulnerability_for_update"] = ` FOR UPDATE OF v` - queries["f_vulnerability_+by_name_namespace"] = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL` - queries["f_vulnerability_+by_id"] = ` WHERE v.id = $1` + searchVulnerabilityFixedIn = ` + SELECT vfif.version, f.id, f.Name + FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id + WHERE vfif.vulnerability_id = $1` - queries["f_vulnerability_fixedin"] = ` - SELECT vfif.version, f.id, f.Name - FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id - WHERE vfif.vulnerability_id = $1` + insertVulnerability = ` + INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) + VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP) + RETURNING id` - queries["i_vulnerability"] = ` - INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) - VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP) - RETURNING id` + insertVulnerabilityFixedInFeature = ` + INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) + VALUES($1, $2, $3) + RETURNING id` - queries["i_vulnerability_fixedin_feature"] = ` - INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) - VALUES($1, $2, $3) - RETURNING id` + searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = $1` - queries["f_featureversion_by_feature"] = ` - SELECT id, version FROM FeatureVersion WHERE feature_id = $1` - - queries["r_vulnerability"] = ` - UPDATE Vulnerability + removeVulnerability = ` + UPDATE Vulnerability SET deleted_at = CURRENT_TIMESTAMP WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1) AND name = $2 @@ -182,62 +168,55 @@ func init() { RETURNING id` // notification.go - queries["i_notification"] = ` - INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) + insertNotification = ` + INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) VALUES($1, CURRENT_TIMESTAMP, $2, $3)` - queries["u_notification_notified"] = ` - UPDATE Vulnerability_Notification - SET notified_at = CURRENT_TIMESTAMP - WHERE name = $1` + updatedNotificationNotified = ` + UPDATE Vulnerability_Notification + SET notified_at = CURRENT_TIMESTAMP + WHERE name = $1` - queries["r_notification"] = ` - UPDATE Vulnerability_Notification - SET deleted_at = CURRENT_TIMESTAMP - WHERE name = $1` + removeNotification = ` + UPDATE Vulnerability_Notification + SET deleted_at = CURRENT_TIMESTAMP + WHERE name = $1` - queries["s_notification_available"] = ` - SELECT id, name, created_at, notified_at, deleted_at - FROM Vulnerability_Notification - WHERE (notified_at IS NULL OR notified_at < $1) - AND deleted_at IS NULL - AND name NOT IN (SELECT name FROM Lock) - ORDER BY Random() - LIMIT 1` + searchNotificationAvailable = ` + SELECT id, name, created_at, notified_at, deleted_at + FROM Vulnerability_Notification + WHERE (notified_at IS NULL OR notified_at < $1) + AND deleted_at IS NULL + AND name NOT IN (SELECT name FROM Lock) + ORDER BY Random() + LIMIT 1` - queries["s_notification"] = ` - SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id - FROM Vulnerability_Notification - WHERE name = $1` + searchNotification = ` + SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id + FROM Vulnerability_Notification + WHERE name = $1` - queries["s_notification_layer_introducing_vulnerability"] = ` - SELECT l.ID, l.name - FROM Vulnerability v, Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv, Layer l - WHERE v.id = $1 - AND v.id = vafv.vulnerability_id - AND vafv.featureversion_id = fv.id - AND fv.id = ldfv.featureversion_id - AND ldfv.modification = 'add' - AND ldfv.layer_id = l.id - AND l.id >= $2 - ORDER BY l.ID - LIMIT $3` + searchNotificationLayerIntroducingVulnerability = ` + SELECT l.ID, l.name + FROM Vulnerability v, Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv, Layer l + WHERE v.id = $1 + AND v.id = vafv.vulnerability_id + AND vafv.featureversion_id = fv.id + AND fv.id = ldfv.featureversion_id + AND ldfv.modification = 'add' + AND ldfv.layer_id = l.id + AND l.id >= $2 + ORDER BY l.ID + LIMIT $3` // complex_test.go - queries["s_complextest_featureversion_affects"] = ` - SELECT v.name + searchComplexTestFeatureVersionAffects = ` + SELECT v.name FROM FeatureVersion fv LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id JOIN Vulnerability v ON vaf.vulnerability_id = v.id WHERE featureversion_id = $1` -} - -func getQuery(name string) string { - if query, ok := queries[name]; ok { - return query - } - panic(fmt.Sprintf("pgsql: unknown query %v", name)) -} +) // buildInputArray constructs a PostgreSQL input array from the specified integers. // Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index 7e4ff20e..106f6935 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -35,11 +35,11 @@ func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vuln func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { defer observeQueryTime("findVulnerability", "all", time.Now()) - queryName := "f_vulnerability" - query := getQuery("f_vulnerability_base") + getQuery("f_vulnerability_+by_name_namespace") + queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName" + query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName if forUpdate { - queryName = queryName + "+for_update" - query = query + getQuery("f_vulnerability_for_update") + queryName = queryName + "+searchVulnerabilityForUpdate" + query = query + searchVulnerabilityForUpdate } return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name)) @@ -48,8 +48,8 @@ func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bo func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) { defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) - queryName := "f_vulnerability" - query := getQuery("f_vulnerability_base") + getQuery("f_vulnerability_+by_id") + queryName := "searchVulnerabilityBase+searchVulnerabilityByID" + query := searchVulnerabilityBase + searchVulnerabilityByID return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id)) } @@ -77,9 +77,9 @@ func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql. } // Query the FixedIn FeatureVersion now. - rows, err := queryer.Query(getQuery("f_vulnerability_fixedin"), vulnerability.ID) + rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID) if err != nil { - return vulnerability, handleError("f_vulnerability_fixedin.Scan()", err) + return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) } defer rows.Close() @@ -95,7 +95,7 @@ func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql. ) if err != nil { - return vulnerability, handleError("f_vulnerability_fixedin.Scan()", err) + return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) } if !featureVersionID.IsZero() { @@ -115,7 +115,7 @@ func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql. } if err := rows.Err(); err != nil { - return vulnerability, handleError("f_vulnerability_fixedin.Rows()", err) + return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err) } return vulnerability, nil @@ -209,10 +209,10 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, on } // Mark the old vulnerability as non latest. - _, err = tx.Exec(getQuery("r_vulnerability"), vulnerability.Namespace.Name, vulnerability.Name) + _, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name) if err != nil { tx.Rollback() - return handleError("r_vulnerability", err) + return handleError("removeVulnerability", err) } } else { // The vulnerability is new, we don't want to have any types.MinVersion as they are only used @@ -234,7 +234,7 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, on // Insert vulnerability. err = tx.QueryRow( - getQuery("i_vulnerability"), + insertVulnerability, namespaceID, vulnerability.Name, vulnerability.Description, @@ -245,7 +245,7 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, on if err != nil { tx.Rollback() - return handleError("i_vulnerability", err) + return handleError("insertVulnerability", err) } // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. @@ -368,12 +368,12 @@ func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulner promConcurrentLockVAFV.Inc() defer promConcurrentLockVAFV.Dec() t := time.Now() - _, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion")) + _, err = tx.Exec(lockVulnerabilityAffects) observeQueryTime("insertVulnerability", "lock", t) if err != nil { tx.Rollback() - return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err) + return handleError("insertVulnerability.lockVulnerabilityAffects", err) } for _, fv := range fixedIn { @@ -381,13 +381,13 @@ func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulner // Insert Vulnerability_FixedIn_Feature. err = tx.QueryRow( - getQuery("i_vulnerability_fixedin_feature"), + insertVulnerabilityFixedInFeature, vulnerabilityID, fv.Feature.ID, &fv.Version, ).Scan(&fixedInID) if err != nil { - return handleError("i_vulnerability_fixedin_feature", err) + return handleError("insertVulnerabilityFixedInFeature", err) } // Insert Vulnerability_Affects_FeatureVersion. @@ -403,9 +403,9 @@ func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulner func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error { // Find every FeatureVersions of the Feature that the vulnerability affects. // TODO(Quentin-M): LIMIT - rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID) + rows, err := tx.Query(searchFeatureVersionByFeature, featureID) if err != nil { - return handleError("f_featureversion_by_feature", err) + return handleError("searchFeatureVersionByFeature", err) } defer rows.Close() @@ -415,7 +415,7 @@ func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, err := rows.Scan(&affected.ID, &affected.Version) if err != nil { - return handleError("f_featureversion_by_feature.Scan()", err) + return handleError("searchFeatureVersionByFeature.Scan()", err) } if affected.Version.Compare(fixedInVersion) < 0 { @@ -425,17 +425,17 @@ func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, } } if err = rows.Err(); err != nil { - return handleError("f_featureversion_by_feature.Rows()", err) + return handleError("searchFeatureVersionByFeature.Rows()", err) } rows.Close() // Insert into Vulnerability_Affects_FeatureVersion. for _, affected := range affecteds { // TODO(Quentin-M): Batch me. - _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, + _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID) if err != nil { - return handleError("i_vulnerability_affects_featureversion", err) + return handleError("insertVulnerabilityAffectsFeatureVersion", err) } } @@ -491,10 +491,10 @@ func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error { } var vulnerabilityID int - err = tx.QueryRow(getQuery("r_vulnerability"), namespaceName, name).Scan(&vulnerabilityID) + err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID) if err != nil { tx.Rollback() - return handleError("r_vulnerability", err) + return handleError("removeVulnerability", err) } // Create a notification.