pgsql: Implement database queries for detector relationship

* Refactor layer and ancestry
* Add tests
* Fix bugs introduced when the queries were moved
This commit is contained in:
Sida Chen 2018-10-08 11:11:30 -04:00
parent 028324014b
commit 0c1b80b2ed
21 changed files with 1806 additions and 907 deletions

View File

@ -14,15 +14,17 @@ const (
insertAncestry = ` insertAncestry = `
INSERT INTO ancestry (name) VALUES ($1) RETURNING id` INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
searchAncestryLayer = ` findAncestryLayerHashes = `
SELECT layer.hash, layer.id, ancestry_layer.ancestry_index SELECT layer.hash, ancestry_layer.ancestry_index
FROM layer, ancestry_layer FROM layer, ancestry_layer
WHERE ancestry_layer.ancestry_id = $1 WHERE ancestry_layer.ancestry_id = $1
AND ancestry_layer.layer_id = layer.id AND ancestry_layer.layer_id = layer.id
ORDER BY ancestry_layer.ancestry_index ASC` ORDER BY ancestry_layer.ancestry_index ASC`
searchAncestryFeatures = ` findAncestryFeatures = `
SELECT namespace.name, namespace.version_format, feature.name, feature.version, feature.version_format, ancestry_layer.ancestry_index SELECT namespace.name, namespace.version_format, feature.name,
feature.version, feature.version_format, ancestry_layer.ancestry_index,
ancestry_feature.feature_detector_id, ancestry_feature.namespace_detector_id
FROM namespace, feature, namespaced_feature, ancestry_layer, ancestry_feature FROM namespace, feature, namespaced_feature, ancestry_layer, ancestry_feature
WHERE ancestry_layer.ancestry_id = $1 WHERE ancestry_layer.ancestry_id = $1
AND ancestry_feature.ancestry_layer_id = ancestry_layer.id AND ancestry_feature.ancestry_layer_id = ancestry_layer.id
@ -30,88 +32,17 @@ const (
AND namespaced_feature.feature_id = feature.id AND namespaced_feature.feature_id = feature.id
AND namespaced_feature.namespace_id = namespace.id` AND namespaced_feature.namespace_id = namespace.id`
searchAncestry = `SELECT id FROM ancestry WHERE name = $1` findAncestryID = `SELECT id FROM ancestry WHERE name = $1`
removeAncestry = `DELETE FROM ancestry WHERE name = $1` removeAncestry = `DELETE FROM ancestry WHERE name = $1`
insertAncestryLayer = ` insertAncestryLayers = `
INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES INSERT INTO ancestry_layer (ancestry_id, ancestry_index, layer_id) VALUES ($1, $2, $3)
($1, $2, (SELECT layer.id FROM layer WHERE hash = $3 LIMIT 1))
RETURNING id` RETURNING id`
insertAncestryLayerFeature = ` insertAncestryFeatures = `
INSERT INTO ancestry_feature INSERT INTO ancestry_feature
(ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES (ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES
($1, $2, $3, $4)` ($1, $2, $3, $4)`
) )
type ancestryLayerWithID struct {
database.AncestryLayer
layerID int64
}
func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry) error {
if ancestry.Name == "" {
log.Error("Empty ancestry name is not allowed")
return commonerr.NewBadRequestError("could not insert an ancestry with empty name")
}
if len(ancestry.Layers) == 0 {
log.Error("Empty ancestry is not allowed")
return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers")
}
if err := tx.deleteAncestry(ancestry.Name); err != nil {
return err
}
var ancestryID int64
if err := tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID); err != nil {
if isErrUniqueViolation(err) {
return handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)"))
}
return handleError("insertAncestry", err)
}
if err := tx.insertAncestryLayers(ancestryID, ancestry.Layers); err != nil {
return err
}
return tx.persistProcessors(persistAncestryLister,
"persistAncestryLister",
persistAncestryDetector,
"persistAncestryDetector",
ancestryID, ancestry.ProcessedBy)
}
func (tx *pgSession) findAncestryID(name string) (int64, bool, error) {
var id sql.NullInt64
if err := tx.QueryRow(searchAncestry, name).Scan(&id); err != nil {
if err == sql.ErrNoRows {
return 0, false, nil
}
return 0, false, handleError("searchAncestry", err)
}
return id.Int64, true, nil
}
func (tx *pgSession) findAncestryProcessors(id int64) (database.Processors, error) {
var (
processors database.Processors
err error
)
if processors.Detectors, err = tx.findProcessors(searchAncestryDetectors, id); err != nil {
return processors, handleError("searchAncestryDetectors", err)
}
if processors.Listers, err = tx.findProcessors(searchAncestryListers, id); err != nil {
return processors, handleError("searchAncestryListers", err)
}
return processors, err
}
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) { func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error) {
var ( var (
ancestry = database.Ancestry{Name: name} ancestry = database.Ancestry{Name: name}
@ -123,7 +54,7 @@ func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error)
return ancestry, ok, err return ancestry, ok, err
} }
if ancestry.ProcessedBy, err = tx.findAncestryProcessors(id); err != nil { if ancestry.By, err = tx.findAncestryDetectors(id); err != nil {
return ancestry, false, err return ancestry, false, err
} }
@ -134,99 +65,187 @@ func (tx *pgSession) FindAncestry(name string) (database.Ancestry, bool, error)
return ancestry, true, nil return ancestry, true, nil
} }
func (tx *pgSession) deleteAncestry(name string) error { func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry) error {
result, err := tx.Exec(removeAncestry, name) if !ancestry.Valid() {
if err != nil { return database.ErrInvalidParameters
return handleError("removeAncestry", err)
} }
_, err = result.RowsAffected() if err := tx.removeAncestry(ancestry.Name); err != nil {
return err
}
id, err := tx.insertAncestry(ancestry.Name)
if err != nil { if err != nil {
return handleError("removeAncestry", err) return err
}
detectorIDs, err := tx.findDetectorIDs(ancestry.By)
if err != nil {
return err
}
// insert ancestry metadata
if err := tx.insertAncestryDetectors(id, detectorIDs); err != nil {
return err
}
layers := make([]string, 0, len(ancestry.Layers))
for _, layer := range ancestry.Layers {
layers = append(layers, layer.Hash)
}
layerIDs, ok, err := tx.findLayerIDs(layers)
if err != nil {
return err
}
if !ok {
log.Error("layer cannot be found, this indicates that the internal logic of calling UpsertAncestry is wrong or the database is corrupted.")
return database.ErrMissingEntities
}
ancestryLayerIDs, err := tx.insertAncestryLayers(id, layerIDs)
if err != nil {
return err
}
for i, id := range ancestryLayerIDs {
if err := tx.insertAncestryFeatures(id, ancestry.Layers[i]); err != nil {
return err
}
} }
return nil return nil
} }
func (tx *pgSession) findProcessors(query string, id int64) ([]string, error) { func (tx *pgSession) insertAncestry(name string) (int64, error) {
var ( var id int64
processors []string err := tx.QueryRow(insertAncestry, name).Scan(&id)
processor string
)
rows, err := tx.Query(query, id)
if err != nil { if err != nil {
if isErrUniqueViolation(err) {
return 0, handleError("insertAncestry", errors.New("other Go-routine is processing this ancestry (skip)"))
}
return 0, handleError("insertAncestry", err)
}
log.WithFields(log.Fields{"ancestry": name, "id": id}).Debug("database: inserted ancestry")
return id, nil
}
func (tx *pgSession) findAncestryID(name string) (int64, bool, error) {
var id sql.NullInt64
if err := tx.QueryRow(findAncestryID, name).Scan(&id); err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return 0, false, nil
} }
return nil, err return 0, false, handleError("findAncestryID", err)
} }
for rows.Next() { return id.Int64, true, nil
if err := rows.Scan(&processor); err != nil { }
return nil, err
}
processors = append(processors, processor) func (tx *pgSession) removeAncestry(name string) error {
result, err := tx.Exec(removeAncestry, name)
if err != nil {
return handleError("removeAncestry", err)
} }
return processors, nil affected, err := result.RowsAffected()
if err != nil {
return handleError("removeAncestry", err)
}
if affected != 0 {
log.WithField("ancestry", name).Debug("removed ancestry")
}
return nil
} }
func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) { func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, error) {
var ( detectors, err := tx.findAllDetectors()
err error if err != nil {
rows *sql.Rows return nil, err
// layer index -> Ancestry Layer + Layer ID
layers = map[int64]ancestryLayerWithID{}
// layer index -> layer-wise features
features = map[int64][]database.NamespacedFeature{}
ancestryLayers []database.AncestryLayer
)
// retrieve ancestry layer metadata
if rows, err = tx.Query(searchAncestryLayer, id); err != nil {
return nil, handleError("searchAncestryLayer", err)
} }
layerMap, err := tx.findAncestryLayerHashes(id)
if err != nil {
return nil, err
}
log.WithField("map", layerMap).Debug("found layer hashes")
featureMap, err := tx.findAncestryFeatures(id, detectors)
if err != nil {
return nil, err
}
layers := make([]database.AncestryLayer, len(layerMap))
for index, layer := range layerMap {
// index MUST match the ancestry layer slice index.
if layers[index].Hash == "" && len(layers[index].Features) == 0 {
layers[index] = database.AncestryLayer{
Hash: layer,
Features: featureMap[index],
}
} else {
log.WithFields(log.Fields{
"ancestry ID": id,
"duplicated ancestry index": index,
}).WithError(database.ErrInconsistent).Error("ancestry layers with same ancestry_index is not allowed")
return nil, database.ErrInconsistent
}
}
return layers, nil
}
func (tx *pgSession) findAncestryLayerHashes(ancestryID int64) (map[int64]string, error) {
// retrieve layer indexes and hashes
rows, err := tx.Query(findAncestryLayerHashes, ancestryID)
if err != nil {
return nil, handleError("findAncestryLayerHashes", err)
}
layerHashes := map[int64]string{}
for rows.Next() { for rows.Next() {
var ( var (
layer database.AncestryLayer hash string
index sql.NullInt64 index int64
id sql.NullInt64
) )
if err = rows.Scan(&layer.Hash, &id, &index); err != nil { if err = rows.Scan(&hash, &index); err != nil {
return nil, handleError("searchAncestryLayer", err) return nil, handleError("findAncestryLayerHashes", err)
} }
if !index.Valid || !id.Valid { if _, ok := layerHashes[index]; ok {
panic("null ancestry ID or ancestry index violates database constraints")
}
if _, ok := layers[index.Int64]; ok {
// one ancestry index should correspond to only one layer // one ancestry index should correspond to only one layer
return nil, database.ErrInconsistent return nil, database.ErrInconsistent
} }
layers[index.Int64] = ancestryLayerWithID{layer, id.Int64} layerHashes[index] = hash
} }
for _, layer := range layers { return layerHashes, nil
if layer.ProcessedBy, err = tx.findLayerProcessors(layer.layerID); err != nil { }
return nil, err
}
}
func (tx *pgSession) findAncestryFeatures(ancestryID int64, detectors detectorMap) (map[int64][]database.AncestryFeature, error) {
// ancestry_index -> ancestry features
featureMap := make(map[int64][]database.AncestryFeature)
// retrieve ancestry layer's namespaced features // retrieve ancestry layer's namespaced features
if rows, err = tx.Query(searchAncestryFeatures, id); err != nil { rows, err := tx.Query(findAncestryFeatures, ancestryID)
return nil, handleError("searchAncestryFeatures", err) if err != nil {
return nil, handleError("findAncestryFeatures", err)
} }
defer rows.Close()
for rows.Next() { for rows.Next() {
var ( var (
feature database.NamespacedFeature featureDetectorID int64
namespaceDetectorID int64
feature database.NamespacedFeature
// index is used to determine which layer the feature belongs to. // index is used to determine which layer the feature belongs to.
index sql.NullInt64 index sql.NullInt64
) )
@ -238,8 +257,10 @@ func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, err
&feature.Feature.Version, &feature.Feature.Version,
&feature.Feature.VersionFormat, &feature.Feature.VersionFormat,
&index, &index,
&featureDetectorID,
&namespaceDetectorID,
); err != nil { ); err != nil {
return nil, handleError("searchAncestryFeatures", err) return nil, handleError("findAncestryFeatures", err)
} }
if feature.Feature.VersionFormat != feature.Namespace.VersionFormat { if feature.Feature.VersionFormat != feature.Namespace.VersionFormat {
@ -248,59 +269,88 @@ func (tx *pgSession) findAncestryLayers(id int64) ([]database.AncestryLayer, err
return nil, database.ErrInconsistent return nil, database.ErrInconsistent
} }
features[index.Int64] = append(features[index.Int64], feature) fDetector, ok := detectors.byID[featureDetectorID]
if !ok {
return nil, database.ErrInconsistent
}
nsDetector, ok := detectors.byID[namespaceDetectorID]
if !ok {
return nil, database.ErrInconsistent
}
featureMap[index.Int64] = append(featureMap[index.Int64], database.AncestryFeature{
NamespacedFeature: feature,
FeatureBy: fDetector,
NamespaceBy: nsDetector,
})
} }
for index, layer := range layers { return featureMap, nil
layer.DetectedFeatures = features[index]
ancestryLayers = append(ancestryLayers, layer.AncestryLayer)
}
return ancestryLayers, nil
} }
// insertAncestryLayers inserts the ancestry layers along with its content into // insertAncestryLayers inserts the ancestry layers along with its content into
// the database. The layers are 0 based indexed in the original order. // the database. The layers are 0 based indexed in the original order.
func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []database.AncestryLayer) error { func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []int64) ([]int64, error) {
//TODO(Sida): use bulk insert. stmt, err := tx.Prepare(insertAncestryLayers)
stmt, err := tx.Prepare(insertAncestryLayer)
if err != nil { if err != nil {
return handleError("insertAncestryLayer", err) return nil, handleError("insertAncestryLayers", err)
} }
ancestryLayerIDs := []sql.NullInt64{} ancestryLayerIDs := []int64{}
for index, layer := range layers { for index, layerID := range layers {
var ancestryLayerID sql.NullInt64 var ancestryLayerID sql.NullInt64
if err := stmt.QueryRow(ancestryID, index, layer.Hash).Scan(&ancestryLayerID); err != nil { if err := stmt.QueryRow(ancestryID, index, layerID).Scan(&ancestryLayerID); err != nil {
return handleError("insertAncestryLayer", commonerr.CombineErrors(err, stmt.Close())) return nil, handleError("insertAncestryLayers", commonerr.CombineErrors(err, stmt.Close()))
} }
ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID) if !ancestryLayerID.Valid {
return nil, database.ErrInconsistent
}
ancestryLayerIDs = append(ancestryLayerIDs, ancestryLayerID.Int64)
} }
if err := stmt.Close(); err != nil { if err := stmt.Close(); err != nil {
return handleError("Failed to close insertAncestryLayer statement", err) return nil, handleError("insertAncestryLayers", err)
}
return ancestryLayerIDs, nil
}
func (tx *pgSession) insertAncestryFeatures(ancestryLayerID int64, layer database.AncestryLayer) error {
detectors, err := tx.findAllDetectors()
if err != nil {
return err
}
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(layer.GetFeatures())
if err != nil {
return err
}
// find the detectors for each feature
stmt, err := tx.Prepare(insertAncestryFeatures)
if err != nil {
return handleError("insertAncestryFeatures", err)
} }
stmt, err = tx.Prepare(insertAncestryLayerFeature)
defer stmt.Close() defer stmt.Close()
for i, layer := range layers { for index, id := range nsFeatureIDs {
var ( namespaceDetectorID, ok := detectors.byValue[layer.Features[index].NamespaceBy]
nsFeatureIDs []sql.NullInt64 if !ok {
layerID = ancestryLayerIDs[i] return database.ErrMissingEntities
)
if nsFeatureIDs, err = tx.findNamespacedFeatureIDs(layer.DetectedFeatures); err != nil {
return err
} }
for _, id := range nsFeatureIDs { featureDetectorID, ok := detectors.byValue[layer.Features[index].FeatureBy]
if _, err := stmt.Exec(layerID, id); err != nil { if !ok {
return handleError("insertAncestryLayerFeature", commonerr.CombineErrors(err, stmt.Close())) return database.ErrMissingEntities
}
} }
if _, err := stmt.Exec(ancestryLayerID, id, featureDetectorID, namespaceDetectorID); err != nil {
return handleError("insertAncestryFeatures", commonerr.CombineErrors(err, stmt.Close()))
}
} }
return nil return nil

View File

@ -15,198 +15,125 @@
package pgsql package pgsql
import ( import (
"sort"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/testutil"
) )
var upsertAncestryTests = []struct {
in *database.Ancestry
err string
title string
}{
{
title: "ancestry with invalid layer",
in: &database.Ancestry{
Name: "a1",
Layers: []database.AncestryLayer{
{
Hash: "layer-non-existing",
},
},
},
err: database.ErrMissingEntities.Error(),
},
{
title: "ancestry with invalid name",
in: &database.Ancestry{},
err: database.ErrInvalidParameters.Error(),
},
{
title: "new valid ancestry",
in: &database.Ancestry{
Name: "a",
Layers: []database.AncestryLayer{{Hash: "layer-0"}},
},
},
{
title: "ancestry with invalid feature",
in: &database.Ancestry{
Name: "a",
By: []database.Detector{realDetectors[1], realDetectors[2]},
Layers: []database.AncestryLayer{{Hash: "layer-1", Features: []database.AncestryFeature{
{fakeNamespacedFeatures[1], fakeDetector[1], fakeDetector[2]},
}}},
},
err: database.ErrMissingEntities.Error(),
},
{
title: "replace old ancestry",
in: &database.Ancestry{
Name: "a",
By: []database.Detector{realDetectors[1], realDetectors[2]},
Layers: []database.AncestryLayer{
{"layer-1", []database.AncestryFeature{{realNamespacedFeatures[1], realDetectors[2], realDetectors[1]}}},
},
},
},
}
func TestUpsertAncestry(t *testing.T) { func TestUpsertAncestry(t *testing.T) {
store, tx := openSessionForTest(t, "UpsertAncestry", true) store, tx := openSessionForTest(t, "UpsertAncestry", true)
defer closeTest(t, store, tx) defer closeTest(t, store, tx)
a1 := database.Ancestry{ for _, test := range upsertAncestryTests {
Name: "a1", t.Run(test.title, func(t *testing.T) {
Layers: []database.AncestryLayer{ err := tx.UpsertAncestry(*test.in)
{ if test.err != "" {
LayerMetadata: database.LayerMetadata{ assert.EqualError(t, err, test.err, "unexpected error")
Hash: "layer-N", return
},
},
},
}
a2 := database.Ancestry{}
a3 := database.Ancestry{
Name: "a",
Layers: []database.AncestryLayer{
{
LayerMetadata: database.LayerMetadata{
Hash: "layer-0",
},
},
},
}
a4 := database.Ancestry{
Name: "a",
Layers: []database.AncestryLayer{
{
LayerMetadata: database.LayerMetadata{
Hash: "layer-1",
},
},
},
}
f1 := database.Feature{
Name: "wechat",
Version: "0.5",
VersionFormat: "dpkg",
}
// not in database
f2 := database.Feature{
Name: "wechat",
Version: "0.6",
VersionFormat: "dpkg",
}
n1 := database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
}
p := database.Processors{
Listers: []string{"dpkg", "non-existing"},
Detectors: []string{"os-release", "non-existing"},
}
nsf1 := database.NamespacedFeature{
Namespace: n1,
Feature: f1,
}
// not in database
nsf2 := database.NamespacedFeature{
Namespace: n1,
Feature: f2,
}
a4.ProcessedBy = p
// invalid case
assert.NotNil(t, tx.UpsertAncestry(a1))
assert.NotNil(t, tx.UpsertAncestry(a2))
// valid case
assert.Nil(t, tx.UpsertAncestry(a3))
a4.Layers[0].DetectedFeatures = []database.NamespacedFeature{nsf1, nsf2}
// replace invalid case
assert.NotNil(t, tx.UpsertAncestry(a4))
a4.Layers[0].DetectedFeatures = []database.NamespacedFeature{nsf1}
// replace valid case
assert.Nil(t, tx.UpsertAncestry(a4))
// validate
ancestry, ok, err := tx.FindAncestry("a")
assert.Nil(t, err)
assert.True(t, ok)
assertAncestryEqual(t, a4, ancestry)
}
func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool {
sort.Strings(expected.Detectors)
sort.Strings(actual.Detectors)
sort.Strings(expected.Listers)
sort.Strings(actual.Listers)
return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers)
}
func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool {
assert.Equal(t, expected.Name, actual.Name)
assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy)
if assert.Equal(t, len(expected.Layers), len(actual.Layers)) {
for index, layer := range expected.Layers {
if !assertAncestryLayerEqual(t, layer, actual.Layers[index]) {
return false
} }
} assert.Nil(t, err)
return true actual, ok, err := tx.FindAncestry(test.in.Name)
assert.Nil(t, err)
assert.True(t, ok)
testutil.AssertAncestryEqual(t, test.in, &actual)
})
} }
return false
} }
func assertAncestryLayerEqual(t *testing.T, expected database.AncestryLayer, actual database.AncestryLayer) bool { var findAncestryTests = []struct {
return assertLayerEqual(t, expected.LayerMetadata, actual.LayerMetadata) && title string
assertNamespacedFeatureEqual(t, expected.DetectedFeatures, actual.DetectedFeatures) in string
ancestry *database.Ancestry
err string
ok bool
}{
{
title: "missing ancestry",
in: "ancestry-non",
err: "",
ancestry: nil,
ok: false,
},
{
title: "valid ancestry",
in: "ancestry-2",
err: "",
ok: true,
ancestry: takeAncestryPointerFromMap(realAncestries, 2),
},
} }
func TestFindAncestry(t *testing.T) { func TestFindAncestry(t *testing.T) {
store, tx := openSessionForTest(t, "FindAncestry", true) store, tx := openSessionForTest(t, "FindAncestry", true)
defer closeTest(t, store, tx) defer closeTest(t, store, tx)
for _, test := range findAncestryTests {
t.Run(test.title, func(t *testing.T) {
ancestry, ok, err := tx.FindAncestry(test.in)
if test.err != "" {
assert.EqualError(t, err, test.err, "unexpected error")
return
}
// invalid assert.Nil(t, err)
_, ok, err := tx.FindAncestry("ancestry-non") assert.Equal(t, test.ok, ok)
if assert.Nil(t, err) { if test.ok {
assert.False(t, ok) testutil.AssertAncestryEqual(t, test.ancestry, &ancestry)
} }
})
expected := database.Ancestry{
Name: "ancestry-2",
ProcessedBy: database.Processors{
Detectors: []string{"os-release"},
Listers: []string{"dpkg"},
},
Layers: []database.AncestryLayer{
{
LayerMetadata: database.LayerMetadata{
Hash: "layer-0",
},
DetectedFeatures: []database.NamespacedFeature{
{
Namespace: database.Namespace{
Name: "debian:7",
VersionFormat: "dpkg",
},
Feature: database.Feature{
Name: "wechat",
Version: "0.5",
VersionFormat: "dpkg",
},
},
{
Namespace: database.Namespace{
Name: "debian:8",
VersionFormat: "dpkg",
},
Feature: database.Feature{
Name: "openssl",
Version: "1.0",
VersionFormat: "dpkg",
},
},
},
},
{
LayerMetadata: database.LayerMetadata{
Hash: "layer-1",
},
},
{
LayerMetadata: database.LayerMetadata{
Hash: "layer-2",
},
},
{
LayerMetadata: database.LayerMetadata{
Hash: "layer-3b",
},
},
},
}
// valid
ancestry, ok, err := tx.FindAncestry("ancestry-2")
if assert.Nil(t, err) && assert.True(t, ok) {
assertAncestryEqual(t, expected, ancestry)
} }
} }

View File

@ -220,7 +220,7 @@ func TestCaching(t *testing.T) {
actualAffectedNames = append(actualAffectedNames, s.Name) actualAffectedNames = append(actualAffectedNames, s.Name)
} }
assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) assert.Len(t, strutil.Difference(expectedAffectedNames, actualAffectedNames), 0)
assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) assert.Len(t, strutil.Difference(actualAffectedNames, expectedAffectedNames), 0)
} }
} }

198
database/pgsql/detector.go Normal file
View File

@ -0,0 +1,198 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"database/sql"
"github.com/deckarep/golang-set"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/database"
)
const (
soiDetector = `
INSERT INTO detector (name, version, dtype)
SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type )
WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);`
selectAncestryDetectors = `
SELECT d.name, d.version, d.dtype
FROM ancestry_detector, detector AS d
WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;`
selectLayerDetectors = `
SELECT d.name, d.version, d.dtype
FROM layer_detector, detector AS d
WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;`
insertAncestryDetectors = `
INSERT INTO ancestry_detector (ancestry_id, detector_id)
SELECT $1, $2
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)`
persistLayerDetector = `
INSERT INTO layer_detector (layer_id, detector_id)
SELECT $1, $2
WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)`
findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3`
findAllDetectors = `SELECT id, name, version, dtype FROM detector`
)
type detectorMap struct {
byID map[int64]database.Detector
byValue map[database.Detector]int64
}
func (tx *pgSession) PersistDetectors(detectors []database.Detector) error {
for _, d := range detectors {
if !d.Valid() {
log.WithField("detector", d).Debug("Invalid Detector")
return database.ErrInvalidParameters
}
r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType)
if err != nil {
return handleError("soiDetector", err)
}
count, err := r.RowsAffected()
if err != nil {
return handleError("soiDetector", err)
}
if count == 0 {
log.Debug("detector already exists: ", d)
}
}
return nil
}
func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error {
if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil {
return handleError("persistLayerDetector", err)
}
return nil
}
func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error {
alreadySaved := mapset.NewSet()
for _, id := range detectorIDs {
if alreadySaved.Contains(id) {
continue
}
alreadySaved.Add(id)
if err := tx.persistLayerDetector(layerID, id); err != nil {
return err
}
}
return nil
}
func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error {
for _, detectorID := range detectorIDs {
if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil {
return handleError("insertAncestryDetectors", err)
}
}
return nil
}
func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) {
detectors, err := tx.getDetectors(selectAncestryDetectors, id)
log.WithField("detectors", detectors).Debug("found ancestry detectors")
return detectors, err
}
func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) {
detectors, err := tx.getDetectors(selectLayerDetectors, id)
log.WithField("detectors", detectors).Debug("found layer detectors")
return detectors, err
}
// findDetectorIDs retrieve ids of the detectors from the database, if any is not
// found, return the error.
func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) {
ids := []int64{}
for _, d := range detectors {
id := sql.NullInt64{}
err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id)
if err != nil {
return nil, handleError("findDetectorID", err)
}
if !id.Valid {
return nil, database.ErrInconsistent
}
ids = append(ids, id.Int64)
}
return ids, nil
}
func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) {
rows, err := tx.Query(query, id)
if err != nil {
return nil, handleError("getDetectors", err)
}
detectors := []database.Detector{}
for rows.Next() {
d := database.Detector{}
err := rows.Scan(&d.Name, &d.Version, &d.DType)
if err != nil {
return nil, handleError("getDetectors", err)
}
if !d.Valid() {
return nil, database.ErrInvalidDetector
}
detectors = append(detectors, d)
}
return detectors, nil
}
func (tx *pgSession) findAllDetectors() (detectorMap, error) {
rows, err := tx.Query(findAllDetectors)
if err != nil {
return detectorMap{}, handleError("searchAllDetectors", err)
}
detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)}
for rows.Next() {
var (
id int64
d database.Detector
)
if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil {
return detectorMap{}, handleError("searchAllDetectors", err)
}
detectors.byID[id] = d
detectors.byValue[d] = id
}
return detectors, nil
}

