diff --git a/database/database.go b/database/database.go index 2001d1f0..b689dc0c 100644 --- a/database/database.go +++ b/database/database.go @@ -50,6 +50,9 @@ type Datastore interface { FindVulnerability(namespaceName, name string) (Vulnerability, error) DeleteVulnerability(namespaceName, name string) error + InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error + DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error + // Notifications GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) // Does not fill old/new Vulnerabilities. GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 06c93d82..19cf5ebc 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -74,6 +74,11 @@ func init() { prometheus.MustRegister(promConcurrentLockVAFV) } +type Queryer interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + type pgSQL struct { *sql.DB cache *lru.ARCCache diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index bd5f0db1..62d2a99c 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -145,6 +145,14 @@ func init() { // vulnerability.go queries["f_vulnerability"] = ` SELECT v.id, n.id, v.description, v.link, v.severity, v.metadata, vfif.version, f.id, f.Name + FROM Vulnerability v + JOIN Namespace n ON v.namespace_id = n.id + LEFT JOIN Vulnerability_FixedIn_Feature vfif ON v.id = vfif.vulnerability_id + LEFT JOIN Feature f ON vfif.feature_id = f.id + WHERE n.Name = $1 AND v.Name = $2` + + queries["f_vulnerability_for_update"] = ` + SELECT FOR UPDATE v.id, n.id, v.description, v.link, v.severity, v.metadata, vfif.version, f.id, f.Name FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id LEFT JOIN Vulnerability_FixedIn_Feature vfif ON v.id = vfif.vulnerability_id diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index 4035e258..8525aab4 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -29,7 +29,11 @@ import ( ) func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { - defer observeQueryTime("FindVulnerability", "all", time.Now()) + return findVulnerability(pgSQL, namespaceName, name, false) +} + +func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { + defer observeQueryTime("findVulnerability", "all", time.Now()) vulnerability := database.Vulnerability{ Name: name, @@ -39,9 +43,14 @@ func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vuln } // Find Vulnerability. - rows, err := pgSQL.Query(getQuery("f_vulnerability"), namespaceName, name) + queryName := "f_vulnerability" + if forUpdate { + queryName = "f_vulnerability_for_update" + } + + rows, err := queryer.Query(getQuery(queryName), namespaceName, name) if err != nil { - return vulnerability, handleError("f_vulnerability", err) + return vulnerability, handleError(queryName, err) } defer rows.Close() @@ -55,7 +64,7 @@ func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vuln &vulnerability.Link, &vulnerability.Severity, &vulnerability.Metadata, &featureVersionVersion, &featureVersionID, &featureVersionFeatureName) if err != nil { - return vulnerability, handleError("f_vulnerability.Scan()", err) + return vulnerability, handleError(queryName+".Scan()", err) } if !featureVersionID.IsZero() { @@ -107,8 +116,9 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er } for _, fixedInFeatureVersion := range vulnerability.FixedIn { - if fixedInFeatureVersion.Feature.Namespace.Name != "" && - fixedInFeatureVersion.Feature.Namespace.Name != vulnerability.Namespace.Name { + if fixedInFeatureVersion.Feature.Namespace.Name == "" { + fixedInFeatureVersion.Feature.Namespace.Name = vulnerability.Namespace.Name + } else if fixedInFeatureVersion.Feature.Namespace.Name != vulnerability.Namespace.Name { msg := "could not insert an invalid vulnerability: FixedIn FeatureVersion must be in the " + "same namespace as the Vulnerability" log.Warning(msg) @@ -116,61 +126,15 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er } } + // We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities. + defer observeQueryTime("insertVulnerability", "all", tf) + // Find or insert Vulnerability's Namespace. namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace) if err != nil { return err } - // Find vulnerability and its Vulnerability_FixedIn_Features. - existingVulnerability, err := pgSQL.FindVulnerability(vulnerability.Namespace.Name, - vulnerability.Name) - if err != nil && err != cerrors.ErrNotFound { - return err - } - - // Compute new/updated FixedIn FeatureVersions. - var newFixedInFeatureVersions []database.FeatureVersion - var updatedFixedInFeatureVersions []database.FeatureVersion - if existingVulnerability.ID == 0 { - newFixedInFeatureVersions = vulnerability.FixedIn - } else { - newFixedInFeatureVersions, updatedFixedInFeatureVersions = diffFixedIn(vulnerability, - existingVulnerability) - - if vulnerability.Description == existingVulnerability.Description && - vulnerability.Link == existingVulnerability.Link && - vulnerability.Severity == existingVulnerability.Severity && - reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata) && - len(newFixedInFeatureVersions) == 0 && - len(updatedFixedInFeatureVersions) == 0 { - - // Nothing to do. - return nil - } - } - - // We do `defer observeQueryTime` here because we don't want to observe existing & up-to-date - // vulnerabilities. - defer observeQueryTime("insertVulnerability", "all", tf) - - // Insert or find the new Features. - // We already have the Feature IDs in updatedFixedInFeatureVersions because diffFixedIn fills them - // in using the existing vulnerability's FixedIn FeatureVersions. Note that even if FixedIn - // is type FeatureVersion, the actual stored ID in these structs are the Feature IDs. - // - // Also, we enforce the namespace of the FeatureVersion in case it was empty. There is a test - // above to ensure that the passed Namespace is either the same as the vulnerability or empty. - // - // TODO(Quentin-M): Batch me. - for i := 0; i < len(newFixedInFeatureVersions); i++ { - newFixedInFeatureVersions[i].Feature.Namespace.Name = vulnerability.Namespace.Name - newFixedInFeatureVersions[i].Feature.ID, err = pgSQL.insertFeature(newFixedInFeatureVersions[i].Feature) - if err != nil { - return err - } - } - // Begin transaction. tx, err := pgSQL.Begin() if err != nil { @@ -178,21 +142,17 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er return handleError("insertVulnerability.Begin()", err) } - // Lock Vulnerability_Affects_FeatureVersion exclusively. - // We want to prevent InsertFeatureVersion to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() - t := time.Now() - _, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion")) - observeQueryTime("insertVulnerability", "lock", t) - - if err != nil { + // Find vulnerability and its Vulnerability_FixedIn_Features. + existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, + vulnerability.Name, true) + if err != nil && err != cerrors.ErrNotFound { tx.Rollback() - return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err) + return err } + // Insert or update vulnerability. if existingVulnerability.ID == 0 { - // Insert new vulnerability. + // The vulnerability is a new one, insert it. err = tx.QueryRow(getQuery("i_vulnerability"), namespaceID, vulnerability.Name, vulnerability.Description, vulnerability.Link, &vulnerability.Severity, &vulnerability.Metadata).Scan(&vulnerability.ID) @@ -201,10 +161,11 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er return handleError("i_vulnerability", err) } } else { - // Update vulnerability + // The vulnerability exists, update it. if vulnerability.Description != existingVulnerability.Description || vulnerability.Link != existingVulnerability.Link || - vulnerability.Severity != existingVulnerability.Severity { + vulnerability.Severity != existingVulnerability.Severity || + !reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata) { _, err = tx.Exec(getQuery("u_vulnerability"), existingVulnerability.ID, vulnerability.Description, vulnerability.Link, &vulnerability.Severity, &vulnerability.Metadata) @@ -217,12 +178,22 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er vulnerability.ID = existingVulnerability.ID } - // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. - t = time.Now() - err = pgSQL.updateVulnerabilityFeatureVersions(tx, &vulnerability, &existingVulnerability, newFixedInFeatureVersions, updatedFixedInFeatureVersions) - observeQueryTime("insertVulnerability", "updateVulnerabilityFeatureVersions", t) + // Get the new/updated/removed FeatureVersions and the resulting full list. + var newFIFV, updatedFIFV, removedFIFV []database.FeatureVersion + if existingVulnerability.ID == 0 { + // The vulnerability is a new new, the new FeatureVersions are the entire list of FixedIn. + newFIFV = vulnerability.FixedIn + } else { + // The vulnerability exists, compute the lists using diffFixedIn. + // We overwrite vulnerability.FixedIn with the entire list of FixedIn FeatureVersions, we'll + // then use the vulnerability in the notification, with that list instead of a potential diff. + newFIFV, updatedFIFV, removedFIFV, vulnerability.FixedIn = + diffFixedIn(existingVulnerability.FixedIn, vulnerability.FixedIn) + } - if err != nil { + // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. + if err = pgSQL.updateVulnerabilityFeatureVersions(tx, vulnerability.ID, newFIFV, updatedFIFV, + removedFIFV); err != nil { tx.Rollback() return err } @@ -260,10 +231,11 @@ func castMetadata(m database.MetadataMap) database.MetadataMap { return c } -func diffFixedIn(vulnerability, existingVulnerability database.Vulnerability) (newFixedIn, updatedFixedIn []database.FeatureVersion) { +func diffFixedIn(existingFIFVList, newFIFVList []database.FeatureVersion) (newFIFV, updatedFIFV, removedFIFV, allFIFV []database.FeatureVersion) { // Build FeatureVersion.Feature.Namespace.Name:FeatureVersion.Feature.Name (NaN) structures. - vulnerabilityFixedInNameMap, vulnerabilityFixedInNameSlice := createFeatureVersionNameMap(vulnerability.FixedIn) - existingFixedInMapNameMap, existingFixedInNameSlice := createFeatureVersionNameMap(existingVulnerability.FixedIn) + allFIFVMap, _ := createFeatureVersionNameMap(existingFIFVList) + vulnerabilityFixedInNameMap, vulnerabilityFixedInNameSlice := createFeatureVersionNameMap(newFIFVList) + existingFixedInMapNameMap, existingFixedInNameSlice := createFeatureVersionNameMap(existingFIFVList) // Calculate the new FixedIn FeatureVersion NaN and updated ones. newFixedInName := utils.CompareStringLists(vulnerabilityFixedInNameSlice, @@ -280,18 +252,31 @@ func diffFixedIn(vulnerability, existingVulnerability database.Vulnerability) (n continue } - newFixedIn = append(newFixedIn, fv) + newFIFV = append(newFIFV, fv) + allFIFVMap[fv.Feature.Namespace.Name+":"+fv.Feature.Name] = fv } + for _, nan := range updatedFixedInName { fv := existingFixedInMapNameMap[nan] fv.Version = vulnerabilityFixedInNameMap[nan].Version + if existingFixedInMapNameMap[nan].Version == fv.Version { // Versions are actually the same! // Even though they appear in both lists, it's not an update. continue } - updatedFixedIn = append(updatedFixedIn, fv) + if fv.Version != types.MinVersion { + updatedFIFV = append(updatedFIFV, fv) + allFIFVMap[fv.Feature.Namespace.Name+":"+fv.Feature.Name] = fv + } else { + removedFIFV = append(removedFIFV, fv) + delete(allFIFVMap, fv.Feature.Namespace.Name+":"+fv.Feature.Name) + } + } + + for _, fv := range allFIFVMap { + allFIFV = append(allFIFV, fv) } return @@ -310,56 +295,192 @@ func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string return m, s } -func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability, existingVulnerability *database.Vulnerability, newFixedInFeatureVersions, updatedFixedInFeatureVersions []database.FeatureVersion) error { +func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { + // Verify parameters + for _, fifv := range fixes { + if fifv.Feature.Namespace.Name == "" { + fifv.Feature.Namespace.Name = vulnerabilityNamespace + } else if fifv.Feature.Namespace.Name != vulnerabilityNamespace { + msg := "could not add/update a FixedIn FeatureVersion: FixedIn FeatureVersion must be in the " + + "same namespace as the Vulnerability" + log.Warning(msg) + return cerrors.NewBadRequestError(msg) + } + } + + f := func(vulnerability database.Vulnerability) (newFIFV, updatedFIFV, removedFIFV, allFIFV []database.FeatureVersion, err error) { + newFIFV, updatedFIFV, _, allFIFV = diffFixedIn(vulnerability.FixedIn, fixes) + return + } + + return pgSQL.doVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName, f) +} + +func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { + f := func(vulnerability database.Vulnerability) (newFIFV, updatedFIFV, removedFIFV, allFIFV []database.FeatureVersion, err error) { + // Search the specified featureName. + for i, vulnerabilityFV := range vulnerability.FixedIn { + if vulnerabilityFV.Feature.Name == featureName { + removedFIFV = append(removedFIFV, vulnerabilityFV) + allFIFV = append(vulnerability.FixedIn[:i], vulnerability.FixedIn[i+1:]...) + return + } + } + + err = cerrors.ErrNotFound + return + } + + return pgSQL.doVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName, f) +} + +// doVulnerabilityFixes is used by InsertVulnerabilityFixes and DeleteVulnerabilityFix. It +// adds/updates/removes FeatureVersions on the specified vulnerability using +// updateVulnerabilityFeatureVersions and creates a database.VulnerabilityNotification. +func (pgSQL *pgSQL) doVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, f func(vulnerability database.Vulnerability) (newFIFV, updatedFIFV, removedFIFV, allFIFV []database.FeatureVersion, err error)) error { + // Begin transaction. + tx, err := pgSQL.Begin() + if err != nil { + tx.Rollback() + return handleError("doVulnerabilityFixes.Begin()", err) + } + + // Select for update the vulnerability in order to prevent everyone else from executing updates + // on the vulnerability (and consequently on Vulnerability_FixedIn_Feature for that particular + // vulnerability) + vulnerability, err := findVulnerability(tx, vulnerabilityNamespace, vulnerabilityName, true) + if err != nil { + tx.Rollback() + return err + } + + // Get the new/updated/removed FeatureVersions and the resulting full list, using the given fct. + newFIFV, updatedFIFV, removedFIFV, allFIFV, err := f(vulnerability) + if err != nil { + tx.Rollback() + return err + } + if len(newFIFV) == 0 && len(updatedFIFV) == 0 && len(removedFIFV) == 0 { + // Nothing to do. + tx.Commit() + return nil + } + + // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. + err = pgSQL.updateVulnerabilityFeatureVersions(tx, vulnerability.ID, newFIFV, updatedFIFV, + removedFIFV) + if err != nil { + tx.Rollback() + return err + } + + // Create notification. + newVulnerability := vulnerability + newVulnerability.FixedIn = allFIFV + + notification := database.VulnerabilityNotification{ + NewVulnerability: newVulnerability, + OldVulnerability: &vulnerability, + } + + if err := pgSQL.insertNotification(tx, notification); err != nil { + return err + } + + // Commit transaction. + err = tx.Commit() + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.Commit()", err) + } + + return nil +} + +func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerabilityID int, newFIFV, updatedFIFV, removedFIFV []database.FeatureVersion) error { + defer observeQueryTime("updateVulnerabilityFeatureVersions", "all", time.Now()) + + // Insert or find the Features. + // TODO(Quentin-M): Batch me. + var err error + var features []*database.Feature + for _, fv := range newFIFV { + features = append(features, &fv.Feature) + } + for _, fv := range updatedFIFV { + features = append(features, &fv.Feature) + } + for _, fv := range removedFIFV { + features = append(features, &fv.Feature) + } + for _, feature := range features { + if feature.ID == 0 { + if feature.ID, err = pgSQL.insertFeature(*feature); err != nil { + return err + } + } + } + + // Lock Vulnerability_Affects_FeatureVersion exclusively. + // We want to prevent InsertFeatureVersion to modify it. + promConcurrentLockVAFV.Inc() + defer promConcurrentLockVAFV.Dec() + t := time.Now() + _, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion")) + observeQueryTime("insertVulnerability", "lock", t) + + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err) + } + var fixedInID int - for _, fv := range newFixedInFeatureVersions { + for _, fv := range newFIFV { // Insert Vulnerability_FixedIn_Feature. - err := tx.QueryRow(getQuery("i_vulnerability_fixedin_feature"), vulnerability.ID, fv.Feature.ID, + err = tx.QueryRow(getQuery("i_vulnerability_fixedin_feature"), vulnerabilityID, fv.Feature.ID, &fv.Version).Scan(&fixedInID) if err != nil { return handleError("i_vulnerability_fixedin_feature", err) } // Insert Vulnerability_Affects_FeatureVersion. - err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerability.ID, fv.Feature.ID, + err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Version) if err != nil { return err } } - for _, fv := range updatedFixedInFeatureVersions { - if fv.Version != types.MinVersion { - // Update Vulnerability_FixedIn_Feature. - err := tx.QueryRow(getQuery("u_vulnerability_fixedin_feature"), vulnerability.ID, - fv.Feature.ID, &fv.Version).Scan(&fixedInID) - if err != nil { - return handleError("u_vulnerability_fixedin_feature", err) - } + for _, fv := range updatedFIFV { + // Update Vulnerability_FixedIn_Feature. + err = tx.QueryRow(getQuery("u_vulnerability_fixedin_feature"), vulnerabilityID, + fv.Feature.ID, &fv.Version).Scan(&fixedInID) + if err != nil { + return handleError("u_vulnerability_fixedin_feature", err) + } - // Drop all old Vulnerability_Affects_FeatureVersion. - _, err = tx.Exec(getQuery("r_vulnerability_affects_featureversion"), fixedInID) - if err != nil { - return handleError("r_vulnerability_affects_featureversion", err) - } + // Drop all old Vulnerability_Affects_FeatureVersion. + _, err = tx.Exec(getQuery("r_vulnerability_affects_featureversion"), fixedInID) + if err != nil { + return handleError("r_vulnerability_affects_featureversion", err) + } - // Insert Vulnerability_Affects_FeatureVersion. - err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerability.ID, fv.Feature.ID, - fv.Version) - if err != nil { - return err - } - } else { - // Updating FixedIn by saying that the fixed version is the lowest possible version, it - // basically means that the vulnerability doesn't affect the feature (anymore). - // Drop it from Vulnerability_FixedIn_Feature and let it cascade to - // Vulnerability_Affects_FeatureVersion. - err := tx.QueryRow(getQuery("r_vulnerability_fixedin_feature"), vulnerability.ID, - fv.Feature.ID).Scan(&fixedInID) - if err != nil && err != sql.ErrNoRows { - return handleError("r_vulnerability_fixedin_feature", err) - } + // Insert Vulnerability_Affects_FeatureVersion. + err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, + fv.Version) + if err != nil { + return err + } + } + + for _, fv := range removedFIFV { + // Drop it from Vulnerability_FixedIn_Feature and let it cascade to + // Vulnerability_Affects_FeatureVersion. + err = tx.QueryRow(getQuery("r_vulnerability_fixedin_feature"), vulnerabilityID, + fv.Feature.ID).Scan(&fixedInID) + if err != nil && err != sql.ErrNoRows { + return handleError("r_vulnerability_fixedin_feature", err) } }