From bd17dfb5e11b927e7134998286aff8511e83e954 Mon Sep 17 00:00:00 2001
From: Quentin Machu <me@quentin-machu.fr>
Date: Mon, 18 Jan 2016 18:52:16 -0500
Subject: [PATCH] database: ensure that concurrent vulnerability/feature
 versions insertions work fine

---
 database/pgsql/complex_test.go  | 146 ++++++++++++++++++++++++++++++++
 database/pgsql/feature.go       |   5 +-
 database/pgsql/queries.go       |  10 ++-
 database/pgsql/vulnerability.go |  37 ++++----
 4 files changed, 180 insertions(+), 18 deletions(-)
 create mode 100644 database/pgsql/complex_test.go

diff --git a/database/pgsql/complex_test.go b/database/pgsql/complex_test.go
new file mode 100644
index 00000000..487d8b73
--- /dev/null
+++ b/database/pgsql/complex_test.go
@@ -0,0 +1,146 @@
+package pgsql
+
+import (
+	"fmt"
+	"math/rand"
+	"runtime"
+	"strconv"
+	"sync"
+	"testing"
+	"time"
+
+	"github.com/coreos/clair/database"
+	"github.com/coreos/clair/utils"
+	"github.com/coreos/clair/utils/types"
+	"github.com/pborman/uuid"
+	"github.com/stretchr/testify/assert"
+)
+
+const (
+	numVulnerabilities = 100
+	numFeatureVersions = 100
+)
+
+func TestRaceAffects(t *testing.T) {
+	datastore, err := OpenForTest("TestRaceAffects", false)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+	defer datastore.Close()
+
+	// Insert the Feature on which we'll work.
+	feature := database.Feature{
+		Namespace: database.Namespace{Name: "TestRaceAffectsFeatureNamespace1"},
+		Name:      "TestRaceAffecturesFeature1",
+	}
+	_, err = datastore.insertFeature(feature)
+	if err != nil {
+		t.Error(err)
+		return
+	}
+
+	// Initialize random generator and enforce max procs.
+	rand.Seed(time.Now().UnixNano())
+	runtime.GOMAXPROCS(runtime.NumCPU())
+
+	// Generate FeatureVersions.
+	featureVersions := make([]database.FeatureVersion, numFeatureVersions)
+	for i := 0; i < numFeatureVersions; i++ {
+		version := rand.Intn(numFeatureVersions)
+
+		featureVersions[i] = database.FeatureVersion{
+			Feature: feature,
+			Version: types.NewVersionUnsafe(strconv.Itoa(version)),
+		}
+	}
+
+	// Generate vulnerabilities.
+	// They are mapped by fixed version, which will make verification really easy afterwards.
+	vulnerabilities := make(map[int][]database.Vulnerability)
+	for i := 0; i < numVulnerabilities; i++ {
+		version := rand.Intn(numFeatureVersions) + 1
+
+		// if _, ok := vulnerabilities[version]; !ok {
+		//   vulnerabilities[version] = make([]database.Vulnerability)
+		// }
+
+		vulnerability := database.Vulnerability{
+			Name:      uuid.New(),
+			Namespace: feature.Namespace,
+			FixedIn: []database.FeatureVersion{
+				database.FeatureVersion{
+					Feature: feature,
+					Version: types.NewVersionUnsafe(strconv.Itoa(version)),
+				},
+			},
+			Severity: types.Unknown,
+		}
+
+		vulnerabilities[version] = append(vulnerabilities[version], vulnerability)
+	}
+
+	// Insert featureversions and vulnerabilities in parallel.
+	var wg sync.WaitGroup
+	wg.Add(2)
+
+	go func() {
+		defer wg.Done()
+		for _, vulnerabilitiesM := range vulnerabilities {
+			for _, vulnerability := range vulnerabilitiesM {
+				err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability})
+				assert.Nil(t, err)
+			}
+		}
+		fmt.Println("finished to insert vulnerabilities")
+	}()
+
+	go func() {
+		defer wg.Done()
+		for i := 0; i < len(featureVersions); i++ {
+			featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i])
+			assert.Nil(t, err)
+		}
+		fmt.Println("finished to insert featureVersions")
+	}()
+
+	wg.Wait()
+
+	// Verify consistency now.
+	var actualAffectedNames []string
+	var expectedAffectedNames []string
+
+	for _, featureVersion := range featureVersions {
+		featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String())
+
+		// Get actual affects.
+		rows, err := datastore.Query(getQuery("s_complextest_featureversion_affects"),
+			featureVersion.ID)
+		assert.Nil(t, err)
+		defer rows.Close()
+
+		var vulnName string
+		for rows.Next() {
+			err = rows.Scan(&vulnName)
+			if !assert.Nil(t, err) {
+				continue
+			}
+			actualAffectedNames = append(actualAffectedNames, vulnName)
+		}
+		if assert.Nil(t, rows.Err()) {
+			rows.Close()
+		}
+
+		// Get expected affects.
+		for i := numVulnerabilities; i > featureVersionVersion; i-- {
+			for _, vulnerability := range vulnerabilities[i] {
+				expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name)
+			}
+		}
+
+		assert.Len(t, utils.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0)
+		assert.Len(t, utils.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0)
+	}
+
+	// TODO(Quentin-M): May be worth having a test for updates as well.
+}
diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go
index 76c00e93..d094e41f 100644
--- a/database/pgsql/feature.go
+++ b/database/pgsql/feature.go
@@ -68,10 +68,10 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
 	// Set transaction as SERIALIZABLE.
 	// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
 	// consistent.