View File

@ -0,0 +1,119 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"testing"
"github.com/deckarep/golang-set"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database"
)
func testGetAllDetectors(tx *pgSession) []database.Detector {
query := `SELECT name, version, dtype FROM detector`
rows, err := tx.Query(query)
if err != nil {
panic(err)
}
detectors := []database.Detector{}
for rows.Next() {
d := database.Detector{}
if err := rows.Scan(&d.Name, &d.Version, &d.DType); err != nil {
panic(err)
}
detectors = append(detectors, d)
}
return detectors
}
var persistDetectorTests = []struct {
title string
in []database.Detector
err string
}{
{
title: "invalid detector",
in: []database.Detector{
{},
database.NewFeatureDetector("name", "2.0"),
},
err: database.ErrInvalidParameters.Error(),
},
{
title: "invalid detector 2",
in: []database.Detector{
database.NewFeatureDetector("name", "2.0"),
{"name", "1.0", "random not valid dtype"},
},
err: database.ErrInvalidParameters.Error(),
},
{
title: "detectors with some different fields",
in: []database.Detector{
database.NewFeatureDetector("name", "2.0"),
database.NewFeatureDetector("name", "1.0"),
database.NewNamespaceDetector("name", "1.0"),
},
},
{
title: "duplicated detectors (parameter level)",
in: []database.Detector{
database.NewFeatureDetector("name", "1.0"),
database.NewFeatureDetector("name", "1.0"),
},
},
{
title: "duplicated detectors (db level)",
in: []database.Detector{
database.NewNamespaceDetector("os-release", "1.0"),
database.NewNamespaceDetector("os-release", "1.0"),
database.NewFeatureDetector("dpkg", "1.0"),
},
},
}
func TestPersistDetector(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistDetector", true)
defer closeTest(t, datastore, tx)
for _, test := range persistDetectorTests {
t.Run(test.title, func(t *testing.T) {
err := tx.PersistDetectors(test.in)
if test.err != "" {
require.EqualError(t, err, test.err)
return
}
detectors := testGetAllDetectors(tx)
// ensure no duplicated detectors
detectorSet := mapset.NewSet()
for _, d := range detectors {
require.False(t, detectorSet.Contains(d), "duplicated: %v", d)
detectorSet.Add(d)
}
// ensure all persisted detectors are actually saved
for _, d := range test.in {
require.True(t, detectorSet.Contains(d), "detector: %v, detectors: %v", d, detectorSet)
}
})
}
}

