pgsql: Fix postgres queries for feature_type

This commit is contained in:
Sida Chen 2019-02-19 16:42:00 -05:00
parent 5fa1ac89b9
commit 79af05e67d
12 changed files with 200 additions and 275 deletions

View File

@ -23,10 +23,11 @@ const (
findAncestryFeatures = `
SELECT namespace.name, namespace.version_format, feature.name,
feature.version, feature.version_format, ancestry_layer.ancestry_index,
feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index,
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
FROM namespace, feature, namespaced_feature, ancestry_layer, ancestry_feature
FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature
WHERE ancestry_layer.ancestry_id = $1
AND feature_type.id = feature.type
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
AND namespaced_feature.feature_id = feature.id
@ -256,6 +257,7 @@ func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMa
&feature.Feature.Name,
&feature.Feature.Version,
&feature.Feature.VersionFormat,
&feature.Feature.Type,
&index,
&featureDetectorID,
&namespaceDetectorID,

View File

@ -25,6 +25,7 @@ import (
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt"
@ -65,19 +66,13 @@ func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database
for i := 0; i < numFeatures; i++ {
version := rand.Intn(numFeatures)
features[i] = database.Feature{
Name: featureName,
VersionFormat: featureVersionFormat,
Version: strconv.Itoa(version),
}
features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat)
nsFeatures[i] = database.NamespacedFeature{
Namespace: namespace,
Feature: features[i],
}
}
// insert features
if !assert.Nil(t, tx.PersistFeatures(features)) {
t.FailNow()
}
@ -98,6 +93,7 @@ func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database
{
Namespace: namespace,
FeatureName: featureName,
FeatureType: database.SourcePackage,
AffectedVersion: strconv.Itoa(version),
FixedInVersion: strconv.Itoa(version),
},
@ -117,7 +113,6 @@ func TestConcurrency(t *testing.T) {
t.FailNow()
}
defer store.Close()
start := time.Now()
var wg sync.WaitGroup
wg.Add(100)
@ -137,65 +132,39 @@ func TestConcurrency(t *testing.T) {
fmt.Println("total", time.Since(start))
}
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
r := make([]database.Namespace, count)
for i := 0; i < count; i++ {
r[i] = database.Namespace{
Name: uuid.New(),
VersionFormat: "dpkg",
}
}
return r
}
func TestCaching(t *testing.T) {
store, err := openDatabaseForTest("Caching", false)
if !assert.Nil(t, err) {
t.FailNow()
}
defer store.Close()
nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store)
fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
tx, err := store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
fmt.Println("finished to insert namespaced features")
require.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
if err := tx.Commit(); err != nil {
panic(err)
}
tx.Commit()
}()
go func() {
defer wg.Done()
tx, err := store.Begin()
tx, err = store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
fmt.Println("finished to insert vulnerabilities")
tx.Commit()
require.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
if err := tx.Commit(); err != nil {
panic(err)
}
}()
wg.Wait()
tx, err := store.Begin()
tx, err = store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
defer tx.Rollback()
// Verify consistency now.
affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures)
if !assert.Nil(t, err) {
t.FailNow()
@ -220,7 +189,7 @@ func TestCaching(t *testing.T) {
actualAffectedNames = append(actualAffectedNames, s.Name)
}
assert.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0)
assert.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames)
require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
}
}

View File

@ -46,6 +46,7 @@ const (
AND nf.feature_id = f.id
AND nf.namespace_id = v.namespace_id
AND vaf.feature_name = f.name
AND vaf.feature_type = f.type
AND vaf.vulnerability_id = v.id
AND v.deleted_at IS NULL`
@ -68,6 +69,11 @@ func (tx *pgSession) PersistFeatures(features []database.Feature) error {
return nil
}
types, err := tx.getFeatureTypeMap()
if err != nil {
return err
}
// Sorting is needed before inserting into database to prevent deadlock.
sort.Slice(features, func(i, j int) bool {
return features[i].Name < features[j].Name ||
@ -78,13 +84,13 @@ func (tx *pgSession) PersistFeatures(features []database.Feature) error {
// TODO(Sida): A better interface for bulk insertion is needed.
keys := make([]interface{}, 0, len(features)*3)
for _, f := range features {
keys = append(keys, f.Name, f.Version, f.VersionFormat)
keys = append(keys, f.Name, f.Version, f.VersionFormat, types.byName[f.Type])
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
}
}
_, err := tx.Exec(queryPersistFeature(len(features)), keys...)
_, err = tx.Exec(queryPersistFeature(len(features)), keys...)
return handleError("queryPersistFeature", err)
}
@ -240,55 +246,31 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea
return nil
}
// FindAffectedNamespacedFeatures looks up cache table and retrieves all
// vulnerabilities associated with the features.
// FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the
// feature.
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
if len(features) == 0 {
return nil, nil
}
returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
// featureMap is used to keep track of duplicated features.
featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{}
// initialize return value and generate unique feature request queries.
for i, f := range features {
returnFeatures[i] = database.NullableAffectedNamespacedFeature{
AffectedNamespacedFeature: database.AffectedNamespacedFeature{
NamespacedFeature: f,
},
}
featureMap[f] = append(featureMap[f], &returnFeatures[i])
}
// query unique namespaced features
distinctFeatures := []database.NamespacedFeature{}
for f := range featureMap {
distinctFeatures = append(distinctFeatures, f)
}
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures)
vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
featureIDs, err := tx.findNamespacedFeatureIDs(features)
if err != nil {
return nil, err
}
toQuery := []int64{}
featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{}
for i, id := range nsFeatureIDs {
for i, id := range featureIDs {
if id.Valid {
toQuery = append(toQuery, id.Int64)
for _, f := range featureMap[distinctFeatures[i]] {
f.Valid = id.Valid
featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f)
}
vulnerableFeatures[i].Valid = true
vulnerableFeatures[i].NamespacedFeature = features[i]
}
}
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery))
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs))
if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
}
defer rows.Close()
for rows.Next() {
@ -296,6 +278,7 @@ func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.Namespac
featureID int64
vuln database.VulnerabilityWithFixedIn
)
err := rows.Scan(&featureID,
&vuln.Name,
&vuln.Description,
@ -306,16 +289,19 @@ func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.Namespac
&vuln.Namespace.Name,
&vuln.Namespace.VersionFormat,
)
if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
}
for _, f := range featureIDMap[featureID] {
f.AffectedBy = append(f.AffectedBy, vuln)
for i, id := range featureIDs {
if id.Valid && id.Int64 == featureID {
vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln)
}
}
}
return returnFeatures, nil
return vulnerableFeatures, nil
}
func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
@ -323,11 +309,10 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature)
return nil, nil
}
nfsMap := map[database.NamespacedFeature]sql.NullInt64{}
keys := make([]interface{}, 0, len(nfs)*4)
nfsMap := map[database.NamespacedFeature]int64{}
keys := make([]interface{}, 0, len(nfs)*5)
for _, nf := range nfs {
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Namespace.Name)
nfsMap[nf] = sql.NullInt64{}
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name)
}
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
@ -337,12 +322,12 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature)
defer rows.Close()
var (
id sql.NullInt64
id int64
nf database.NamespacedFeature
)
for rows.Next() {
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name)
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name)
nf.Namespace.VersionFormat = nf.VersionFormat
if err != nil {
return nil, handleError("searchNamespacedFeature", err)
@ -352,7 +337,11 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature)
ids := make([]sql.NullInt64, len(nfs))
for i, nf := range nfs {
ids[i] = nfsMap[nf]
if id, ok := nfsMap[nf]; ok {
ids[i] = sql.NullInt64{id, true}
} else {
ids[i] = sql.NullInt64{}
}
}
return ids, nil
@ -363,11 +352,17 @@ func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, err
return nil, nil
}
types, err := tx.getFeatureTypeMap()
if err != nil {
return nil, err
}
fMap := map[database.Feature]sql.NullInt64{}
keys := make([]interface{}, 0, len(fs)*3)
keys := make([]interface{}, 0, len(fs)*4)
for _, f := range fs {
keys = append(keys, f.Name, f.Version, f.VersionFormat)
typeID := types.byName[f.Type]
keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID)
fMap[f] = sql.NullInt64{}
}
@ -382,10 +377,13 @@ func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, err
f database.Feature
)
for rows.Next() {
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat)
var typeID int
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID)
if err != nil {
return nil, handleError("querySearchFeatureID", err)
}
f.Type = types.byID[typeID]
fMap[f] = id
}

