pgsql: Fix postgres queries for feature_type

This commit is contained in:
Sida Chen 2019-02-19 16:42:00 -05:00
parent 5fa1ac89b9
commit 79af05e67d
12 changed files with 200 additions and 275 deletions

View File

@ -23,10 +23,11 @@ const (
findAncestryFeatures = ` findAncestryFeatures = `
SELECT namespace.name, namespace.version_format, feature.name, SELECT namespace.name, namespace.version_format, feature.name,
feature.version, feature.version_format, ancestry_layer.ancestry_index, feature.version, feature.version_format, feature_type.name, ancestry_layer.ancestry_index,
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
FROM namespace, feature, namespaced_feature, ancestry_layer, ancestry_feature FROM namespace, feature, feature_type, namespaced_feature, ancestry_layer, ancestry_feature
WHERE ancestry_layer.ancestry_id = $1 WHERE ancestry_layer.ancestry_id = $1
AND feature_type.id = feature.type
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
AND namespaced_feature.feature_id = feature.id AND namespaced_feature.feature_id = feature.id
@ -256,6 +257,7 @@ func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMa
&feature.Feature.Name, &feature.Feature.Name,
&feature.Feature.Version, &feature.Feature.Version,
&feature.Feature.VersionFormat, &feature.Feature.VersionFormat,
&feature.Feature.Type,
&index, &index,
&featureDetectorID, &featureDetectorID,
&namespaceDetectorID, &namespaceDetectorID,

View File

@ -25,6 +25,7 @@ import (
"github.com/pborman/uuid" "github.com/pborman/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/versionfmt" "github.com/coreos/clair/ext/versionfmt"
@ -65,19 +66,13 @@ func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database
for i := 0; i < numFeatures; i++ { for i := 0; i < numFeatures; i++ {
version := rand.Intn(numFeatures) version := rand.Intn(numFeatures)
features[i] = database.Feature{ features[i] = *database.NewSourcePackage(featureName, strconv.Itoa(version), featureVersionFormat)
Name: featureName,
VersionFormat: featureVersionFormat,
Version: strconv.Itoa(version),
}
nsFeatures[i] = database.NamespacedFeature{ nsFeatures[i] = database.NamespacedFeature{
Namespace: namespace, Namespace: namespace,
Feature: features[i], Feature: features[i],
} }
} }
// insert features
if !assert.Nil(t, tx.PersistFeatures(features)) { if !assert.Nil(t, tx.PersistFeatures(features)) {
t.FailNow() t.FailNow()
} }
@ -98,6 +93,7 @@ func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database
{ {
Namespace: namespace, Namespace: namespace,
FeatureName: featureName, FeatureName: featureName,
FeatureType: database.SourcePackage,
AffectedVersion: strconv.Itoa(version), AffectedVersion: strconv.Itoa(version),
FixedInVersion: strconv.Itoa(version), FixedInVersion: strconv.Itoa(version),
}, },
@ -117,7 +113,6 @@ func TestConcurrency(t *testing.T) {
t.FailNow() t.FailNow()
} }
defer store.Close() defer store.Close()
start := time.Now() start := time.Now()
var wg sync.WaitGroup var wg sync.WaitGroup
wg.Add(100) wg.Add(100)
@ -137,65 +132,39 @@ func TestConcurrency(t *testing.T) {
fmt.Println("total", time.Since(start)) fmt.Println("total", time.Since(start))
} }
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
r := make([]database.Namespace, count)
for i := 0; i < count; i++ {
r[i] = database.Namespace{
Name: uuid.New(),
VersionFormat: "dpkg",
}
}
return r
}
func TestCaching(t *testing.T) { func TestCaching(t *testing.T) {
store, err := openDatabaseForTest("Caching", false) store, err := openDatabaseForTest("Caching", false)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
t.FailNow() t.FailNow()
} }
defer store.Close() defer store.Close()
nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store) nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store)
fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities))
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
tx, err := store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
fmt.Println("finished to insert namespaced features")
tx.Commit()
}()
go func() {
defer wg.Done()
tx, err := store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
fmt.Println("finished to insert vulnerabilities")
tx.Commit()
}()
wg.Wait()
tx, err := store.Begin() tx, err := store.Begin()
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
t.FailNow() t.FailNow()
} }
require.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
if err := tx.Commit(); err != nil {
panic(err)
}
tx, err = store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
require.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
if err := tx.Commit(); err != nil {
panic(err)
}
tx, err = store.Begin()
if !assert.Nil(t, err) {
t.FailNow()
}
defer tx.Rollback() defer tx.Rollback()
// Verify consistency now.
affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures) affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
t.FailNow() t.FailNow()
@ -220,7 +189,7 @@ func TestCaching(t *testing.T) {
actualAffectedNames = append(actualAffectedNames, s.Name) actualAffectedNames = append(actualAffectedNames, s.Name)
} }
assert.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0) require.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0, "\nvulns: %#v\nfeature:%#v\nexpected:%#v\nactual:%#v", vulnerabilities, ansf.NamespacedFeature, expectedAffectedNames, actualAffectedNames)
assert.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0) require.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
} }
} }