View File

@ -16,7 +16,6 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"errors"
"sort" "sort"
"github.com/lib/pq" "github.com/lib/pq"
@ -28,7 +27,6 @@ import (
) )
const ( const (
// feature.go
soiNamespacedFeature = ` soiNamespacedFeature = `
WITH new_feature_ns AS ( WITH new_feature_ns AS (
INSERT INTO namespaced_feature(feature_id, namespace_id) INSERT INTO namespaced_feature(feature_id, namespace_id)
@ -65,15 +63,6 @@ const (
AND v.deleted_at IS NULL` AND v.deleted_at IS NULL`
) )
var (
errFeatureNotFound = errors.New("Feature not found")
)
type vulnerabilityAffecting struct {
vulnerabilityID int64
addedByID int64
}
func (tx *pgSession) PersistFeatures(features []database.Feature) error { func (tx *pgSession) PersistFeatures(features []database.Feature) error {
if len(features) == 0 { if len(features) == 0 {
return nil return nil
@ -126,7 +115,7 @@ func (tx *pgSession) searchAffectingVulnerabilities(features []database.Namespac
fMap := map[int64]database.NamespacedFeature{} fMap := map[int64]database.NamespacedFeature{}
for i, f := range features { for i, f := range features {
if !ids[i].Valid { if !ids[i].Valid {
return nil, errFeatureNotFound return nil, database.ErrMissingEntities
} }
fMap[ids[i].Int64] = f fMap[ids[i].Int64] = f
} }
@ -218,7 +207,7 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea
if ids, err := tx.findFeatureIDs(fToFind); err == nil { if ids, err := tx.findFeatureIDs(fToFind); err == nil {
for i, id := range ids { for i, id := range ids {
if !id.Valid { if !id.Valid {
return errFeatureNotFound return database.ErrMissingEntities
} }
fIDs[fToFind[i]] = id fIDs[fToFind[i]] = id
} }
@ -234,7 +223,7 @@ func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFea
if ids, err := tx.findNamespaceIDs(nsToFind); err == nil { if ids, err := tx.findNamespaceIDs(nsToFind); err == nil {
for i, id := range ids { for i, id := range ids {
if !id.Valid { if !id.Valid {
return errNamespaceNotFound return database.ErrMissingEntities
} }
nsIDs[nsToFind[i]] = id nsIDs[nsToFind[i]] = id
} }

View File

@ -52,7 +52,7 @@ func TestPersistNamespacedFeatures(t *testing.T) {
// existing features // existing features
f1 := database.Feature{ f1 := database.Feature{
Name: "wechat", Name: "ourchat",
Version: "0.5", Version: "0.5",
VersionFormat: "dpkg", VersionFormat: "dpkg",
} }
@ -213,27 +213,6 @@ func listFeatures(t *testing.T, tx *pgSession) []database.Feature {
return fs return fs
} }
func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.Feature]bool{}
for _, nf := range expected {
has[nf] = false
}
for _, nf := range actual {
has[nf] = true
}
for nf, visited := range has {
if !assert.True(t, visited, nf.Name+" is expected") {
return false
}
return true
}
}
return false
}
func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool { func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
if assert.Len(t, actual, len(expected)) { if assert.Len(t, actual, len(expected)) {
has := map[database.NamespacedFeature]bool{} has := map[database.NamespacedFeature]bool{}

View File

@ -18,6 +18,8 @@ import (
"database/sql" "database/sql"
"sort" "sort"
"github.com/deckarep/golang-set"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
) )
@ -34,300 +36,331 @@ const (
UNION UNION
SELECT id FROM layer WHERE hash = $1` SELECT id FROM layer WHERE hash = $1`
searchLayerFeatures = ` findLayerFeatures = `
SELECT feature_id, detector_id SELECT f.name, f.version, f.version_format, lf.detector_id
FROM layer_feature FROM layer_feature AS lf, feature AS f
WHERE layer_id = $1` WHERE lf.feature_id = f.id
AND lf.layer_id = $1`
searchLayerNamespaces = ` findLayerNamespaces = `
SELECT namespace.Name, namespace.version_format SELECT ns.name, ns.version_format, ln.detector_id
FROM namespace, layer_namespace FROM layer_namespace AS ln, namespace AS ns
WHERE layer_namespace.layer_id = $1 WHERE ln.namespace_id = ns.id
AND layer_namespace.namespace_id = namespace.id` AND ln.layer_id = $1`
searchLayer = `SELECT id FROM layer WHERE hash = $1` findLayerID = `SELECT id FROM layer WHERE hash = $1`
) )
// dbLayerNamespace represents the layer_namespace table.
type dbLayerNamespace struct {
layerID int64
namespaceID int64
detectorID int64
}
// dbLayerFeature represents the layer_feature table
type dbLayerFeature struct {
layerID int64
featureID int64
detectorID int64
}
func (tx *pgSession) FindLayer(hash string) (database.Layer, bool, error) {
layer := database.Layer{Hash: hash}
if hash == "" {
return layer, false, commonerr.NewBadRequestError("non empty layer hash is expected.")
}
layerID, ok, err := tx.findLayerID(hash)
if err != nil || !ok {
return layer, ok, err
}
detectorMap, err := tx.findAllDetectors()
if err != nil { if err != nil {
return layer, false, err return layer, false, err
} }
if !ok { if layer.By, err = tx.findLayerDetectors(layerID); err != nil {
return layer, false, nil return layer, false, err
}
if layer.Features, err = tx.findLayerFeatures(layerID, detectorMap); err != nil {
return layer, false, err
}
if layer.Namespaces, err = tx.findLayerNamespaces(layerID, detectorMap); err != nil {
return layer, false, err
} }
layer.Features, err = tx.findLayerFeatures(layerID)
layer.Namespaces, err = tx.findLayerNamespaces(layerID)
return layer, true, nil return layer, true, nil
} }
func (tx *pgSession) persistLayer(hash string) (int64, error) { func sanitizePersistLayerInput(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
if hash == "" { if hash == "" {
return -1, commonerr.NewBadRequestError("Empty Layer Hash is not allowed") return commonerr.NewBadRequestError("expected non-empty layer hash")
} }
id := sql.NullInt64{} detectedBySet := mapset.NewSet()
if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil { for _, d := range detectedBy {
return -1, handleError("queryPersistLayer", err) detectedBySet.Add(d)
} }
if !id.Valid { for _, f := range features {
panic("null layer.id violates database constraint") if !detectedBySet.Contains(f.By) {
return database.ErrInvalidParameters
}
} }
return id.Int64, nil for _, n := range namespaces {
if !detectedBySet.Contains(n.By) {
return database.ErrInvalidParameters
}
}
return nil
} }
// PersistLayer relates layer identified by hash with namespaces, // PersistLayer saves the content of a layer to the database.
// features and processors provided. If the layer, namespaces, features are not func (tx *pgSession) PersistLayer(hash string, features []database.LayerFeature, namespaces []database.LayerNamespace, detectedBy []database.Detector) error {
// in database, the function returns an error.
func (tx *pgSession) PersistLayer(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error {
if hash == "" {
return commonerr.NewBadRequestError("Empty layer hash is not allowed")
}
var ( var (
err error err error
id int64 id int64
detectorIDs []int64
) )
if id, err = tx.persistLayer(hash); err != nil { if err = sanitizePersistLayerInput(hash, features, namespaces, detectedBy); err != nil {
return err return err
} }
if err = tx.persistLayerNamespace(id, namespaces); err != nil { if id, err = tx.soiLayer(hash); err != nil {
return err return err
} }
if err = tx.persistLayerFeatures(id, features); err != nil { if detectorIDs, err = tx.findDetectorIDs(detectedBy); err != nil {
if err == commonerr.ErrNotFound {
return database.ErrMissingEntities
}
return err return err
} }
if err = tx.persistLayerDetectors(id, processedBy.Detectors); err != nil { if err = tx.persistLayerDetectors(id, detectorIDs); err != nil {
return err return err
} }
if err = tx.persistLayerListers(id, processedBy.Listers); err != nil { if err = tx.persistAllLayerFeatures(id, features); err != nil {
return err
}
if err = tx.persistAllLayerNamespaces(id, namespaces); err != nil {
return err return err
} }
return nil return nil
} }
func (tx *pgSession) persistLayerDetectors(id int64, detectors []string) error { func (tx *pgSession) persistAllLayerNamespaces(layerID int64, namespaces []database.LayerNamespace) error {
if len(detectors) == 0 { detectorMap, err := tx.findAllDetectors()
return nil
}
// Sorting is needed before inserting into database to prevent deadlock.
sort.Strings(detectors)
keys := make([]interface{}, len(detectors)*2)
for i, d := range detectors {
keys[i*2] = id
keys[i*2+1] = d
}
_, err := tx.Exec(queryPersistLayerDetectors(len(detectors)), keys...)
if err != nil { if err != nil {
return handleError("queryPersistLayerDetectors", err) return err
}
return nil
}
func (tx *pgSession) persistLayerListers(id int64, listers []string) error {
if len(listers) == 0 {
return nil
} }
sort.Strings(listers) // TODO(sidac): This kind of type conversion is very useless and wasteful,
keys := make([]interface{}, len(listers)*2) // we need interfaces around the database models to reduce these kind of
for i, d := range listers { // operations.
keys[i*2] = id rawNamespaces := make([]database.Namespace, 0, len(namespaces))
keys[i*2+1] = d for _, ns := range namespaces {
rawNamespaces = append(rawNamespaces, ns.Namespace)
} }
_, err := tx.Exec(queryPersistLayerListers(len(listers)), keys...) rawNamespaceIDs, err := tx.findNamespaceIDs(rawNamespaces)
if err != nil { if err != nil {
return handleError("queryPersistLayerDetectors", err) return err
} }
dbLayerNamespaces := make([]dbLayerNamespace, 0, len(namespaces))
for i, ns := range namespaces {
detectorID := detectorMap.byValue[ns.By]
namespaceID := rawNamespaceIDs[i].Int64
if !rawNamespaceIDs[i].Valid {
return database.ErrMissingEntities
}
dbLayerNamespaces = append(dbLayerNamespaces, dbLayerNamespace{layerID, namespaceID, detectorID})
}
return tx.persistLayerNamespaces(dbLayerNamespaces)
}
func (tx *pgSession) persistAllLayerFeatures(layerID int64, features []database.LayerFeature) error {
detectorMap, err := tx.findAllDetectors()
if err != nil {
return err
}
rawFeatures := make([]database.Feature, 0, len(features))
for _, f := range features {
rawFeatures = append(rawFeatures, f.Feature)
}
featureIDs, err := tx.findFeatureIDs(rawFeatures)
if err != nil {
return err
}
dbFeatures := make([]dbLayerFeature, 0, len(features))
for i, f := range features {
detectorID := detectorMap.byValue[f.By]
featureID := featureIDs[i].Int64
if !featureIDs[i].Valid {
return database.ErrMissingEntities
}
dbFeatures = append(dbFeatures, dbLayerFeature{layerID, featureID, detectorID})
}
if err := tx.persistLayerFeatures(dbFeatures); err != nil {
return err
}
return nil return nil
} }
func (tx *pgSession) persistLayerFeatures(id int64, features []database.Feature) error { func (tx *pgSession) persistLayerFeatures(features []dbLayerFeature) error {
if len(features) == 0 { if len(features) == 0 {
return nil return nil
} }
fIDs, err := tx.findFeatureIDs(features) sort.Slice(features, func(i, j int) bool {
if err != nil { return features[i].featureID < features[j].featureID
return err })
keys := make([]interface{}, len(features)*3)
for i, feature := range features {
keys[i*3] = feature.layerID
keys[i*3+1] = feature.featureID
keys[i*3+2] = feature.detectorID
} }
ids := make([]int, len(fIDs)) _, err := tx.Exec(queryPersistLayerFeature(len(features)), keys...)
for i, fID := range fIDs {
if !fID.Valid {
return errNamespaceNotFound
}
ids[i] = int(fID.Int64)
}
sort.IntSlice(ids).Sort()
keys := make([]interface{}, len(features)*2)
for i, fID := range ids {
keys[i*2] = id
keys[i*2+1] = fID
}
_, err = tx.Exec(queryPersistLayerFeature(len(features)), keys...)
if err != nil { if err != nil {
return handleError("queryPersistLayerFeature", err) return handleError("queryPersistLayerFeature", err)
} }
return nil return nil
} }
func (tx *pgSession) persistLayerNamespace(id int64, namespaces []database.Namespace) error { func (tx *pgSession) persistLayerNamespaces(namespaces []dbLayerNamespace) error {
if len(namespaces) == 0 { if len(namespaces) == 0 {
return nil return nil
} }
nsIDs, err := tx.findNamespaceIDs(namespaces)
if err != nil {
return err
}
// for every bulk persist operation, the input data should be sorted. // for every bulk persist operation, the input data should be sorted.
ids := make([]int, len(nsIDs)) sort.Slice(namespaces, func(i, j int) bool {
for i, nsID := range nsIDs { return namespaces[i].namespaceID < namespaces[j].namespaceID
if !nsID.Valid { })
panic(errNamespaceNotFound)
} elementSize := 3
ids[i] = int(nsID.Int64) keys := make([]interface{}, len(namespaces)*elementSize)
for i, row := range namespaces {
keys[i*3] = row.layerID
keys[i*3+1] = row.namespaceID
keys[i*3+2] = row.detectorID
} }
sort.IntSlice(ids).Sort() _, err := tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
keys := make([]interface{}, len(namespaces)*2)
for i, nsID := range ids {
keys[i*2] = id
keys[i*2+1] = nsID
}
_, err = tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
if err != nil { if err != nil {
return handleError("queryPersistLayerNamespace", err) return handleError("queryPersistLayerNamespace", err)
} }
return nil
}
func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int64, processors database.Processors) error {
stmt, err := tx.Prepare(listerQuery)
if err != nil {
return handleError(listerQueryName, err)
}
for _, l := range processors.Listers {
_, err := stmt.Exec(id, l)
if err != nil {
stmt.Close()
return handleError(listerQueryName, err)
}
}
if err := stmt.Close(); err != nil {
return handleError(listerQueryName, err)
}
stmt, err = tx.Prepare(detectorQuery)
if err != nil {
return handleError(detectorQueryName, err)
}
for _, d := range processors.Detectors {
_, err := stmt.Exec(id, d)
if err != nil {
stmt.Close()
return handleError(detectorQueryName, err)
}
}
if err := stmt.Close(); err != nil {
return handleError(detectorQueryName, err)
}
return nil return nil
} }
func (tx *pgSession) findLayerNamespaces(layerID int64) ([]database.Namespace, error) { func (tx *pgSession) findLayerNamespaces(layerID int64, detectors detectorMap) ([]database.LayerNamespace, error) {
var namespaces []database.Namespace rows, err := tx.Query(findLayerNamespaces, layerID)
rows, err := tx.Query(searchLayerNamespaces, layerID)
if err != nil { if err != nil {
return nil, handleError("searchLayerFeatures", err) return nil, handleError("findLayerNamespaces", err)
} }
namespaces := []database.LayerNamespace{}
for rows.Next() { for rows.Next() {
ns := database.Namespace{} var (
err := rows.Scan(&ns.Name, &ns.VersionFormat) namespace database.LayerNamespace
if err != nil { detectorID int64
)
if err := rows.Scan(&namespace.Name, &namespace.VersionFormat, &detectorID); err != nil {
return nil, err return nil, err
} }
namespaces = append(namespaces, ns)
namespace.By = detectors.byID[detectorID]
namespaces = append(namespaces, namespace)
} }
return namespaces, nil return namespaces, nil
} }
func (tx *pgSession) findLayerFeatures(layerID int64) ([]database.Feature, error) { func (tx *pgSession) findLayerFeatures(layerID int64, detectors detectorMap) ([]database.LayerFeature, error) {
var features []database.Feature rows, err := tx.Query(findLayerFeatures, layerID)
rows, err := tx.Query(searchLayerFeatures, layerID)
if err != nil { if err != nil {
return nil, handleError("searchLayerFeatures", err) return nil, handleError("findLayerFeatures", err)
}
defer rows.Close()
features := []database.LayerFeature{}
for rows.Next() {
var (
detectorID int64
feature database.LayerFeature
)
if err := rows.Scan(&feature.Name, &feature.Version, &feature.VersionFormat, &detectorID); err != nil {
return nil, handleError("findLayerFeatures", err)
}
feature.By = detectors.byID[detectorID]
features = append(features, feature)
} }
for rows.Next() {
f := database.Feature{}
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat)
if err != nil {
return nil, err
}
features = append(features, f)
}
return features, nil return features, nil
} }
func (tx *pgSession) findLayer(hash string) (database.LayerMetadata, int64, bool, error) { func (tx *pgSession) findLayerID(hash string) (int64, bool, error) {
var ( var layerID int64
layerID int64 err := tx.QueryRow(findLayerID, hash).Scan(&layerID)
layer = database.LayerMetadata{Hash: hash, ProcessedBy: database.Processors{}}
)
if hash == "" {
return layer, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed")
}
err := tx.QueryRow(searchLayer, hash).Scan(&layerID)
if err != nil { if err != nil {
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return layer, layerID, false, nil return layerID, false, nil
} }
return layer, layerID, false, err
return layerID, false, handleError("findLayerID", err)
} }
layer.ProcessedBy, err = tx.findLayerProcessors(layerID) return layerID, true, nil
return layer, layerID, true, err
} }
func (tx *pgSession) findLayerProcessors(id int64) (database.Processors, error) { func (tx *pgSession) findLayerIDs(hashes []string) ([]int64, bool, error) {
var ( layerIDs := make([]int64, 0, len(hashes))
err error for _, hash := range hashes {
processors database.Processors id, ok, err := tx.findLayerID(hash)
) if !ok {
return nil, false, nil
}
if processors.Detectors, err = tx.findProcessors(searchLayerDetectors, id); err != nil { if err != nil {
return processors, handleError("searchLayerDetectors", err) return nil, false, err
}
layerIDs = append(layerIDs, id)
} }
if processors.Listers, err = tx.findProcessors(searchLayerListers, id); err != nil { return layerIDs, true, nil
return processors, handleError("searchLayerListers", err) }
}
func (tx *pgSession) soiLayer(hash string) (int64, error) {
return processors, nil var id int64
if err := tx.QueryRow(soiLayer, hash).Scan(&id); err != nil {
return 0, handleError("soiLayer", err)
}
return id, nil
} }

View File

@ -20,107 +20,172 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/testutil"
) )
var persistLayerTests = []struct {
title string
name string
by []database.Detector
features []database.LayerFeature
namespaces []database.LayerNamespace
layer *database.Layer
err string
}{
{
title: "invalid layer name",
name: "",
err: "expected non-empty layer hash",
},
{
title: "layer with inconsistent feature and detectors",
name: "random-forest",
by: []database.Detector{realDetectors[2]},
features: []database.LayerFeature{
{realFeatures[1], realDetectors[1]},
},
err: "database: parameters are not valid",
},
{
title: "layer with non-existing feature",
name: "random-forest",
err: "database: associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[2]},
features: []database.LayerFeature{
{fakeFeatures[1], realDetectors[2]},
},
},
{
title: "layer with non-existing namespace",
name: "random-forest2",
err: "database: associated immutable entities are missing in the database",
by: []database.Detector{realDetectors[1]},
namespaces: []database.LayerNamespace{
{fakeNamespaces[1], realDetectors[1]},
},
},
{
title: "layer with non-existing detector",
name: "random-forest3",
err: "database: associated immutable entities are missing in the database",
by: []database.Detector{fakeDetector[1]},
},
{
title: "valid layer",
name: "hamsterhouse",
by: []database.Detector{realDetectors[1], realDetectors[2]},
features: []database.LayerFeature{
{realFeatures[1], realDetectors[2]},
{realFeatures[2], realDetectors[2]},
},
namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
},
layer: &database.Layer{
Hash: "hamsterhouse",
By: []database.Detector{realDetectors[1], realDetectors[2]},
Features: []database.LayerFeature{
{realFeatures[1], realDetectors[2]},
{realFeatures[2], realDetectors[2]},
},
Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
},
},
},
{
title: "update existing layer",
name: "layer-1",
by: []database.Detector{realDetectors[3], realDetectors[4]},
features: []database.LayerFeature{
{realFeatures[4], realDetectors[3]},
},
namespaces: []database.LayerNamespace{
{realNamespaces[3], realDetectors[4]},
},
layer: &database.Layer{
Hash: "layer-1",
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]},
Features: []database.LayerFeature{
{realFeatures[1], realDetectors[2]},
{realFeatures[2], realDetectors[2]},
{realFeatures[4], realDetectors[3]},
},
Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
{realNamespaces[3], realDetectors[4]},
},
},
},
}
func TestPersistLayer(t *testing.T) { func TestPersistLayer(t *testing.T) {
datastore, tx := openSessionForTest(t, "PersistLayer", false) datastore, tx := openSessionForTest(t, "PersistLayer", true)
defer closeTest(t, datastore, tx) defer closeTest(t, datastore, tx)
// invalid for _, test := range persistLayerTests {
assert.NotNil(t, tx.PersistLayer("", nil, nil, database.Processors{})) t.Run(test.title, func(t *testing.T) {
// insert namespaces + features to err := tx.PersistLayer(test.name, test.features, test.namespaces, test.by)
namespaces := []database.Namespace{ if test.err != "" {
{ assert.EqualError(t, err, test.err, "unexpected error")
Name: "sushi shop", return
VersionFormat: "apk", }
},
assert.Nil(t, err)
if test.layer != nil {
layer, ok, err := tx.FindLayer(test.name)
assert.Nil(t, err)
assert.True(t, ok)
testutil.AssertLayerEqual(t, test.layer, &layer)
}
})
} }
}
features := []database.Feature{ var findLayerTests = []struct {
{ title string
Name: "blue fin sashimi", in string
Version: "v1.0",
VersionFormat: "apk",
},
}
processors := database.Processors{ out *database.Layer
Listers: []string{"release"}, err string
Detectors: []string{"apk"}, ok bool
} }{
{
assert.Nil(t, tx.PersistNamespaces(namespaces)) title: "invalid layer name",
assert.Nil(t, tx.PersistFeatures(features)) in: "",
err: "non empty layer hash is expected.",
// Valid },
assert.Nil(t, tx.PersistLayer("RANDOM_FOREST", namespaces, features, processors)) {
title: "non-existing layer",
nonExistingFeature := []database.Feature{{Name: "lobster sushi", Version: "v0.1", VersionFormat: "apk"}} in: "layer-non-existing",
// Invalid: ok: false,
assert.NotNil(t, tx.PersistLayer("RANDOM_FOREST", namespaces, nonExistingFeature, processors)) out: nil,
},
assert.Nil(t, tx.PersistFeatures(nonExistingFeature)) {
// Update the layer title: "existing layer",
assert.Nil(t, tx.PersistLayer("RANDOM_FOREST", namespaces, nonExistingFeature, processors)) in: "layer-4",
ok: true,
// confirm update out: takeLayerPointerFromMap(realLayers, 6),
layer, ok, err := tx.FindLayer("RANDOM_FOREST") },
assert.Nil(t, err)
assert.True(t, ok)
expectedLayer := database.Layer{
LayerMetadata: database.LayerMetadata{
Hash: "RANDOM_FOREST",
ProcessedBy: processors,
},
Features: append(features, nonExistingFeature...),
Namespaces: namespaces,
}
assertLayerWithContentEqual(t, expectedLayer, layer)
} }
func TestFindLayer(t *testing.T) { func TestFindLayer(t *testing.T) {
datastore, tx := openSessionForTest(t, "FindLayer", true) datastore, tx := openSessionForTest(t, "FindLayer", true)
defer closeTest(t, datastore, tx) defer closeTest(t, datastore, tx)
_, _, err := tx.FindLayer("") for _, test := range findLayerTests {
assert.NotNil(t, err) t.Run(test.title, func(t *testing.T) {
_, ok, err := tx.FindLayer("layer-non") layer, ok, err := tx.FindLayer(test.in)
assert.Nil(t, err) if test.err != "" {
assert.False(t, ok) assert.EqualError(t, err, test.err, "unexpected error")
return
}
expectedL := database.Layer{ assert.Nil(t, err)
LayerMetadata: database.LayerMetadata{ assert.Equal(t, test.ok, ok)
Hash: "layer-4", if test.ok {
ProcessedBy: database.Processors{ testutil.AssertLayerEqual(t, test.out, &layer)
Detectors: []string{"os-release", "apt-sources"}, }
Listers: []string{"dpkg", "rpm"}, })
},
},
Features: []database.Feature{
{Name: "fake", Version: "2.0", VersionFormat: "rpm"},
{Name: "openssl", Version: "2.0", VersionFormat: "dpkg"},
},
Namespaces: []database.Namespace{
{Name: "debian:7", VersionFormat: "dpkg"},
{Name: "fake:1.0", VersionFormat: "rpm"},
},
}
layer, ok2, err := tx.FindLayer("layer-4")
if assert.Nil(t, err) && assert.True(t, ok2) {
assertLayerWithContentEqual(t, expectedL, layer)
} }
} }
func assertLayerWithContentEqual(t *testing.T, expected database.Layer, actual database.Layer) bool {
return assertLayerEqual(t, expected.LayerMetadata, actual.LayerMetadata) &&
assertFeaturesEqual(t, expected.Features, actual.Features) &&
assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces)
}
func assertLayerEqual(t *testing.T, expected database.LayerMetadata, actual database.LayerMetadata) bool {
return assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) &&
assert.Equal(t, expected.Hash, actual.Hash)
}