-	_, err = tx.Exec(getQuery("set_tx_serializable"))
+	_, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion"))
 	if err != nil {
 		tx.Rollback()
-		return 0, handleError("insertFeatureVersion.set_tx_serializable", err)
+		return 0, handleError("insertFeatureVersion.l_vulnerability_affects_featureversion", err)
 	}
 
 	// Find or create FeatureVersion.
@@ -162,6 +162,7 @@ func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.Fea
 
 	// Insert into Vulnerability_Affects_FeatureVersion.
 	for _, affect := range affects {
+		// TODO(Quentin-M): Batch me.
 		_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), affect.vulnerabilityID,
 			featureVersion.ID, affect.fixedInID)
 		if err != nil {
diff --git a/database/pgsql/queries.go b/database/pgsql/queries.go
index af77993e..dbb19230 100644
--- a/database/pgsql/queries.go
+++ b/database/pgsql/queries.go
@@ -10,7 +10,7 @@ var queries map[string]string
 func init() {
 	queries = make(map[string]string)
 
-	queries["set_tx_serializable"] = `SET TRANSACTION ISOLATION LEVEL SERIALIZABLE`
+	queries["l_vulnerability_affects_featureversion"] = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE`
 
 	// keyvalue.go
 	queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2`
@@ -180,6 +180,14 @@ func init() {
 
 	queries["f_featureversion_by_feature"] = `
     SELECT id, version FROM FeatureVersion WHERE feature_id = $1`
+
+	// complex_test.go
+	queries["s_complextest_featureversion_affects"] = `
+    SELECT v.name
+    FROM FeatureVersion fv
+      LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id
+      JOIN Vulnerability v ON vaf.vulnerability_id = v.id
+    WHERE featureversion_id = $1`
 }
 
 func getQuery(name string) string {
diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go
index 659ca0bc..5556872d 100644
--- a/database/pgsql/vulnerability.go
+++ b/database/pgsql/vulnerability.go
@@ -155,10 +155,10 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er
 	// Set transaction as SERIALIZABLE.
 	// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
 	// consistent.
-	_, err = tx.Exec(getQuery("set_tx_serializable"))
+	_, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion"))
 	if err != nil {
 		tx.Rollback()
-		return handleError("insertFeatureVersion.set_tx_serializable", err)
+		return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err)
 	}
 
 	if existingVulnerability.ID == 0 {
@@ -315,36 +315,43 @@ func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability
 }
 
 func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error {
-	// Find every FeatureVersions of the Feature we want to affect.
+	// Find every FeatureVersions of the Feature that the vulnerability affects.
 	// TODO(Quentin-M): LIMIT
 	rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID)
-	if err == sql.ErrNoRows {
-		return nil
-	}
 	if err != nil {
 		return handleError("f_featureversion_by_feature", err)
 	}
 	defer rows.Close()
 
-	var featureVersionID int
-	var featureVersionVersion types.Version
+	var affecteds []database.FeatureVersion
 	for rows.Next() {
-		err := rows.Scan(&featureVersionID, &featureVersionVersion)
+		var affected database.FeatureVersion
+
+		err := rows.Scan(&affected.ID, &affected.Version)
 		if err != nil {
 			return handleError("f_featureversion_by_feature.Scan()", err)
 		}
 
-		if featureVersionVersion.Compare(fixedInVersion) < 0 {
-			_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, featureVersionID,
-				fixedInID)
-			if err != nil {
-				return handleError("i_vulnerability_affects_featureversion", err)
-			}
+		if affected.Version.Compare(fixedInVersion) < 0 {
+			// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
+			// thus, this FeatureVersion is affected by it.
+			affecteds = append(affecteds, affected)
 		}
 	}
 	if err = rows.Err(); err != nil {
 		return handleError("f_featureversion_by_feature.Rows()", err)
 	}
+	rows.Close()
+
+	// Insert into Vulnerability_Affects_FeatureVersion.
+	for _, affected := range affecteds {
+		// TODO(Quentin-M): Batch me.
+		_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID,
+			affected.ID, fixedInID)
+		if err != nil {
+			return handleError("i_vulnerability_affects_featureversion", err)
+		}
+	}
 
 	return nil
 }