View File

@ -46,6 +46,7 @@ const (
AND nf.feature_id = f.id AND nf.feature_id = f.id
AND nf.namespace_id = v.namespace_id AND nf.namespace_id = v.namespace_id
AND vaf.feature_name = f.name AND vaf.feature_name = f.name
AND vaf.feature_type = f.type
AND vaf.vulnerability_id = v.id AND vaf.vulnerability_id = v.id
AND v.deleted_at IS NULL` AND v.deleted_at IS NULL`
@ -68,6 +69,11 @@ func (tx *pgSession) PersistFeatures(features []database.Feature) error {
return nil return nil
} }
types, err := tx.getFeatureTypeMap()
if err != nil {
return err
}
// Sorting is needed before inserting into database to prevent deadlock. // Sorting is needed before inserting into database to prevent deadlock.
sort.Slice(features, func(i, j int) bool { sort.Slice(features, func(i, j int) bool {
return features[i].Name < features[j].Name || return features[i].Name < features[j].Name ||
@ -78,13 +84,13 @@ func (tx *pgSession) PersistFeatures(features []database.Feature) error {
// TODO(Sida): A better interface for bulk insertion is needed. // TODO(Sida): A better interface for bulk insertion is needed.
keys := make([]interface{}, 0, len(features)*3) keys := make([]interface{}, 0, len(features)*3)
for _, f := range features { for _, f := range features {
keys = append(keys, f.Name, f.Version, f.VersionFormat) keys = append(keys, f.Name, f.Version, f.VersionFormat, types.byName[f.Type])
if f.Name == "" || f.Version == "" || f.VersionFormat == "" { if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed") return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
} }
} }
_, err := tx.Exec(queryPersistFeature(len(features)), keys...) _, err = tx.Exec(queryPersistFeature(len(features)), keys...)
return handleError("queryPersistFeature", err) return handleError("queryPersistFeature", err)
} }
@ -240,55 +246,31 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea
return nil return nil
} }
// FindAffectedNamespacedFeatures looks up cache table and retrieves all // FindAffectedNamespacedFeatures retrieves vulnerabilities associated with the
// vulnerabilities associated with the features. // feature.
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) { func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
if len(features) == 0 { if len(features) == 0 {
return nil, nil return nil, nil
} }
returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features)) vulnerableFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
featureIDs, err := tx.findNamespacedFeatureIDs(features)
// featureMap is used to keep track of duplicated features.
featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{}
// initialize return value and generate unique feature request queries.
for i, f := range features {
returnFeatures[i] = database.NullableAffectedNamespacedFeature{
AffectedNamespacedFeature: database.AffectedNamespacedFeature{
NamespacedFeature: f,
},
}
featureMap[f] = append(featureMap[f], &returnFeatures[i])
}
// query unique namespaced features
distinctFeatures := []database.NamespacedFeature{}
for f := range featureMap {
distinctFeatures = append(distinctFeatures, f)
}
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures)
if err != nil { if err != nil {
return nil, err return nil, err
} }
toQuery := []int64{} for i, id := range featureIDs {
featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{}
for i, id := range nsFeatureIDs {
if id.Valid { if id.Valid {
toQuery = append(toQuery, id.Int64) vulnerableFeatures[i].Valid = true
for _, f := range featureMap[distinctFeatures[i]] { vulnerableFeatures[i].NamespacedFeature = features[i]
f.Valid = id.Valid
featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f)
}
} }
} }
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery)) rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(featureIDs))
if err != nil { if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
} }
defer rows.Close() defer rows.Close()
for rows.Next() { for rows.Next() {
@ -296,6 +278,7 @@ func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.Namespac
featureID int64 featureID int64
vuln database.VulnerabilityWithFixedIn vuln database.VulnerabilityWithFixedIn
) )
err := rows.Scan(&featureID, err := rows.Scan(&featureID,
&vuln.Name, &vuln.Name,
&vuln.Description, &vuln.Description,
@ -306,16 +289,19 @@ func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.Namespac
&vuln.Namespace.Name, &vuln.Namespace.Name,
&vuln.Namespace.VersionFormat, &vuln.Namespace.VersionFormat,
) )
if err != nil { if err != nil {
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err) return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
} }
for _, f := range featureIDMap[featureID] { for i, id := range featureIDs {
f.AffectedBy = append(f.AffectedBy, vuln) if id.Valid && id.Int64 == featureID {
vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy = append(vulnerableFeatures[i].AffectedNamespacedFeature.AffectedBy, vuln)
}
} }
} }
return returnFeatures, nil return vulnerableFeatures, nil
} }
func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) { func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
@ -323,11 +309,10 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature)
return nil, nil return nil, nil
} }
nfsMap := map[database.NamespacedFeature]sql.NullInt64{} nfsMap := map[database.NamespacedFeature]int64{}
keys := make([]interface{}, 0, len(nfs)*4) keys := make([]interface{}, 0, len(nfs)*5)
for _, nf := range nfs { for _, nf := range nfs {
keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Namespace.Name) keys = append(keys, nf.Name, nf.Version, nf.VersionFormat, nf.Type, nf.Namespace.Name)
nfsMap[nf] = sql.NullInt64{}
} }
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...) rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
@ -337,12 +322,12 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature)
defer rows.Close() defer rows.Close()
var ( var (
id sql.NullInt64 id int64
nf database.NamespacedFeature nf database.NamespacedFeature
) )
for rows.Next() { for rows.Next() {
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name) err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Type, &nf.Namespace.Name)
nf.Namespace.VersionFormat = nf.VersionFormat nf.Namespace.VersionFormat = nf.VersionFormat
if err != nil { if err != nil {
return nil, handleError("searchNamespacedFeature", err) return nil, handleError("searchNamespacedFeature", err)
@ -352,7 +337,11 @@ func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature)
ids := make([]sql.NullInt64, len(nfs)) ids := make([]sql.NullInt64, len(nfs))
for i, nf := range nfs { for i, nf := range nfs {
ids[i] = nfsMap[nf] if id, ok := nfsMap[nf]; ok {
ids[i] = sql.NullInt64{id, true}
} else {
ids[i] = sql.NullInt64{}
}
} }
return ids, nil return ids, nil
@ -363,11 +352,17 @@ func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, err
return nil, nil return nil, nil
} }
types, err := tx.getFeatureTypeMap()
if err != nil {
return nil, err
}
fMap := map[database.Feature]sql.NullInt64{} fMap := map[database.Feature]sql.NullInt64{}
keys := make([]interface{}, 0, len(fs)*3) keys := make([]interface{}, 0, len(fs)*4)
for _, f := range fs { for _, f := range fs {
keys = append(keys, f.Name, f.Version, f.VersionFormat) typeID := types.byName[f.Type]
keys = append(keys, f.Name, f.Version, f.VersionFormat, typeID)
fMap[f] = sql.NullInt64{} fMap[f] = sql.NullInt64{}
} }
@ -382,10 +377,13 @@ func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, err
f database.Feature f database.Feature
) )
for rows.Next() { for rows.Next() {
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat) var typeID int
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat, &typeID)
if err != nil { if err != nil {
return nil, handleError("querySearchFeatureID", err) return nil, handleError("querySearchFeatureID", err)
} }
f.Type = types.byID[typeID]
fMap[f] = id fMap[f] = id
} }