View File

@ -38,8 +38,8 @@ var (
`CREATE TABLE IF NOT EXISTS namespaced_feature ( `CREATE TABLE IF NOT EXISTS namespaced_feature (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
namespace_id INT REFERENCES namespace, namespace_id INT REFERENCES namespace ON DELETE CASCADE,
feature_id INT REFERENCES feature, feature_id INT REFERENCES feature ON DELETE CASCADE,
UNIQUE (namespace_id, feature_id));`, UNIQUE (namespace_id, feature_id));`,
}, },
Down: []string{ Down: []string{
@ -116,7 +116,7 @@ var (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE, ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
ancestry_index INT NOT NULL, ancestry_index INT NOT NULL,
layer_id INT REFERENCES layer ON DELETE RESTRICT, layer_id INT NOT NULL REFERENCES layer ON DELETE RESTRICT,
UNIQUE (ancestry_id, ancestry_index));`, UNIQUE (ancestry_id, ancestry_index));`,
`CREATE INDEX ON ancestry_layer(ancestry_id);`, `CREATE INDEX ON ancestry_layer(ancestry_id);`,
@ -130,7 +130,7 @@ var (
`CREATE TABLE IF NOT EXISTS ancestry_detector( `CREATE TABLE IF NOT EXISTS ancestry_detector(
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
ancestry_id INT REFERENCES layer ON DELETE CASCADE, ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
detector_id INT REFERENCES detector ON DELETE CASCADE, detector_id INT REFERENCES detector ON DELETE CASCADE,
UNIQUE(ancestry_id, detector_id));`, UNIQUE(ancestry_id, detector_id));`,
`CREATE INDEX ON ancestry_detector(ancestry_id);`, `CREATE INDEX ON ancestry_detector(ancestry_id);`,

View File

@ -16,7 +16,6 @@ package pgsql
import ( import (
"database/sql" "database/sql"
"errors"
"sort" "sort"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
@ -27,10 +26,6 @@ const (
searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2` searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2`
) )
var (
errNamespaceNotFound = errors.New("Requested Namespace is not in database")
)
// PersistNamespaces soi namespaces into database. // PersistNamespaces soi namespaces into database.
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error { func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
if len(namespaces) == 0 { if len(namespaces) == 0 {

View File

@ -42,42 +42,3 @@ func TestPersistNamespaces(t *testing.T) {
assert.Len(t, nsList, 1) assert.Len(t, nsList, 1)
assert.Equal(t, ns2, nsList[0]) assert.Equal(t, ns2, nsList[0])
} }
func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.Namespace]bool{}
for _, i := range expected {
has[i] = false
}
for _, i := range actual {
has[i] = true
}
for key, v := range has {
if !assert.True(t, v, key.Name+"is expected") {
return false
}
}
return true
}
return false
}
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
rows, err := tx.Query("SELECT name, version_format FROM namespace")
if err != nil {
t.FailNow()
}
defer rows.Close()
namespaces := []database.Namespace{}
for rows.Next() {
var ns database.Namespace
err := rows.Scan(&ns.Name, &ns.VersionFormat)
if err != nil {
t.FailNow()
}
namespaces = append(namespaces, ns)
}
return namespaces
}

View File

@ -27,7 +27,6 @@ import (
) )
const ( const (
// notification.go
insertNotification = ` insertNotification = `
INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id)
VALUES ($1, $2, $3, $4)` VALUES ($1, $2, $3, $4)`
@ -60,9 +59,10 @@ const (
SELECT DISTINCT ON (a.id) SELECT DISTINCT ON (a.id)
a.id, a.name a.id, a.name
FROM vulnerability_affected_namespaced_feature AS vanf, FROM vulnerability_affected_namespaced_feature AS vanf,
ancestry_layer AS al, ancestry_feature AS af ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
WHERE vanf.vulnerability_id = $1 WHERE vanf.vulnerability_id = $1
AND al.ancestry_id >= $2 AND a.id >= $2
AND al.ancestry_id = a.id
AND al.id = af.ancestry_layer_id AND al.id = af.ancestry_layer_id
AND af.namespaced_feature_id = vanf.namespaced_feature_id AND af.namespaced_feature_id = vanf.namespaced_feature_id
ORDER BY a.id ASC ORDER BY a.id ASC
@ -211,14 +211,12 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
vulnPage := database.PagedVulnerableAncestries{Limit: limit} vulnPage := database.PagedVulnerableAncestries{Limit: limit}
currentPage := Page{0} currentPage := Page{0}
if currentToken != pagination.FirstPageToken { if currentToken != pagination.FirstPageToken {
var err error if err := tx.key.UnmarshalToken(currentToken, &currentPage); err != nil {
err = tx.key.UnmarshalToken(currentToken, &currentPage)
if err != nil {
return vulnPage, err return vulnPage, err
} }
} }
err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
&vulnPage.Name, &vulnPage.Name,
&vulnPage.Description, &vulnPage.Description,
&vulnPage.Link, &vulnPage.Link,
@ -226,8 +224,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
&vulnPage.Metadata, &vulnPage.Metadata,
&vulnPage.Namespace.Name, &vulnPage.Namespace.Name,
&vulnPage.Namespace.VersionFormat, &vulnPage.Namespace.VersionFormat,
) ); err != nil {
if err != nil {
return vulnPage, handleError("searchVulnerabilityByID", err) return vulnPage, handleError("searchVulnerabilityByID", err)
} }
@ -290,7 +287,6 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
noti.Name = name noti.Name = name
err := tx.QueryRow(searchNotification, name).Scan(&created, &notified, err := tx.QueryRow(searchNotification, name).Scan(&created, &notified,
&deleted, &oldVulnID, &newVulnID) &deleted, &oldVulnID, &newVulnID)

View File

@ -19,121 +19,144 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/pagination"
) )
func TestPagination(t *testing.T) { type findVulnerabilityNotificationIn struct {
datastore, tx := openSessionForTest(t, "Pagination", true) notificationName string
pageSize int
oldAffectedAncestryPage pagination.Token
newAffectedAncestryPage pagination.Token
}
type findVulnerabilityNotificationOut struct {
notification *database.VulnerabilityNotificationWithVulnerable
ok bool
err string
}
var findVulnerabilityNotificationTests = []struct {
title string
in findVulnerabilityNotificationIn
out findVulnerabilityNotificationOut
}{
{
title: "find notification with invalid page",
in: findVulnerabilityNotificationIn{
notificationName: "test",
pageSize: 1,
oldAffectedAncestryPage: pagination.FirstPageToken,
newAffectedAncestryPage: pagination.Token("random non sense"),
},
out: findVulnerabilityNotificationOut{
err: pagination.ErrInvalidToken.Error(),
},
},
{
title: "find non-existing notification",
in: findVulnerabilityNotificationIn{
notificationName: "non-existing",
pageSize: 1,
oldAffectedAncestryPage: pagination.FirstPageToken,
newAffectedAncestryPage: pagination.FirstPageToken,
},
out: findVulnerabilityNotificationOut{
ok: false,
},
},
{
title: "find existing notification first page",
in: findVulnerabilityNotificationIn{
notificationName: "test",
pageSize: 1,
oldAffectedAncestryPage: pagination.FirstPageToken,
newAffectedAncestryPage: pagination.FirstPageToken,
},
out: findVulnerabilityNotificationOut{
&database.VulnerabilityNotificationWithVulnerable{
NotificationHook: realNotification[1].NotificationHook,
Old: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[2],
Limit: 1,
Affected: make(map[int]string),
Current: mustMarshalToken(testPaginationKey, Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{0}),
End: true,
},
New: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[1],
Limit: 1,
Affected: map[int]string{3: "ancestry-3"},
Current: mustMarshalToken(testPaginationKey, Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{4}),
End: false,
},
},
true,
"",
},
},
{
title: "find existing notification of second page of new affected ancestry",
in: findVulnerabilityNotificationIn{
notificationName: "test",
pageSize: 1,
oldAffectedAncestryPage: pagination.FirstPageToken,
newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}),
},
out: findVulnerabilityNotificationOut{
&database.VulnerabilityNotificationWithVulnerable{
NotificationHook: realNotification[1].NotificationHook,
Old: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[2],
Limit: 1,
Affected: make(map[int]string),
Current: mustMarshalToken(testPaginationKey, Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{0}),
End: true,
},
New: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[1],
Limit: 1,
Affected: map[int]string{4: "ancestry-4"},
Current: mustMarshalToken(testPaginationKey, Page{4}),
Next: mustMarshalToken(testPaginationKey, Page{0}),
End: true,
},
},
true,
"",
},
},
}
func TestFindVulnerabilityNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "pagination", true)
defer closeTest(t, datastore, tx) defer closeTest(t, datastore, tx)
ns := database.Namespace{ for _, test := range findVulnerabilityNotificationTests {
Name: "debian:7", t.Run(test.title, func(t *testing.T) {
VersionFormat: "dpkg", notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage)
} if test.out.err != "" {
require.EqualError(t, err, test.out.err)
vNew := database.Vulnerability{ return
Namespace: ns,
Name: "CVE-OPENSSL-1-DEB7",
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: database.HighSeverity,
}
vOld := database.Vulnerability{
Namespace: ns,
Name: "CVE-NOPE",
Description: "A vulnerability affecting nothing",
Severity: database.UnknownSeverity,
}
noti, ok, err := tx.FindVulnerabilityNotification("test", 1, "", "")
oldPage := database.PagedVulnerableAncestries{
Vulnerability: vOld,
Limit: 1,
Affected: make(map[int]string),
End: true,
}
newPage1 := database.PagedVulnerableAncestries{
Vulnerability: vNew,
Limit: 1,
Affected: map[int]string{3: "ancestry-3"},
End: false,
}
newPage2 := database.PagedVulnerableAncestries{
Vulnerability: vNew,
Limit: 1,
Affected: map[int]string{4: "ancestry-4"},
Next: "",
End: true,
}
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
var oldPage Page
err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage)
if !assert.Nil(t, err) {
assert.FailNow(t, "")
} }
assert.Equal(t, int64(0), oldPage.StartID) require.Nil(t, err)
var newPage Page if !test.out.ok {
err = tx.key.UnmarshalToken(noti.New.Current, &newPage) require.Equal(t, test.out.ok, ok)
if !assert.Nil(t, err) { return
assert.FailNow(t, "")
}
var newPageNext Page
err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext)
if !assert.Nil(t, err) {
assert.FailNow(t, "")
}
assert.Equal(t, int64(0), newPage.StartID)
assert.Equal(t, int64(4), newPageNext.StartID)
noti.Old.Current = ""
noti.New.Current = ""
noti.New.Next = ""
assert.Equal(t, oldPage, *noti.Old)
assert.Equal(t, newPage1, *noti.New)
}
}
pageNum1, err := tx.key.MarshalToken(Page{0})
if !assert.Nil(t, err) {
assert.FailNow(t, "")
}
pageNum2, err := tx.key.MarshalToken(Page{4})
if !assert.Nil(t, err) {
assert.FailNow(t, "")
}
noti, ok, err = tx.FindVulnerabilityNotification("test", 1, pageNum1, pageNum2)
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
var oldCurrentPage Page
err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage)
if !assert.Nil(t, err) {
assert.FailNow(t, "")
} }
var newCurrentPage Page require.True(t, ok)
err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage) assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, &notification)
if !assert.Nil(t, err) { })
assert.FailNow(t, "")
}
assert.Equal(t, int64(0), oldCurrentPage.StartID)
assert.Equal(t, int64(4), newCurrentPage.StartID)
noti.Old.Current = ""
noti.New.Current = ""
assert.Equal(t, oldPage, *noti.Old)
assert.Equal(t, newPage2, *noti.New)
}
} }
} }

