diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go new file mode 100644 index 00000000..487d8b73 --- /dev/null +++ b/database/pgsql/complex_test.go @@ -0,0 +1,146 @@ +package pgsql + +import ( + "fmt" + "math/rand" + "runtime" + "strconv" + "sync" + "testing" + "time" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils" + "github.com/coreos/clair/utils/types" + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" +) + +const ( + numVulnerabilities = 100 + numFeatureVersions = 100 +) + +func TestRaceAffects(t *testing.T) { + datastore, err := OpenForTest("TestRaceAffects", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Insert the Feature on which we'll work. + feature := database.Feature{ + Namespace: database.Namespace{Name: "TestRaceAffectsFeatureNamespace1"}, + Name: "TestRaceAffecturesFeature1", + } + _, err = datastore.insertFeature(feature) + if err != nil { + t.Error(err) + return + } + + // Initialize random generator and enforce max procs. + rand.Seed(time.Now().UnixNano()) + runtime.GOMAXPROCS(runtime.NumCPU()) + + // Generate FeatureVersions. + featureVersions := make([]database.FeatureVersion, numFeatureVersions) + for i := 0; i < numFeatureVersions; i++ { + version := rand.Intn(numFeatureVersions) + + featureVersions[i] = database.FeatureVersion{ + Feature: feature, + Version: types.NewVersionUnsafe(strconv.Itoa(version)), + } + } + + // Generate vulnerabilities. + // They are mapped by fixed version, which will make verification really easy afterwards. + vulnerabilities := make(map[int][]database.Vulnerability) + for i := 0; i < numVulnerabilities; i++ { + version := rand.Intn(numFeatureVersions) + 1 + + // if _, ok := vulnerabilities[version]; !ok { + // vulnerabilities[version] = make([]database.Vulnerability) + // } + + vulnerability := database.Vulnerability{ + Name: uuid.New(), + Namespace: feature.Namespace, + FixedIn: []database.FeatureVersion{ + database.FeatureVersion{ + Feature: feature, + Version: types.NewVersionUnsafe(strconv.Itoa(version)), + }, + }, + Severity: types.Unknown, + } + + vulnerabilities[version] = append(vulnerabilities[version], vulnerability) + } + + // Insert featureversions and vulnerabilities in parallel. + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for _, vulnerabilitiesM := range vulnerabilities { + for _, vulnerability := range vulnerabilitiesM { + err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}) + assert.Nil(t, err) + } + } + fmt.Println("finished to insert vulnerabilities") + }() + + go func() { + defer wg.Done() + for i := 0; i < len(featureVersions); i++ { + featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i]) + assert.Nil(t, err) + } + fmt.Println("finished to insert featureVersions") + }() + + wg.Wait() + + // Verify consistency now. + var actualAffectedNames []string + var expectedAffectedNames []string + + for _, featureVersion := range featureVersions { + featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String()) + + // Get actual affects. + rows, err := datastore.Query(getQuery("s_complextest_featureversion_affects"), + featureVersion.ID) + assert.Nil(t, err) + defer rows.Close() + + var vulnName string + for rows.Next() { + err = rows.Scan(&vulnName) + if !assert.Nil(t, err) { + continue + } + actualAffectedNames = append(actualAffectedNames, vulnName) + } + if assert.Nil(t, rows.Err()) { + rows.Close() + } + + // Get expected affects. + for i := numVulnerabilities; i > featureVersionVersion; i-- { + for _, vulnerability := range vulnerabilities[i] { + expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name) + } + } + + assert.Len(t, utils.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) + assert.Len(t, utils.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) + } + + // TODO(Quentin-M): May be worth having a test for updates as well. +} diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index 76c00e93..d094e41f 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -68,10 +68,10 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) // Set transaction as SERIALIZABLE. // This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always // consistent. - _, err = tx.Exec(getQuery("set_tx_serializable")) + _, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion")) if err != nil { tx.Rollback() - return 0, handleError("insertFeatureVersion.set_tx_serializable", err) + return 0, handleError("insertFeatureVersion.l_vulnerability_affects_featureversion", err) } // Find or create FeatureVersion. @@ -162,6 +162,7 @@ func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.Fea // Insert into Vulnerability_Affects_FeatureVersion. for _, affect := range affects { + // TODO(Quentin-M): Batch me. _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), affect.vulnerabilityID, featureVersion.ID, affect.fixedInID) if err != nil { diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go index af77993e..dbb19230 100644 --- a/database/pgsql/queries.go +++ b/database/pgsql/queries.go @@ -10,7 +10,7 @@ var queries map[string]string func init() { queries = make(map[string]string) - queries["set_tx_serializable"] = `SET TRANSACTION ISOLATION LEVEL SERIALIZABLE` + queries["l_vulnerability_affects_featureversion"] = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE` // keyvalue.go queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2` @@ -180,6 +180,14 @@ func init() { queries["f_featureversion_by_feature"] = ` SELECT id, version FROM FeatureVersion WHERE feature_id = $1` + + // complex_test.go + queries["s_complextest_featureversion_affects"] = ` + SELECT v.name + FROM FeatureVersion fv + LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id + JOIN Vulnerability v ON vaf.vulnerability_id = v.id + WHERE featureversion_id = $1` } func getQuery(name string) string { diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index 659ca0bc..5556872d 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -155,10 +155,10 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er // Set transaction as SERIALIZABLE. // This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always // consistent. - _, err = tx.Exec(getQuery("set_tx_serializable")) + _, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion")) if err != nil { tx.Rollback() - return handleError("insertFeatureVersion.set_tx_serializable", err) + return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err) } if existingVulnerability.ID == 0 { @@ -315,36 +315,43 @@ func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability } func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error { - // Find every FeatureVersions of the Feature we want to affect. + // Find every FeatureVersions of the Feature that the vulnerability affects. // TODO(Quentin-M): LIMIT rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID) - if err == sql.ErrNoRows { - return nil - } if err != nil { return handleError("f_featureversion_by_feature", err) } defer rows.Close() - var featureVersionID int - var featureVersionVersion types.Version + var affecteds []database.FeatureVersion for rows.Next() { - err := rows.Scan(&featureVersionID, &featureVersionVersion) + var affected database.FeatureVersion + + err := rows.Scan(&affected.ID, &affected.Version) if err != nil { return handleError("f_featureversion_by_feature.Scan()", err) } - if featureVersionVersion.Compare(fixedInVersion) < 0 { - _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, featureVersionID, - fixedInID) - if err != nil { - return handleError("i_vulnerability_affects_featureversion", err) - } + if affected.Version.Compare(fixedInVersion) < 0 { + // The version of the FeatureVersion is lower than the fixed version of this vulnerability, + // thus, this FeatureVersion is affected by it. + affecteds = append(affecteds, affected) } } if err = rows.Err(); err != nil { return handleError("f_featureversion_by_feature.Rows()", err) } + rows.Close() + + // Insert into Vulnerability_Affects_FeatureVersion. + for _, affected := range affecteds { + // TODO(Quentin-M): Batch me. + _, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, + affected.ID, fixedInID) + if err != nil { + return handleError("i_vulnerability_affects_featureversion", err) + } + } return nil }