View File

@ -18,134 +18,53 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
// register dpkg feature lister for testing
_ "github.com/coreos/clair/ext/featurefmt/dpkg"
) )
func TestPersistFeatures(t *testing.T) { func TestPersistFeatures(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistFeatures", false) tx, cleanup := createTestPgSession(t, "TestPersistFeatures")
defer closeTest(t, datastore, tx) defer cleanup()
f1 := database.Feature{} invalid := database.Feature{}
f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"} valid := *database.NewBinaryPackage("mount", "2.31.1-0.4ubuntu3.1", "dpkg")
// empty
assert.Nil(t, tx.PersistFeatures([]database.Feature{}))
// invalid // invalid
assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1})) require.NotNil(t, tx.PersistFeatures([]database.Feature{invalid}))
// duplicated
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2}))
// existing // existing
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2})) require.Nil(t, tx.PersistFeatures([]database.Feature{valid}))
require.Nil(t, tx.PersistFeatures([]database.Feature{valid}))
fs := listFeatures(t, tx) features := selectAllFeatures(t, tx)
assert.Len(t, fs, 1) assert.Equal(t, []database.Feature{valid}, features)
assert.Equal(t, f2, fs[0])
} }
func TestPersistNamespacedFeatures(t *testing.T) { func TestPersistNamespacedFeatures(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true) tx, cleanup := createTestPgSessionWithFixtures(t, "TestPersistNamespacedFeatures")
defer closeTest(t, datastore, tx) defer cleanup()
// existing features // existing features
f1 := database.Feature{ f1 := database.NewSourcePackage("ourchat", "0.5", "dpkg")
Name: "ourchat",
Version: "0.5",
VersionFormat: "dpkg",
}
// non-existing features // non-existing features
f2 := database.Feature{ f2 := database.NewSourcePackage("fake!", "", "")
Name: "fake!",
}
f3 := database.Feature{
Name: "openssl",
Version: "2.0",
VersionFormat: "dpkg",
}
// exising namespace // exising namespace
n1 := database.Namespace{ n1 := database.NewNamespace("debian:7", "dpkg")
Name: "debian:7",
VersionFormat: "dpkg",
}
n3 := database.Namespace{
Name: "debian:8",
VersionFormat: "dpkg",
}
// non-existing namespace // non-existing namespace
n2 := database.Namespace{ n2 := database.NewNamespace("debian:non", "dpkg")
Name: "debian:non",
VersionFormat: "dpkg",
}
// existing namespaced feature // existing namespaced feature
nf1 := database.NamespacedFeature{ nf1 := database.NewNamespacedFeature(n1, f1)
Namespace: n1,
Feature: f1,
}
// invalid namespaced feature // invalid namespaced feature
nf2 := database.NamespacedFeature{ nf2 := database.NewNamespacedFeature(n2, f2)
Namespace: n2,
Feature: f2,
}
// new namespaced feature affected by vulnerability
nf3 := database.NamespacedFeature{
Namespace: n3,
Feature: f3,
}
// namespaced features with namespaces or features not in the database will // namespaced features with namespaces or features not in the database will
// generate error. // generate error.
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{})) assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{}))
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1, *nf2}))
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2}))
// valid case: insert nf3 // valid case: insert nf3
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3})) assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{*nf1}))
all := listNamespacedFeatures(t, tx) all := listNamespacedFeatures(t, tx)
assert.Contains(t, all, nf1) assert.Contains(t, all, *nf1)
assert.Contains(t, all, nf3)
}
func TestVulnerableFeature(t *testing.T) {
datastore, tx := openSessionForTest(t, "VulnerableFeature", true)
defer closeTest(t, datastore, tx)
f1 := database.Feature{
Name: "openssl",
Version: "1.3",
VersionFormat: "dpkg",
}
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
nf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
assert.Nil(t, tx.PersistFeatures([]database.Feature{f1}))
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1}))
assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}))
// ensure the namespaced feature is affected correctly
anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})
if assert.Nil(t, err) &&
assert.Len(t, anf, 1) &&
assert.True(t, anf[0].Valid) &&
assert.Len(t, anf[0].AffectedBy, 1) {
assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name)
}
} }
func TestFindAffectedNamespacedFeatures(t *testing.T) { func TestFindAffectedNamespacedFeatures(t *testing.T) {
@ -156,6 +75,7 @@ func TestFindAffectedNamespacedFeatures(t *testing.T) {
Name: "openssl", Name: "openssl",
Version: "1.0", Version: "1.0",
VersionFormat: "dpkg", VersionFormat: "dpkg",
Type: database.SourcePackage,
}, },
Namespace: database.Namespace{ Namespace: database.Namespace{
Name: "debian:7", Name: "debian:7",
@ -173,30 +93,41 @@ func TestFindAffectedNamespacedFeatures(t *testing.T) {
} }
func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature { func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature {
rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format types, err := tx.getFeatureTypeMap()
if err != nil {
panic(err)
}
rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, f.type, n.name, n.version_format
FROM feature AS f, namespace AS n, namespaced_feature AS nf FROM feature AS f, namespace AS n, namespaced_feature AS nf
WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`) WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`)
if err != nil { if err != nil {
t.Error(err) panic(err)
t.FailNow()
} }
nf := []database.NamespacedFeature{} nf := []database.NamespacedFeature{}
for rows.Next() { for rows.Next() {
f := database.NamespacedFeature{} f := database.NamespacedFeature{}
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat) var typeID int
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID, &f.Namespace.Name, &f.Namespace.VersionFormat)
if err != nil { if err != nil {
t.Error(err) panic(err)
t.FailNow()
} }
f.Type = types.byID[typeID]
nf = append(nf, f) nf = append(nf, f)
} }
return nf return nf
} }
func listFeatures(t *testing.T, tx *pgSession) []database.Feature { func selectAllFeatures(t *testing.T, tx *pgSession) []database.Feature {
rows, err := tx.Query("SELECT name, version, version_format FROM feature") types, err := tx.getFeatureTypeMap()
if err != nil {
panic(err)
}
rows, err := tx.Query("SELECT name, version, version_format, type FROM feature")
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
@ -204,7 +135,9 @@ func listFeatures(t *testing.T, tx *pgSession) []database.Feature {
fs := []database.Feature{} fs := []database.Feature{}
for rows.Next() { for rows.Next() {
f := database.Feature{} f := database.Feature{}
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat) var typeID int
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &typeID)
f.Type = types.byID[typeID]
if err != nil { if err != nil {
t.FailNow() t.FailNow()
} }
@ -233,3 +166,33 @@ func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFe
} }
return false return false
} }
func TestFindNamespacedFeatureIDs(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNamespacedFeatureIDs")
defer cleanup()
features := []database.NamespacedFeature{}
expectedIDs := []int{}
for id, feature := range realNamespacedFeatures {
features = append(features, feature)
expectedIDs = append(expectedIDs, id)
}
features = append(features, realNamespacedFeatures[1]) // test duplicated
expectedIDs = append(expectedIDs, 1)
namespace := realNamespaces[1]
features = append(features, *database.NewNamespacedFeature(&namespace, database.NewBinaryPackage("not-found", "1.0", "dpkg"))) // test not found feature
ids, err := tx.findNamespacedFeatureIDs(features)
require.Nil(t, err)
require.Len(t, ids, len(expectedIDs)+1)
for i, id := range ids {
if i == len(ids)-1 {
require.False(t, id.Valid)
} else {
require.True(t, id.Valid)
require.Equal(t, expectedIDs[i], int(id.Int64))
}
}
}