View File

@ -270,6 +270,7 @@ func migrateDatabase(db *sql.DB) error {
// createDatabase creates a new database. // createDatabase creates a new database.
// The source parameter should not contain a dbname. // The source parameter should not contain a dbname.
func createDatabase(source, dbName string) error { func createDatabase(source, dbName string) error {
log.WithFields(log.Fields{"source": source, "dbName": dbName}).Debug("creating database...")
// Open database. // Open database.
db, err := sql.Open("postgres", source) db, err := sql.Open("postgres", source)
if err != nil { if err != nil {
@ -325,7 +326,7 @@ func handleError(desc string, err error) error {
return commonerr.ErrNotFound return commonerr.ErrNotFound
} }
log.WithError(err).WithField("Description", desc).Error("Handled Database Error") log.WithError(err).WithField("Description", desc).Error("database: handled database error")
promErrorsTotal.WithLabelValues(desc).Inc() promErrorsTotal.WithLabelValues(desc).Inc()
if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {

View File

@ -37,6 +37,8 @@ var (
withFixtureName, withoutFixtureName string withFixtureName, withoutFixtureName string
) )
var testPaginationKey = pagination.Must(pagination.NewKey())
func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) { func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) {
config := generateTestConfig(name, loadFixture, false) config := generateTestConfig(name, loadFixture, false)
source := config.Options["source"].(string) source := config.Options["source"].(string)
@ -215,13 +217,15 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data
source = fmt.Sprintf(sourceEnv, dbName) source = fmt.Sprintf(sourceEnv, dbName)
} }
log.Infof("pagination key for current test: %s", testPaginationKey.String())
return database.RegistrableComponentConfig{ return database.RegistrableComponentConfig{
Options: map[string]interface{}{ Options: map[string]interface{}{
"source": source, "source": source,
"cachesize": 0, "cachesize": 0,
"managedatabaselifecycle": manageLife, "managedatabaselifecycle": manageLife,
"fixturepath": fixturePath, "fixturepath": fixturePath,
"paginationkey": pagination.Must(pagination.NewKey()).String(), "paginationkey": testPaginationKey.String(),
}, },
} }
} }
@ -247,6 +251,8 @@ func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *p
t.Error(err) t.Error(err)
t.FailNow() t.FailNow()
} }
log.Infof("transaction pagination key: '%s'", tx.(*pgSession).key.String())
return store, tx.(*pgSession) return store, tx.(*pgSession)
} }

