database: ensure that concurrent vulnerability/feature versions insertions work fine

This commit is contained in:
Quentin Machu 2016-01-18 18:52:16 -05:00 committed by Jimmy Zelinskie
parent 74fc5b3e66
commit bd17dfb5e1
4 changed files with 180 additions and 18 deletions

View File

@ -0,0 +1,146 @@
package pgsql
import (
"fmt"
"math/rand"
"runtime"
"strconv"
"sync"
"testing"
"time"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
"github.com/coreos/clair/utils/types"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
)
const (
numVulnerabilities = 100
numFeatureVersions = 100
)
func TestRaceAffects(t *testing.T) {
datastore, err := OpenForTest("TestRaceAffects", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Insert the Feature on which we'll work.
feature := database.Feature{
Namespace: database.Namespace{Name: "TestRaceAffectsFeatureNamespace1"},
Name: "TestRaceAffecturesFeature1",
}
_, err = datastore.insertFeature(feature)
if err != nil {
t.Error(err)
return
}
// Initialize random generator and enforce max procs.
rand.Seed(time.Now().UnixNano())
runtime.GOMAXPROCS(runtime.NumCPU())
// Generate FeatureVersions.
featureVersions := make([]database.FeatureVersion, numFeatureVersions)
for i := 0; i < numFeatureVersions; i++ {
version := rand.Intn(numFeatureVersions)
featureVersions[i] = database.FeatureVersion{
Feature: feature,
Version: types.NewVersionUnsafe(strconv.Itoa(version)),
}
}
// Generate vulnerabilities.
// They are mapped by fixed version, which will make verification really easy afterwards.
vulnerabilities := make(map[int][]database.Vulnerability)
for i := 0; i < numVulnerabilities; i++ {
version := rand.Intn(numFeatureVersions) + 1
// if _, ok := vulnerabilities[version]; !ok {
// vulnerabilities[version] = make([]database.Vulnerability)
// }
vulnerability := database.Vulnerability{
Name: uuid.New(),
Namespace: feature.Namespace,
FixedIn: []database.FeatureVersion{
database.FeatureVersion{
Feature: feature,
Version: types.NewVersionUnsafe(strconv.Itoa(version)),
},
},
Severity: types.Unknown,
}
vulnerabilities[version] = append(vulnerabilities[version], vulnerability)
}
// Insert featureversions and vulnerabilities in parallel.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for _, vulnerabilitiesM := range vulnerabilities {
for _, vulnerability := range vulnerabilitiesM {
err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability})
assert.Nil(t, err)
}
}
fmt.Println("finished to insert vulnerabilities")
}()
go func() {
defer wg.Done()
for i := 0; i < len(featureVersions); i++ {
featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i])
assert.Nil(t, err)
}
fmt.Println("finished to insert featureVersions")
}()
wg.Wait()
// Verify consistency now.
var actualAffectedNames []string
var expectedAffectedNames []string
for _, featureVersion := range featureVersions {
featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String())
// Get actual affects.
rows, err := datastore.Query(getQuery("s_complextest_featureversion_affects"),
featureVersion.ID)
assert.Nil(t, err)
defer rows.Close()
var vulnName string
for rows.Next() {
err = rows.Scan(&vulnName)
if !assert.Nil(t, err) {
continue
}
actualAffectedNames = append(actualAffectedNames, vulnName)
}
if assert.Nil(t, rows.Err()) {
rows.Close()
}
// Get expected affects.
for i := numVulnerabilities; i > featureVersionVersion; i-- {
for _, vulnerability := range vulnerabilities[i] {
expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name)
}
}
assert.Len(t, utils.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0)
assert.Len(t, utils.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0)
}
// TODO(Quentin-M): May be worth having a test for updates as well.
}

View File

@ -68,10 +68,10 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
// Set transaction as SERIALIZABLE.
// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
// consistent.
_, err = tx.Exec(getQuery("set_tx_serializable"))
_, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion"))
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.set_tx_serializable", err)
return 0, handleError("insertFeatureVersion.l_vulnerability_affects_featureversion", err)
}
// Find or create FeatureVersion.
@ -162,6 +162,7 @@ func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.Fea
// Insert into Vulnerability_Affects_FeatureVersion.
for _, affect := range affects {
// TODO(Quentin-M): Batch me.
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), affect.vulnerabilityID,
featureVersion.ID, affect.fixedInID)
if err != nil {

View File

@ -10,7 +10,7 @@ var queries map[string]string
func init() {
queries = make(map[string]string)
queries["set_tx_serializable"] = `SET TRANSACTION ISOLATION LEVEL SERIALIZABLE`
queries["l_vulnerability_affects_featureversion"] = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE`
// keyvalue.go
queries["u_keyvalue"] = `UPDATE KeyValue SET value = $1 WHERE key = $2`
@ -180,6 +180,14 @@ func init() {
queries["f_featureversion_by_feature"] = `
SELECT id, version FROM FeatureVersion WHERE feature_id = $1`
// complex_test.go
queries["s_complextest_featureversion_affects"] = `
SELECT v.name
FROM FeatureVersion fv
LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id
JOIN Vulnerability v ON vaf.vulnerability_id = v.id
WHERE featureversion_id = $1`
}
func getQuery(name string) string {

View File

@ -155,10 +155,10 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability) er
// Set transaction as SERIALIZABLE.
// This is how we ensure that the data in Vulnerability_Affects_FeatureVersion is always
// consistent.
_, err = tx.Exec(getQuery("set_tx_serializable"))
_, err = tx.Exec(getQuery("l_vulnerability_affects_featureversion"))
if err != nil {
tx.Rollback()
return handleError("insertFeatureVersion.set_tx_serializable", err)
return handleError("insertVulnerability.l_vulnerability_affects_featureversion", err)
}
if existingVulnerability.ID == 0 {
@ -315,36 +315,43 @@ func (pgSQL *pgSQL) updateVulnerabilityFeatureVersions(tx *sql.Tx, vulnerability
}
func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error {
// Find every FeatureVersions of the Feature we want to affect.
// Find every FeatureVersions of the Feature that the vulnerability affects.
// TODO(Quentin-M): LIMIT
rows, err := tx.Query(getQuery("f_featureversion_by_feature"), featureID)
if err == sql.ErrNoRows {
return nil
}
if err != nil {
return handleError("f_featureversion_by_feature", err)
}
defer rows.Close()
var featureVersionID int
var featureVersionVersion types.Version
var affecteds []database.FeatureVersion
for rows.Next() {
err := rows.Scan(&featureVersionID, &featureVersionVersion)
var affected database.FeatureVersion
err := rows.Scan(&affected.ID, &affected.Version)
if err != nil {
return handleError("f_featureversion_by_feature.Scan()", err)
}
if featureVersionVersion.Compare(fixedInVersion) < 0 {
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID, featureVersionID,
fixedInID)
if err != nil {
return handleError("i_vulnerability_affects_featureversion", err)
}
if affected.Version.Compare(fixedInVersion) < 0 {
// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
// thus, this FeatureVersion is affected by it.
affecteds = append(affecteds, affected)
}
}
if err = rows.Err(); err != nil {
return handleError("f_featureversion_by_feature.Rows()", err)
}
rows.Close()
// Insert into Vulnerability_Affects_FeatureVersion.
for _, affected := range affecteds {
// TODO(Quentin-M): Batch me.
_, err := tx.Exec(getQuery("i_vulnerability_affects_featureversion"), vulnerabilityID,
affected.ID, fixedInID)
if err != nil {
return handleError("i_vulnerability_affects_featureversion", err)
}
}
return nil
}