View File

@ -18,134 +18,53 @@ import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database"
// register dpkg feature lister for testing
_ "github.com/coreos/clair/ext/featurefmt/dpkg"
)
func TestPersistFeatures(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistFeatures", false)
defer closeTest(t, datastore, tx)
tx, cleanup := createTestPgSession(t, "TestPersistFeatures")
defer cleanup()
f1 := database.Feature{}
f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"}
invalid := database.Feature{}
valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg")
// empty
assert.Nil(t, tx.PersistFeatures([]database.Feature{}))
// invalid
assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1}))
// duplicated
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2}))
require.NotNil(t, tx.PersistFeatures([]database.Feature{invalid}))
// existing
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2}))
require.Nil(t, tx.PersistFeatures([]database.Feature{valid}))
require.Nil(t, tx.PersistFeatures([]database.Feature{valid}))
fs := listFeatures(t, tx)
assert.Len(t, fs, 1)
assert.Equal(t, f2, fs[0])
features := selectAllFeatures(t, tx)
assert.Equal(t, []database.Feature{valid}, features)
}
func TestPersistNamespacedFeatures(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true)
defer closeTest(t, datastore, tx)
tx, cleanup := createTestPgSessionWithFixtures(t, "TestPersistNamespacedFeatures")
defer cleanup()
// existing features
f1 := database.Feature{
Name: "ourchat",
Version: "0.5",
VersionFormat: "dpkg",
}
f1 := database.NewSourcePackage("ourchat", "0.5", "dpkg")
// non-existing features
f2 := database.Feature{
Name: "fake!",
}
f3 := database.Feature{
Name: "openssl",
Version: "2.0",
VersionFormat: "dpkg",
}
f2 := database.NewSourcePackage("fake!", "", "")
// exising namespace
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
n3 := database.Namespace{
Name: "debian:8",
VersionFormat: "dpkg",
}
n1 := database.NewNamespace("debian:7", "dpkg")
// non-existing namespace
n2 := database.Namespace{
Name: "debian:non",
VersionFormat: "dpkg",
}
n2 := database.NewNamespace("debian:non", "dpkg")
// existing namespaced feature
nf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
nf1 := database.NewNamespacedFeature(n1, f1)
// invalid namespaced feature
nf2 := database.NamespacedFeature{
Namespace: n2,
Feature: f2,
}
// new namespaced feature affected by vulnerability
nf3 := database.NamespacedFeature{
Namespace: n3,
Feature: f3,
}
nf2 := database.NewNamespacedFeature(n2, f2)
// namespaced features with namespaces or features not in the database will
// generate error.
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{}))
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2}))
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1, *nf2}))
// valid case: insert nf3
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3}))
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1}))
all := listNamespacedFeatures(t, tx)
assert.Contains(t, all, nf1)
assert.Contains(t, all, nf3)
}
func TestVulnerableFeature(t *testing.T) {
datastore, tx := openSessionForTest(t, "VulnerableFeature", true)
defer closeTest(t, datastore, tx)
f1 := database.Feature{
Name: "openssl",
Version: "1.3",
VersionFormat: "dpkg",
}
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
nf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
assert.Nil(t, tx.PersistFeatures([]database.Feature{f1}))
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1}))
assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}))
// ensure the namespaced feature is affected correctly
anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})
if assert.Nil(t, err) &&
assert.Len(t, anf, 1) &&
assert.True(t, anf[0].Valid) &&
assert.Len(t, anf[0].AffectedBy, 1) {
assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name)
}
assert.Contains(t, all, *nf1)
}
func TestFindAffectedNamespacedFeatures(t *testing.T) {
@ -156,6 +75,7 @@ func TestFindAffectedNamespacedFeatures(t *testing.T) {
Name: "openssl",
Version: "1.0",
VersionFormat: "dpkg",
Type: database.SourcePackage,
},
Namespace: database.Namespace{
Name: "debian:7",
@ -173,30 +93,41 @@ func TestFindAffectedNamespacedFeatures(t *testing.T) {
}
func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature {
rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format
types, err := tx.getFeatureTypeMap()
if err != nil {
panic(err)
}
rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, f.type, n.name, n.version_format
FROM feature AS f, namespace AS n, namespaced_feature AS nf
WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`)
if err != nil {
t.Error(err)
t.FailNow()
panic(err)
}
nf := []database.NamespacedFeature{}
for rows.Next() {
f := database.NamespacedFeature{}
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat)
var typeID int
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID, &f.Namespace.Name, &f.Namespace.VersionFormat)
if err != nil {
t.Error(err)
t.FailNow()
panic(err)
}
f.Type = types.byID[typeID]
nf = append(nf, f)
}
return nf
}
func listFeatures(t *testing.T, tx *pgSession) []database.Feature {
rows, err := tx.Query("SELECT name, version, version_format FROM feature")
func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
types, err := tx.getFeatureTypeMap()
if err != nil {
panic(err)
}
rows, err := tx.Query("SELECT name, version, version_format, type FROM feature")
if err != nil {
t.FailNow()
}
@ -204,7 +135,9 @@ func listFeatures(t *testing.T, tx *pgSession) []database.Feature {
fs := []database.Feature{}
for rows.Next() {
f := database.Feature{}
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat)
var typeID int
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID)
f.Type = types.byID[typeID]
if err != nil {
t.FailNow()
}
@ -233,3 +166,33 @@ func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFe
}
return false
}
func TestFindNamespacedFeatureIDs(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNamespacedFeatureIDs")
defer cleanup()
features := []database.NamespacedFeature{}
expectedIDs := []int{}
for id, feature := range realNamespacedFeatures {
features = append(features, feature)
expectedIDs = append(expectedIDs, id)
}
features = append(features, realNamespacedFeatures[1]) // test duplicated
expectedIDs = append(expectedIDs, 1)
namespace := realNamespaces[1]
features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature
ids, err := tx.findNamespacedFeatureIDs(features)
require.Nil(t, err)
require.Len(t, ids, len(expectedIDs)+1)
for i, id := range ids {
if i == len(ids)-1 {
require.False(t, id.Valid)
} else {
require.True(t, id.Valid)
require.Equal(t, expectedIDs[i], int(id.Int64))
}
}
}

View File

@ -37,9 +37,10 @@ const (
SELECT id FROM layer WHERE hash = $1`
findLayerFeatures = `
SELECT f.name, f.version, f.version_format, lf.detector_id
FROM layer_feature AS lf, feature AS f
SELECT f.name, f.version, f.version_format, t.name, lf.detector_id
FROM layer_feature AS lf, feature AS f, feature_type AS t
WHERE lf.feature_id = f.id
AND t.id = f.type
AND lf.layer_id = $1`
findLayerNamespaces = `
@ -307,7 +308,7 @@ func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]
detectorID int64
feature database.LayerFeature
)
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &detectorID); err != nil {
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID); err != nil {
return nil, handleError("findLayerFeatures", err)
}

View File

@ -43,12 +43,12 @@ var persistLayerTests = []struct {
features: []database.LayerFeature{
{realFeatures[1], realDetectors[1]},
},
err: "database: parameters are not valid",
err: "parameters are not valid",
},
{
title: "layer with non-existing feature",
name: "random-forest",
err: "database: associated immutable entities are missing in the database",
err: "associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[2]},
features: []database.LayerFeature{
{fakeFeatures[1], realDetectors[2]},
@ -57,7 +57,7 @@ var persistLayerTests = []struct {
{
title: "layer with non-existing namespace",
name: "random-forest2",
err: "database: associated immutable entities are missing in the database",
err: "associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[1]},
namespaces: []database.LayerNamespace{
{fakeNamespaces[1], realDetectors[1]},
@ -66,7 +66,7 @@ var persistLayerTests = []struct {
{
title: "layer with non-existing detector",
name: "random-forest3",
err: "database: associated immutable entities are missing in the database",
err: "associated immutable entities are missing in the database",
by: []database.Detector{fakeDetector[1]},
},
{

View File

@ -211,8 +211,8 @@ func TestInsertVulnerabilityNotifications(t *testing.T) {
}
func TestFindNewNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindNewNotification", true)
defer closeTest(t, datastore, tx)
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification")
defer cleanup()
noti, ok, err := tx.FindNewNotification(time.Now())
if assert.Nil(t, err) && assert.True(t, ok) {
@ -229,7 +229,7 @@ func TestFindNewNotification(t *testing.T) {
assert.Nil(t, err)
assert.False(t, ok)
// can find the notified after a period of time
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second)))
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
assert.NotEqual(t, time.Time{}, noti.Notified)

View File

@ -109,6 +109,7 @@ func dropTemplateDatabase(url string, name string) {
}
}
func TestMain(m *testing.M) {
fURL, fName := genTemplateDatabase("fixture", true)
nfURL, nfName := genTemplateDatabase("nonfixture", false)

View File

@ -52,21 +52,22 @@ func querySearchNotDeletedVulnerabilityID(count int) string {
func querySearchFeatureID(featureCount int) string {
return fmt.Sprintf(`
SELECT id, name, version, version_format
FROM Feature WHERE (name, version, version_format) IN (%s)`,
queryString(3, featureCount),
SELECT id, name, version, version_format, type
FROM Feature WHERE (name, version, version_format, type) IN (%s)`,
queryString(4, featureCount),
)
}
func querySearchNamespacedFeature(nsfCount int) string {
return fmt.Sprintf(`
SELECT nf.id, f.name, f.version, f.version_format, n.name
FROM namespaced_feature AS nf, feature AS f, namespace AS n
SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name
FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t
WHERE nf.feature_id = f.id
AND nf.namespace_id = n.id
AND n.version_format = f.version_format
AND (f.name, f.version, f.version_format, n.name) IN (%s)`,
queryString(4, nsfCount),
AND f.type = t.id
AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`,
queryString(5, nsfCount),
)
}
@ -110,10 +111,11 @@ func queryInsertNotifications(count int) string {
func queryPersistFeature(count int) string {
return queryPersist(count,
"feature",
"feature_name_version_version_format_key",
"feature_name_version_version_format_type_key",
"name",
"version",
"version_format")
"version_format",
"type")
}
func queryPersistLayerFeature(count int) string {

View File

@ -4,11 +4,12 @@ INSERT INTO namespace (id, name, version_format) VALUES
(2, 'debian:8', 'dpkg'),
(3, 'fake:1.0', 'rpm');
INSERT INTO feature (id, name, version, version_format) VALUES
(1, 'ourchat', '0.5', 'dpkg'),
(2, 'openssl', '1.0', 'dpkg'),
(3, 'openssl', '2.0', 'dpkg'),
(4, 'fake', '2.0', 'rpm');
INSERT INTO feature (id, name, version, version_format, type) VALUES
(1, 'ourchat', '0.5', 'dpkg', 1),
(2, 'openssl', '1.0', 'dpkg', 1),
(3, 'openssl', '2.0', 'dpkg', 1),
(4, 'fake', '2.0', 'rpm', 1),
(5, 'mount', '2.31.1-0.4ubuntu3.1', 'dpkg', 2);
INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES
(1, 1, 1), -- ourchat 0.5, debian:7
@ -112,9 +113,9 @@ INSERT INTO vulnerability (id, namespace_id, name, description, link, severity)
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES
(3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483');
INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES
(1, 1, 'openssl', '2.0', '2.0'),
(2, 1, 'libssl', '1.9-abc', '1.9-abc');
INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin, feature_type) VALUES
(1, 1, 'openssl', '2.0', '2.0', 1),
(2, 1, 'libssl', '1.9-abc', '1.9-abc', 1);
INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES
(1, 1, 2, 1);

View File

@ -39,15 +39,15 @@ const (
`
insertVulnerabilityAffected = `
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin)
VALUES ($1, $2, $3, $4)
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin)
VALUES ($1, $2, $3, $4, $5)
RETURNING ID
`
searchVulnerabilityAffected = `
SELECT vulnerability_id, feature_name, affected_version, fixedin
FROM vulnerability_affected_feature
WHERE vulnerability_id = ANY($1)
SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin
FROM vulnerability_affected_feature AS vaf, feature_type AS t
WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1)
`
searchVulnerabilityByID = `
@ -58,7 +58,7 @@ const (
searchVulnerabilityPotentialAffected = `
WITH req AS (
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id
FROM vulnerability_affected_feature AS vaf,
vulnerability AS v,
namespace AS n
@ -69,6 +69,7 @@ const (
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
FROM feature AS f, namespaced_feature AS nf, req
WHERE f.name = req.name
AND f.type = req.type
AND nf.namespace_id = req.n_id
AND nf.feature_id = f.id`
@ -180,7 +181,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
f database.AffectedFeature
)
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion)
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion)
if err != nil {
return nil, handleError("searchVulnerabilityAffected", err)
}
@ -220,6 +221,11 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
affectedID int64
)
types, err := tx.getFeatureTypeMap()
if err != nil {
return nil, err
}
//TODO(Sida): Change to bulk insert.
stmt, err := tx.Prepare(insertVulnerabilityAffected)
if err != nil {
@ -231,7 +237,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
// affected feature row ID -> affected feature
affectedFeatures := map[int64]database.AffectedFeature{}
for _, f := range vuln.Affected {
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID)
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
if err != nil {
return nil, handleError("insertVulnerabilityAffected", err)
}

View File

@ -106,6 +106,7 @@ func TestCachingVulnerable(t *testing.T) {
Name: "openssl",
Version: "1.0",
VersionFormat: dpkg.ParserName,
Type: database.SourcePackage,
},
Namespace: ns,
}
@ -120,6 +121,7 @@ func TestCachingVulnerable(t *testing.T) {
{
Namespace: ns,
FeatureName: "openssl",
FeatureType: database.SourcePackage,
AffectedVersion: "2.0",
FixedInVersion: "2.1",
},
@ -136,6 +138,7 @@ func TestCachingVulnerable(t *testing.T) {
{
Namespace: ns,
FeatureName: "openssl",
FeatureType: database.SourcePackage,
AffectedVersion: "2.1",
FixedInVersion: "2.2",
},
@ -209,12 +212,14 @@ func TestFindVulnerabilities(t *testing.T) {
Affected: []database.AffectedFeature{
{
FeatureName: "openssl",
FeatureType: database.SourcePackage,
AffectedVersion: "2.0",
FixedInVersion: "2.0",
Namespace: ns,
},
{
FeatureName: "libssl",
FeatureType: database.SourcePackage,
AffectedVersion: "1.9-abc",
FixedInVersion: "1.9-abc",
Namespace: ns,
@ -318,26 +323,3 @@ func TestFindVulnerabilityIDs(t *testing.T) {
}
}
}
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
}
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.AffectedFeature]bool{}
for _, i := range expected {
has[i] = false
}
for _, i := range actual {
if visited, ok := has[i]; !ok {
return false
} else if visited {
return false
}
has[i] = true
}
return true
}
return false
}