View File

@ -121,7 +121,8 @@ func queryPersistLayerFeature(count int) string {
"layer_feature", "layer_feature",
"layer_feature_layer_id_feature_id_key", "layer_feature_layer_id_feature_id_key",
"layer_id", "layer_id",
"feature_id") "feature_id",
"detector_id")
} }
func queryPersistNamespace(count int) string { func queryPersistNamespace(count int) string {
@ -132,28 +133,13 @@ func queryPersistNamespace(count int) string {
"version_format") "version_format")
} }
func queryPersistLayerListers(count int) string {
return queryPersist(count,
"layer_lister",
"layer_lister_layer_id_lister_key",
"layer_id",
"lister")
}
func queryPersistLayerDetectors(count int) string {
return queryPersist(count,
"layer_detector",
"layer_detector_layer_id_detector_key",
"layer_id",
"detector")
}
func queryPersistLayerNamespace(count int) string { func queryPersistLayerNamespace(count int) string {
return queryPersist(count, return queryPersist(count,
"layer_namespace", "layer_namespace",
"layer_namespace_layer_id_namespace_id_key", "layer_namespace_layer_id_namespace_id_key",
"layer_id", "layer_id",
"namespace_id") "namespace_id",
"detector_id")
} }
// size of key and array should be both greater than 0 // size of key and array should be both greater than 0