View File

@ -37,9 +37,10 @@ const (
SELECT id FROM layer WHERE hash = $1` SELECT id FROM layer WHERE hash = $1`
findLayerFeatures = ` findLayerFeatures = `
SELECT f.name, f.version, f.version_format, lf.detector_id SELECT f.name, f.version, f.version_format, t.name, lf.detector_id
FROM layer_feature AS lf, feature AS f FROM layer_feature AS lf, feature AS f, feature_type AS t
WHERE lf.feature_id = f.id WHERE lf.feature_id = f.id
AND t.id = f.type
AND lf.layer_id = $1` AND lf.layer_id = $1`
findLayerNamespaces = ` findLayerNamespaces = `
@ -307,7 +308,7 @@ func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]
detectorID int64 detectorID int64
feature database.LayerFeature feature database.LayerFeature
) )
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &detectorID); err != nil { if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &feature.Type, &detectorID); err != nil {
return nil, handleError("findLayerFeatures", err) return nil, handleError("findLayerFeatures", err)
} }

View File

@ -43,12 +43,12 @@ var persistLayerTests = []struct {
features: []database.LayerFeature{ features: []database.LayerFeature{
{realFeatures[1], realDetectors[1]}, {realFeatures[1], realDetectors[1]},
}, },
err: "database: parameters are not valid", err: "parameters are not valid",
}, },
{ {
title: "layer with non-existing feature", title: "layer with non-existing feature",
name: "random-forest", name: "random-forest",
err: "database: associated immutable entities are missing in the database", err: "associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[2]}, by: []database.Detector{realDetectors[2]},
features: []database.LayerFeature{ features: []database.LayerFeature{
{fakeFeatures[1], realDetectors[2]}, {fakeFeatures[1], realDetectors[2]},
@ -57,7 +57,7 @@ var persistLayerTests = []struct {
{ {
title: "layer with non-existing namespace", title: "layer with non-existing namespace",
name: "random-forest2", name: "random-forest2",
err: "database: associated immutable entities are missing in the database", err: "associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[1]}, by: []database.Detector{realDetectors[1]},
namespaces: []database.LayerNamespace{ namespaces: []database.LayerNamespace{
{fakeNamespaces[1], realDetectors[1]}, {fakeNamespaces[1], realDetectors[1]},
@ -66,7 +66,7 @@ var persistLayerTests = []struct {
{ {
title: "layer with non-existing detector", title: "layer with non-existing detector",
name: "random-forest3", name: "random-forest3",
err: "database: associated immutable entities are missing in the database", err: "associated immutable entities are missing in the database",
by: []database.Detector{fakeDetector[1]}, by: []database.Detector{fakeDetector[1]},
}, },
{ {

View File

@ -211,8 +211,8 @@ func TestInsertVulnerabilityNotifications(t *testing.T) {
} }
func TestFindNewNotification(t *testing.T) { func TestFindNewNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindNewNotification", true) tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification")
defer closeTest(t, datastore, tx) defer cleanup()
noti, ok, err := tx.FindNewNotification(time.Now()) noti, ok, err := tx.FindNewNotification(time.Now())
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
@ -229,7 +229,7 @@ func TestFindNewNotification(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
assert.False(t, ok) assert.False(t, ok)
// can find the notified after a period of time // can find the notified after a period of time
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second)))
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
assert.NotEqual(t, time.Time{}, noti.Notified) assert.NotEqual(t, time.Time{}, noti.Notified)

View File

@ -109,6 +109,7 @@ func dropTemplateDatabase(url string, name string) {
} }
} }
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
fURL, fName := genTemplateDatabase("fixture", true) fURL, fName := genTemplateDatabase("fixture", true)
nfURL, nfName := genTemplateDatabase("nonfixture", false) nfURL, nfName := genTemplateDatabase("nonfixture", false)

View File

@ -52,21 +52,22 @@ func querySearchNotDeletedVulnerabilityID(count int) string {
func querySearchFeatureID(featureCount int) string { func querySearchFeatureID(featureCount int) string {
return fmt.Sprintf(` return fmt.Sprintf(`
SELECT id, name, version, version_format SELECT id, name, version, version_format, type
FROM Feature WHERE (name, version, version_format) IN (%s)`, FROM Feature WHERE (name, version, version_format, type) IN (%s)`,
queryString(3, featureCount), queryString(4, featureCount),
) )
} }
func querySearchNamespacedFeature(nsfCount int) string { func querySearchNamespacedFeature(nsfCount int) string {
return fmt.Sprintf(` return fmt.Sprintf(`
SELECT nf.id, f.name, f.version, f.version_format, n.name SELECT nf.id, f.name, f.version, f.version_format, t.name, n.name
FROM namespaced_feature AS nf, feature AS f, namespace AS n FROM namespaced_feature AS nf, feature AS f, namespace AS n, feature_type AS t
WHERE nf.feature_id = f.id WHERE nf.feature_id = f.id
AND nf.namespace_id = n.id AND nf.namespace_id = n.id
AND n.version_format = f.version_format AND n.version_format = f.version_format
AND (f.name, f.version, f.version_format, n.name) IN (%s)`, AND f.type = t.id
queryString(4, nsfCount), AND (f.name, f.version, f.version_format, t.name, n.name) IN (%s)`,
queryString(5, nsfCount),
) )
} }
@ -110,10 +111,11 @@ func queryInsertNotifications(count int) string {
func queryPersistFeature(count int) string { func queryPersistFeature(count int) string {
return queryPersist(count, return queryPersist(count,
"feature", "feature",
"feature_name_version_version_format_key", "feature_name_version_version_format_type_key",
"name", "name",
"version", "version",
"version_format") "version_format",
"type")
} }
func queryPersistLayerFeature(count int) string { func queryPersistLayerFeature(count int) string {

View File

@ -4,11 +4,12 @@ INSERT INTO namespace (id, name, version_format) VALUES
(2, 'debian:8', 'dpkg'), (2, 'debian:8', 'dpkg'),
(3, 'fake:1.0', 'rpm'); (3, 'fake:1.0', 'rpm');
INSERT INTO feature (id, name, version, version_format) VALUES INSERT INTO feature (id, name, version, version_format, type) VALUES
(1, 'ourchat', '0.5', 'dpkg'), (1, 'ourchat', '0.5', 'dpkg', 1),
(2, 'openssl', '1.0', 'dpkg'), (2, 'openssl', '1.0', 'dpkg', 1),
(3, 'openssl', '2.0', 'dpkg'), (3, 'openssl', '2.0', 'dpkg', 1),
(4, 'fake', '2.0', 'rpm'); (4, 'fake', '2.0', 'rpm', 1),
(5, 'mount', '2.31.1-0.4ubuntu3.1', 'dpkg', 2);
INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES
(1, 1, 1), -- ourchat 0.5, debian:7 (1, 1, 1), -- ourchat 0.5, debian:7
@ -112,9 +113,9 @@ INSERT INTO vulnerability (id, namespace_id, name, description, link, severity)
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES
(3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483'); (3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483');
INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin, feature_type) VALUES
(1, 1, 'openssl', '2.0', '2.0'), (1, 1, 'openssl', '2.0', '2.0', 1),
(2, 1, 'libssl', '1.9-abc', '1.9-abc'); (2, 1, 'libssl', '1.9-abc', '1.9-abc', 1);
INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES
(1, 1, 2, 1); (1, 1, 2, 1);

View File

@ -39,15 +39,15 @@ const (
` `
insertVulnerabilityAffected = ` insertVulnerabilityAffected = `
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin) INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, feature_type, fixedin)
VALUES ($1, $2, $3, $4) VALUES ($1, $2, $3, $4, $5)
RETURNING ID RETURNING ID
` `
searchVulnerabilityAffected = ` searchVulnerabilityAffected = `
SELECT vulnerability_id, feature_name, affected_version, fixedin SELECT vulnerability_id, feature_name, affected_version, t.name, fixedin
FROM vulnerability_affected_feature FROM vulnerability_affected_feature AS vaf, feature_type AS t
WHERE vulnerability_id = ANY($1) WHERE t.id = vaf.feature_type AND vulnerability_id = ANY($1)
` `
searchVulnerabilityByID = ` searchVulnerabilityByID = `
@ -58,7 +58,7 @@ const (
searchVulnerabilityPotentialAffected = ` searchVulnerabilityPotentialAffected = `
WITH req AS ( WITH req AS (
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, vaf.feature_type AS type, v.id AS vulnerability_id
FROM vulnerability_affected_feature AS vaf, FROM vulnerability_affected_feature AS vaf,
vulnerability AS v, vulnerability AS v,
namespace AS n namespace AS n
@ -69,6 +69,7 @@ const (
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
FROM feature AS f, namespaced_feature AS nf, req FROM feature AS f, namespaced_feature AS nf, req
WHERE f.name = req.name WHERE f.name = req.name
AND f.type = req.type
AND nf.namespace_id = req.n_id AND nf.namespace_id = req.n_id
AND nf.feature_id = f.id` AND nf.feature_id = f.id`
@ -180,7 +181,7 @@ func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.Vulnerabilit
f database.AffectedFeature f database.AffectedFeature
) )
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion) err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FeatureType, &f.FixedInVersion)
if err != nil { if err != nil {
return nil, handleError("searchVulnerabilityAffected", err) return nil, handleError("searchVulnerabilityAffected", err)
} }
@ -220,6 +221,11 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
affectedID int64 affectedID int64
) )
types, err := tx.getFeatureTypeMap()
if err != nil {
return nil, err
}
//TODO(Sida): Change to bulk insert. //TODO(Sida): Change to bulk insert.
stmt, err := tx.Prepare(insertVulnerabilityAffected) stmt, err := tx.Prepare(insertVulnerabilityAffected)
if err != nil { if err != nil {
@ -231,7 +237,7 @@ func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulne
// affected feature row ID -> affected feature // affected feature row ID -> affected feature
affectedFeatures := map[int64]database.AffectedFeature{} affectedFeatures := map[int64]database.AffectedFeature{}
for _, f := range vuln.Affected { for _, f := range vuln.Affected {
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID) err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, types.byName[f.FeatureType], f.FixedInVersion).Scan(&affectedID)
if err != nil { if err != nil {
return nil, handleError("insertVulnerabilityAffected", err) return nil, handleError("insertVulnerabilityAffected", err)
} }

View File

@ -106,6 +106,7 @@ func TestCachingVulnerable(t *testing.T) {
Name: "openssl", Name: "openssl",
Version: "1.0", Version: "1.0",
VersionFormat: dpkg.ParserName, VersionFormat: dpkg.ParserName,
Type: database.SourcePackage,
}, },
Namespace: ns, Namespace: ns,
} }
@ -120,6 +121,7 @@ func TestCachingVulnerable(t *testing.T) {
{ {
Namespace: ns, Namespace: ns,
FeatureName: "openssl", FeatureName: "openssl",
FeatureType: database.SourcePackage,
AffectedVersion: "2.0", AffectedVersion: "2.0",
FixedInVersion: "2.1", FixedInVersion: "2.1",
}, },
@ -136,6 +138,7 @@ func TestCachingVulnerable(t *testing.T) {
{ {
Namespace: ns, Namespace: ns,
FeatureName: "openssl", FeatureName: "openssl",
FeatureType: database.SourcePackage,
AffectedVersion: "2.1", AffectedVersion: "2.1",
FixedInVersion: "2.2", FixedInVersion: "2.2",
}, },
@ -209,12 +212,14 @@ func TestFindVulnerabilities(t *testing.T) {
Affected: []database.AffectedFeature{ Affected: []database.AffectedFeature{
{ {
FeatureName: "openssl", FeatureName: "openssl",
FeatureType: database.SourcePackage,
AffectedVersion: "2.0", AffectedVersion: "2.0",
FixedInVersion: "2.0", FixedInVersion: "2.0",
Namespace: ns, Namespace: ns,
}, },
{ {
FeatureName: "libssl", FeatureName: "libssl",
FeatureType: database.SourcePackage,
AffectedVersion: "1.9-abc", AffectedVersion: "1.9-abc",
FixedInVersion: "1.9-abc", FixedInVersion: "1.9-abc",
Namespace: ns, Namespace: ns,
@ -318,26 +323,3 @@ func TestFindVulnerabilityIDs(t *testing.T) {
} }
} }
} }
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
}
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.AffectedFeature]bool{}
for _, i := range expected {
has[i] = false
}
for _, i := range actual {
if visited, ok := has[i]; !ok {
return false
} else if visited {
return false
}
has[i] = true
}
return true
}
return false
}