View File

@ -1,57 +1,69 @@
-- initialize entities
INSERT INTO namespace (id, name, version_format) VALUES INSERT INTO namespace (id, name, version_format) VALUES
(1, 'debian:7', 'dpkg'), (1, 'debian:7', 'dpkg'),
(2, 'debian:8', 'dpkg'), (2, 'debian:8', 'dpkg'),
(3, 'fake:1.0', 'rpm'); (3, 'fake:1.0', 'rpm');
INSERT INTO feature (id, name, version, version_format) VALUES INSERT INTO feature (id, name, version, version_format) VALUES
(1, 'wechat', '0.5', 'dpkg'), (1, 'ourchat', '0.5', 'dpkg'),
(2, 'openssl', '1.0', 'dpkg'), (2, 'openssl', '1.0', 'dpkg'),
(3, 'openssl', '2.0', 'dpkg'), (3, 'openssl', '2.0', 'dpkg'),
(4, 'fake', '2.0', 'rpm'); (4, 'fake', '2.0', 'rpm');
INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES
(1, 1, 1), -- ourchat 0.5, debian:7
(2, 2, 1), -- openssl 1.0, debian:7
(3, 2, 2), -- openssl 1.0, debian:8
(4, 3, 1); -- openssl 2.0, debian:7
INSERT INTO detector(id, name, version, dtype) VALUES
(1, 'os-release', '1.0', 'namespace'),
(2, 'dpkg', '1.0', 'feature'),
(3, 'rpm', '1.0', 'feature'),
(4, 'apt-sources', '1.0', 'namespace');
-- initialize layers
INSERT INTO layer (id, hash) VALUES INSERT INTO layer (id, hash) VALUES
(1, 'layer-0'), -- blank (1, 'layer-0'), -- blank
(2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0 (2, 'layer-1'), -- debian:7; ourchat 0.5, openssl 1.0
(3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0 (3, 'layer-2'), -- debian:7; ourchat 0.5, openssl 2.0
(4, 'layer-3a'),-- debian:7; (4, 'layer-3a'),-- debian:7;
(5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0 (5, 'layer-3b'),-- debian:8; ourchat 0.5, openssl 1.0
(6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake) (6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake)
INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES INSERT INTO layer_namespace(id, layer_id, namespace_id, detector_id) VALUES
(1, 2, 1), (1, 2, 1, 1), -- layer-1: debian:7
(2, 3, 1), (2, 3, 1, 1), -- layer-2: debian:7
(3, 4, 1), (3, 4, 1, 1), -- layer-3a: debian:7
(4, 5, 2), (4, 5, 2, 1), -- layer-3b: debian:8
(5, 6, 1), (5, 6, 1, 1), -- layer-4: debian:7
(6, 6, 3); (6, 6, 3, 4); -- layer-4: fake:1.0
INSERT INTO layer_feature(id, layer_id, feature_id) VALUES INSERT INTO layer_feature(id, layer_id, feature_id, detector_id) VALUES
(1, 2, 1), (1, 2, 1, 2), -- layer-1: ourchat 0.5
(2, 2, 2), (2, 2, 2, 2), -- layer-1: openssl 1.0
(3, 3, 1), (3, 3, 1, 2), -- layer-2: ourchat 0.5
(4, 3, 3), (4, 3, 3, 2), -- layer-2: openssl 2.0
(5, 5, 1), (5, 5, 1, 2), -- layer-3b: ourchat 0.5
(6, 5, 2), (6, 5, 2, 2), -- layer-3b: openssl 1.0
(7, 6, 4), (7, 6, 4, 3), -- layer-4: fake 2.0
(8, 6, 3); (8, 6, 3, 2); -- layer-4: openssl 2.0
INSERT INTO layer_lister(id, layer_id, lister) VALUES INSERT INTO layer_detector(layer_id, detector_id) VALUES
(1, 1, 'dpkg'), (1, 1),
(2, 2, 'dpkg'), (2, 1),
(3, 3, 'dpkg'), (3, 1),
(4, 4, 'dpkg'), (4, 1),
(5, 5, 'dpkg'), (5, 1),
(6, 6, 'dpkg'), (6, 1),
(7, 6, 'rpm'); (6, 4),
(1, 2),
INSERT INTO layer_detector(id, layer_id, detector) VALUES (2, 2),
(1, 1, 'os-release'), (3, 2),
(2, 2, 'os-release'), (4, 2),
(3, 3, 'os-release'), (5, 2),
(4, 4, 'os-release'), (6, 2),
(5, 5, 'os-release'), (6, 3);
(6, 6, 'os-release'),
(7, 6, 'apt-sources');
INSERT INTO ancestry (id, name) VALUES INSERT INTO ancestry (id, name) VALUES
(1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a (1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a
@ -59,32 +71,39 @@ INSERT INTO ancestry (id, name) VALUES
(3, 'ancestry-3'), -- layer-0 (3, 'ancestry-3'), -- layer-0
(4, 'ancestry-4'); -- layer-0 (4, 'ancestry-4'); -- layer-0
INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES INSERT INTO ancestry_detector (ancestry_id, detector_id) VALUES
(1, 1, 'dpkg'), (1, 2),
(2, 2, 'dpkg'); (2, 2),
(1, 1),
INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES (2, 1);
(1, 1, 'os-release'),
(2, 2, 'os-release');
INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES
-- ancestry-1: layer-0, layer-1, layer-2, layer-3a
(1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3), (1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3),
-- ancestry-2: layer-0, layer-1, layer-2, layer-3b
(5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3), (5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3),
(9, 3, 1, 0), -- ancestry-3: layer-1
(10, 4, 1, 0); (9, 3, 2, 0),
-- ancestry-4: layer-1
INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES (10, 4, 2, 0);
(1, 1, 1), -- wechat 0.5, debian:7
(2, 2, 1), -- openssl 1.0, debian:7
(3, 2, 2), -- openssl 1.0, debian:8
(4, 3, 1); -- openssl 2.0, debian:7
-- assume that ancestry-3 and ancestry-4 are vulnerable. -- assume that ancestry-3 and ancestry-4 are vulnerable.
INSERT INTO ancestry_feature (id, ancestry_layer_id, namespaced_feature_id) VALUES INSERT INTO ancestry_feature (id, ancestry_layer_id, namespaced_feature_id, feature_detector_id, namespace_detector_id) VALUES
(1, 1, 1), (2, 1, 4), -- ancestry-1, layer 0 introduces 1, 4 -- ancestry-1:
(3, 5, 1), (4, 5, 3), -- ancestry-2, layer 0 introduces 1, 3 -- layer-2: ourchat 0.5 <- detected by dpkg 1.0 (2); debian: 7 <- detected by os-release 1.0 (1)
(5, 9, 2), -- ancestry-3, layer 0 introduces 2 -- layer-2: openssl 2.0, debian:7
(6, 10, 2); -- ancestry-4, layer 0 introduces 2 (1, 3, 1, 2, 1), (2, 3, 4, 2, 1),
-- ancestry 2:
-- 1(ourchat 0.5; debian:7 layer-2)
-- 3(openssl 1.0; debian:8 layer-3b)
(3, 7, 1, 2, 1), (4, 8, 3, 2, 1),
-- ancestry-3:
-- 2(openssl 1.0, debian:7 layer-1)
-- 1(ourchat 0.5, debian:7 layer-1)
(5, 9, 2, 2, 1), (6, 9, 1, 2, 1), -- vulnerable
-- ancestry-4:
-- same as ancestry-3
(7, 10, 2, 2, 1), (8, 10, 1, 2, 1); -- vulnerable
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES
(1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'),
@ -103,19 +122,23 @@ INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, name
INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES
(1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7' (1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7'
SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('detector', 'id'), (SELECT MAX(id) FROM detector)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_feature', 'id'), (SELECT MAX(id) FROM layer_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1); SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1);
SELECT pg_catalog.setval(pg_get_serial_sequence('detector', 'id'), (SELECT MAX(id) FROM detector)+1);

263
database/pgsql/testutil.go Normal file
View File

@ -0,0 +1,263 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/pagination"
"github.com/coreos/clair/pkg/testutil"
)
// int keys must be the consistent with the database ID.
var (
realFeatures = map[int]database.Feature{
1: {"ourchat", "0.5", "dpkg"},
2: {"openssl", "1.0", "dpkg"},
3: {"openssl", "2.0", "dpkg"},
4: {"fake", "2.0", "rpm"},
}
realNamespaces = map[int]database.Namespace{
1: {"debian:7", "dpkg"},
2: {"debian:8", "dpkg"},
3: {"fake:1.0", "rpm"},
}
realNamespacedFeatures = map[int]database.NamespacedFeature{
1: {realFeatures[1], realNamespaces[1]},
2: {realFeatures[2], realNamespaces[1]},
3: {realFeatures[2], realNamespaces[2]},
4: {realFeatures[3], realNamespaces[1]},
}
realDetectors = map[int]database.Detector{
1: database.NewNamespaceDetector("os-release", "1.0"),
2: database.NewFeatureDetector("dpkg", "1.0"),
3: database.NewFeatureDetector("rpm", "1.0"),
4: database.NewNamespaceDetector("apt-sources", "1.0"),
}
realLayers = map[int]database.Layer{
2: {
Hash: "layer-1",
By: []database.Detector{realDetectors[1], realDetectors[2]},
Features: []database.LayerFeature{
{realFeatures[1], realDetectors[2]},
{realFeatures[2], realDetectors[2]},
},
Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
},
},
6: {
Hash: "layer-4",
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]},
Features: []database.LayerFeature{
{realFeatures[4], realDetectors[3]},
{realFeatures[3], realDetectors[2]},
},
Namespaces: []database.LayerNamespace{
{realNamespaces[1], realDetectors[1]},
{realNamespaces[3], realDetectors[4]},
},
},
}
realAncestries = map[int]database.Ancestry{
2: {
Name: "ancestry-2",
By: []database.Detector{realDetectors[2], realDetectors[1]},
Layers: []database.AncestryLayer{
{
"layer-0",
[]database.AncestryFeature{},
},
{
"layer-1",
[]database.AncestryFeature{},
},
{
"layer-2",
[]database.AncestryFeature{
{
realNamespacedFeatures[1],
realDetectors[2],
realDetectors[1],
},
},
},
{
"layer-3b",
[]database.AncestryFeature{
{
realNamespacedFeatures[3],
realDetectors[2],
realDetectors[1],
},
},
},
},
},
}
realVulnerability = map[int]database.Vulnerability{
1: {
Name: "CVE-OPENSSL-1-DEB7",
Namespace: realNamespaces[1],
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: database.HighSeverity,
},
2: {
Name: "CVE-NOPE",
Namespace: realNamespaces[1],
Description: "A vulnerability affecting nothing",
Severity: database.UnknownSeverity,
},
}
realNotification = map[int]database.VulnerabilityNotification{
1: {
NotificationHook: database.NotificationHook{
Name: "test",
},
Old: takeVulnerabilityPointerFromMap(realVulnerability, 2),
New: takeVulnerabilityPointerFromMap(realVulnerability, 1),
},
}
fakeFeatures = map[int]database.Feature{
1: {
Name: "ourchat",
Version: "0.6",
VersionFormat: "dpkg",
},
}
fakeNamespaces = map[int]database.Namespace{
1: {"green hat", "rpm"},
}
fakeNamespacedFeatures = map[int]database.NamespacedFeature{
1: {
Feature: fakeFeatures[0],
Namespace: realNamespaces[0],
},
}
fakeDetector = map[int]database.Detector{
1: {
Name: "fake",
Version: "1.0",
DType: database.FeatureDetectorType,
},
2: {
Name: "fake2",
Version: "2.0",
DType: database.NamespaceDetectorType,
},
}
)
func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability {
x := m[id]
return &x
}
func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry {
x := m[id]
return &x
}
func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer {
x := m[id]
return &x
}
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
rows, err := tx.Query("SELECT name, version_format FROM namespace")
if err != nil {
t.FailNow()
}
defer rows.Close()
namespaces := []database.Namespace{}
for rows.Next() {
var ns database.Namespace
err := rows.Scan(&ns.Name, &ns.VersionFormat)
if err != nil {
t.FailNow()
}
namespaces = append(namespaces, ns)
}
return namespaces
}
func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) &&
AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) &&
AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New)
}
func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return testutil.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) &&
assert.Equal(t, expected.Limit, actual.Limit) &&
assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) &&
assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) &&
assert.Equal(t, expected.End, actual.End) &&
testutil.AssertIntStringMapEqual(t, expected.Affected, actual.Affected)
}
func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page {
if token == pagination.FirstPageToken {
return Page{}
}
p := Page{}
if err := key.UnmarshalToken(token, &p); err != nil {
panic(err)
}
return p
}
func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
token, err := key.MarshalToken(v)
if err != nil {
panic(err)
}
return token
}

View File

@ -306,14 +306,14 @@ func TestFindVulnerabilityIDs(t *testing.T) {
ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}}) ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
if assert.Nil(t, err) { if assert.Nil(t, err) {
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, ids[0].Int64)) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, int(ids[0].Int64))) {
assert.Fail(t, "") assert.Fail(t, "")
} }
} }
ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}}) ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
if assert.Nil(t, err) { if assert.Nil(t, err) {
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, ids[0].Int64)) { if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, int(ids[0].Int64))) {
assert.Fail(t, "") assert.Fail(t, "")
} }
} }

285
pkg/testutil/testutil.go Normal file
View File

@ -0,0 +1,285 @@
package testutil
import (
"encoding/json"
"sort"
"testing"
"github.com/deckarep/golang-set"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database"
)
// AssertDetectorsEqual asserts actual detectors are content wise equal to
// expected detectors regardless of the ordering.
func AssertDetectorsEqual(t *testing.T, expected, actual []database.Detector) bool {
if len(expected) != len(actual) {
return assert.Fail(t, "detectors are not equal", "expected: '%v', actual: '%v'", expected, actual)
}
sort.Slice(expected, func(i, j int) bool {
return expected[i].String() < expected[j].String()
})
sort.Slice(actual, func(i, j int) bool {
return actual[i].String() < actual[j].String()
})
for i := range expected {
if expected[i] != actual[i] {
return assert.Fail(t, "detectors are not equal", "expected: '%v', actual: '%v'", expected, actual)
}
}
return true
}
// AssertAncestryEqual asserts actual ancestry equals to expected ancestry
// content wise.
func AssertAncestryEqual(t *testing.T, expected, actual *database.Ancestry) bool {
if expected == actual {
return true
}
if actual == nil || expected == nil {
return assert.Equal(t, expected, actual)
}
if !assert.Equal(t, expected.Name, actual.Name) || !AssertDetectorsEqual(t, expected.By, actual.By) {
return false
}
if assert.Equal(t, len(expected.Layers), len(actual.Layers)) {
for index := range expected.Layers {
if !AssertAncestryLayerEqual(t, &expected.Layers[index], &actual.Layers[index]) {
return false
}
}
return true
}
return false
}
// AssertAncestryLayerEqual asserts actual ancestry layer equals to expected
// ancestry layer content wise.
func AssertAncestryLayerEqual(t *testing.T, expected, actual *database.AncestryLayer) bool {
if !assert.Equal(t, expected.Hash, actual.Hash) {
return false
}
if !assert.Equal(t, len(expected.Features), len(actual.Features),
"layer: %s\nExpected: %v\n Actual: %v",
expected.Hash, expected.Features, actual.Features,
) {
return false
}
// feature -> is in actual layer
hitCounter := map[database.AncestryFeature]bool{}
for _, f := range expected.Features {
hitCounter[f] = false
}
// if there's no extra features and no duplicated features, since expected
// and actual have the same length, their result must equal.
for _, f := range actual.Features {
v, ok := hitCounter[f]
assert.True(t, ok, "unexpected feature %s", f)
assert.False(t, v, "duplicated feature %s", f)
hitCounter[f] = true
}
for f, visited := range hitCounter {
assert.True(t, visited, "missing feature %s", f)
}
return true
}
// AssertElementsEqual asserts that content in actual equals to content in
// expected array regardless of ordering.
//
// Note: This function uses interface wise comparison.
func AssertElementsEqual(t *testing.T, expected, actual []interface{}) bool {
counter := map[interface{}]bool{}
for _, f := range expected {
counter[f] = false
}
for _, f := range actual {
v, ok := counter[f]
if !assert.True(t, ok, "unexpected element %v\nExpected: %v\n Actual: %v\n", f, expected, actual) {
return false
}
if !assert.False(t, v, "duplicated element %v\nExpected: %v\n Actual: %v\n", f, expected, actual) {
return false
}
counter[f] = true
}
for f, visited := range counter {
if !assert.True(t, visited, "missing feature %v\nExpected: %v\n Actual: %v\n", f, expected, actual) {
return false
}
}
return true
}
// AssertFeaturesEqual asserts content in actual equals content in expected
// regardless of ordering.
func AssertFeaturesEqual(t *testing.T, expected, actual []database.Feature) bool {
if assert.Len(t, actual, len(expected)) {
has := map[database.Feature]bool{}
for _, nf := range expected {
has[nf] = false
}
for _, nf := range actual {
has[nf] = true
}
for nf, visited := range has {
if !assert.True(t, visited, nf.Name+" is expected") {
return false
}
return true
}
}
return false
}
// AssertLayerFeaturesEqual asserts content in actual equals to content in
// expected regardless of ordering.
func AssertLayerFeaturesEqual(t *testing.T, expected, actual []database.LayerFeature) bool {
if !assert.Len(t, actual, len(expected)) {
return false
}
expectedInterfaces := []interface{}{}
for _, e := range expected {
expectedInterfaces = append(expectedInterfaces, e)
}
actualInterfaces := []interface{}{}
for _, a := range actual {
actualInterfaces = append(actualInterfaces, a)
}
return AssertElementsEqual(t, expectedInterfaces, actualInterfaces)
}
// AssertNamespacesEqual asserts content in actual equals to content in
// expected regardless of ordering.
func AssertNamespacesEqual(t *testing.T, expected, actual []database.Namespace) bool {
expectedInterfaces := []interface{}{}
for _, e := range expected {
expectedInterfaces = append(expectedInterfaces, e)
}
actualInterfaces := []interface{}{}
for _, a := range actual {
actualInterfaces = append(actualInterfaces, a)
}
return AssertElementsEqual(t, expectedInterfaces, actualInterfaces)
}
// AssertLayerNamespacesEqual asserts content in actual equals to content in
// expected regardless of ordering.
func AssertLayerNamespacesEqual(t *testing.T, expected, actual []database.LayerNamespace) bool {
expectedInterfaces := []interface{}{}
for _, e := range expected {
expectedInterfaces = append(expectedInterfaces, e)
}
actualInterfaces := []interface{}{}
for _, a := range actual {
actualInterfaces = append(actualInterfaces, a)
}
return AssertElementsEqual(t, expectedInterfaces, actualInterfaces)
}
// AssertLayerEqual asserts actual layer equals to expected layer content wise.
func AssertLayerEqual(t *testing.T, expected, actual *database.Layer) bool {
if expected == actual {
return true
}
if expected == nil || actual == nil {
return assert.Equal(t, expected, actual)
}
return assert.Equal(t, expected.Hash, actual.Hash) &&
AssertDetectorsEqual(t, expected.By, actual.By) &&
AssertLayerFeaturesEqual(t, expected.Features, actual.Features) &&
AssertLayerNamespacesEqual(t, expected.Namespaces, actual.Namespaces)
}
// AssertIntStringMapEqual asserts two maps with integer as key and string as
// value are equal.
func AssertIntStringMapEqual(t *testing.T, expected, actual map[int]string) bool {
checked := mapset.NewSet()
for k, v := range expected {
assert.Equal(t, v, actual[k])
checked.Add(k)
}
for k := range actual {
if !assert.True(t, checked.Contains(k)) {
return false
}
}
return true
}
// AssertVulnerabilityEqual asserts two vulnerabilities are equal.
func AssertVulnerabilityEqual(t *testing.T, expected, actual *database.Vulnerability) bool {
return assert.Equal(t, expected.Name, actual.Name) &&
assert.Equal(t, expected.Link, actual.Link) &&
assert.Equal(t, expected.Description, actual.Description) &&
assert.Equal(t, expected.Namespace, actual.Namespace) &&
assert.Equal(t, expected.Severity, actual.Severity) &&
AssertMetadataMapEqual(t, expected.Metadata, actual.Metadata)
}
func castMetadataMapToInterface(metadata database.MetadataMap) map[string]interface{} {
content, err := json.Marshal(metadata)
if err != nil {
panic(err)
}
data := make(map[string]interface{})
if err := json.Unmarshal(content, &data); err != nil {
panic(err)
}
return data
}
// AssertMetadataMapEqual asserts two metadata maps are equal.
func AssertMetadataMapEqual(t *testing.T, expected, actual database.MetadataMap) bool {
expectedMap := castMetadataMapToInterface(expected)
actualMap := castMetadataMapToInterface(actual)
checked := mapset.NewSet()
for k, v := range expectedMap {
if !assert.Equal(t, v, (actualMap)[k]) {
return false
}
checked.Add(k)
}
for k := range actual {
if !assert.True(t, checked.Contains(k)) {
return false
}
}
return true
}