* Refactor layer and ancestry * Add tests * Fix bugs introduced when the queries were movedmaster
parent
028324014b
commit
0c1b80b2ed
@ -0,0 +1,198 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
const (
|
||||
soiDetector = `
|
||||
INSERT INTO detector (name, version, dtype)
|
||||
SELECT CAST ($1 AS TEXT), CAST ($2 AS TEXT), CAST ($3 AS detector_type )
|
||||
WHERE NOT EXISTS (SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3);`
|
||||
|
||||
selectAncestryDetectors = `
|
||||
SELECT d.name, d.version, d.dtype
|
||||
FROM ancestry_detector, detector AS d
|
||||
WHERE ancestry_detector.detector_id = d.id AND ancestry_detector.ancestry_id = $1;`
|
||||
|
||||
selectLayerDetectors = `
|
||||
SELECT d.name, d.version, d.dtype
|
||||
FROM layer_detector, detector AS d
|
||||
WHERE layer_detector.detector_id = d.id AND layer_detector.layer_id = $1;`
|
||||
|
||||
insertAncestryDetectors = `
|
||||
INSERT INTO ancestry_detector (ancestry_id, detector_id)
|
||||
SELECT $1, $2
|
||||
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector_id = $2)`
|
||||
|
||||
persistLayerDetector = `
|
||||
INSERT INTO layer_detector (layer_id, detector_id)
|
||||
SELECT $1, $2
|
||||
WHERE NOT EXISTS (SELECT id FROM layer_detector WHERE layer_id = $1 AND detector_id = $2)`
|
||||
|
||||
findDetectorID = `SELECT id FROM detector WHERE name = $1 AND version = $2 AND dtype = $3`
|
||||
findAllDetectors = `SELECT id, name, version, dtype FROM detector`
|
||||
)
|
||||
|
||||
type detectorMap struct {
|
||||
byID map[int64]database.Detector
|
||||
byValue map[database.Detector]int64
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistDetectors(detectors []database.Detector) error {
|
||||
for _, d := range detectors {
|
||||
if !d.Valid() {
|
||||
log.WithField("detector", d).Debug("Invalid Detector")
|
||||
return database.ErrInvalidParameters
|
||||
}
|
||||
|
||||
r, err := tx.Exec(soiDetector, d.Name, d.Version, d.DType)
|
||||
if err != nil {
|
||||
return handleError("soiDetector", err)
|
||||
}
|
||||
|
||||
count, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("soiDetector", err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
log.Debug("detector already exists: ", d)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerDetector(layerID int64, detectorID int64) error {
|
||||
if _, err := tx.Exec(persistLayerDetector, layerID, detectorID); err != nil {
|
||||
return handleError("persistLayerDetector", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerDetectors(layerID int64, detectorIDs []int64) error {
|
||||
alreadySaved := mapset.NewSet()
|
||||
for _, id := range detectorIDs {
|
||||
if alreadySaved.Contains(id) {
|
||||
continue
|
||||
}
|
||||
|
||||
alreadySaved.Add(id)
|
||||
if err := tx.persistLayerDetector(layerID, id); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) insertAncestryDetectors(ancestryID int64, detectorIDs []int64) error {
|
||||
for _, detectorID := range detectorIDs {
|
||||
if _, err := tx.Exec(insertAncestryDetectors, ancestryID, detectorID); err != nil {
|
||||
return handleError("insertAncestryDetectors", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryDetectors(id int64) ([]database.Detector, error) {
|
||||
detectors, err := tx.getDetectors(selectAncestryDetectors, id)
|
||||
log.WithField("detectors", detectors).Debug("found ancestry detectors")
|
||||
return detectors, err
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerDetectors(id int64) ([]database.Detector, error) {
|
||||
detectors, err := tx.getDetectors(selectLayerDetectors, id)
|
||||
log.WithField("detectors", detectors).Debug("found layer detectors")
|
||||
return detectors, err
|
||||
}
|
||||
|
||||
// findDetectorIDs retrieve ids of the detectors from the database, if any is not
|
||||
// found, return the error.
|
||||
func (tx *pgSession) findDetectorIDs(detectors []database.Detector) ([]int64, error) {
|
||||
ids := []int64{}
|
||||
for _, d := range detectors {
|
||||
id := sql.NullInt64{}
|
||||
err := tx.QueryRow(findDetectorID, d.Name, d.Version, d.DType).Scan(&id)
|
||||
if err != nil {
|
||||
return nil, handleError("findDetectorID", err)
|
||||
}
|
||||
|
||||
if !id.Valid {
|
||||
return nil, database.ErrInconsistent
|
||||
}
|
||||
|
||||
ids = append(ids, id.Int64)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) getDetectors(query string, id int64) ([]database.Detector, error) {
|
||||
rows, err := tx.Query(query, id)
|
||||
if err != nil {
|
||||
return nil, handleError("getDetectors", err)
|
||||
}
|
||||
|
||||
detectors := []database.Detector{}
|
||||
for rows.Next() {
|
||||
d := database.Detector{}
|
||||
err := rows.Scan(&d.Name, &d.Version, &d.DType)
|
||||
if err != nil {
|
||||
return nil, handleError("getDetectors", err)
|
||||
}
|
||||
|
||||
if !d.Valid() {
|
||||
return nil, database.ErrInvalidDetector
|
||||
}
|
||||
|
||||
detectors = append(detectors, d)
|
||||
}
|
||||
|
||||
return detectors, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAllDetectors() (detectorMap, error) {
|
||||
rows, err := tx.Query(findAllDetectors)
|
||||
if err != nil {
|
||||
return detectorMap{}, handleError("searchAllDetectors", err)
|
||||
}
|
||||
|
||||
detectors := detectorMap{byID: make(map[int64]database.Detector), byValue: make(map[database.Detector]int64)}
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
d database.Detector
|
||||
)
|
||||
if err := rows.Scan(&id, &d.Name, &d.Version, &d.DType); err != nil {
|
||||
return detectorMap{}, handleError("searchAllDetectors", err)
|
||||
}
|
||||
|
||||
detectors.byID[id] = d
|
||||
detectors.byValue[d] = id
|
||||
}
|
||||
|
||||
return detectors, nil
|
||||
}
|
@ -0,0 +1,119 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
func testGetAllDetectors(tx *pgSession) []database.Detector {
|
||||
query := `SELECT name, version, dtype FROM detector`
|
||||
rows, err := tx.Query(query)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
detectors := []database.Detector{}
|
||||
for rows.Next() {
|
||||
d := database.Detector{}
|
||||
if err := rows.Scan(&d.Name, &d.Version, &d.DType); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
detectors = append(detectors, d)
|
||||
}
|
||||
|
||||
return detectors
|
||||
}
|
||||
|
||||
var persistDetectorTests = []struct {
|
||||
title string
|
||||
in []database.Detector
|
||||
err string
|
||||
}{
|
||||
{
|
||||
title: "invalid detector",
|
||||
in: []database.Detector{
|
||||
{},
|
||||
database.NewFeatureDetector("name", "2.0"),
|
||||
},
|
||||
err: database.ErrInvalidParameters.Error(),
|
||||
},
|
||||
{
|
||||
title: "invalid detector 2",
|
||||
in: []database.Detector{
|
||||
database.NewFeatureDetector("name", "2.0"),
|
||||
{"name", "1.0", "random not valid dtype"},
|
||||
},
|
||||
err: database.ErrInvalidParameters.Error(),
|
||||
},
|
||||
{
|
||||
title: "detectors with some different fields",
|
||||
in: []database.Detector{
|
||||
database.NewFeatureDetector("name", "2.0"),
|
||||
database.NewFeatureDetector("name", "1.0"),
|
||||
database.NewNamespaceDetector("name", "1.0"),
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "duplicated detectors (parameter level)",
|
||||
in: []database.Detector{
|
||||
database.NewFeatureDetector("name", "1.0"),
|
||||
database.NewFeatureDetector("name", "1.0"),
|
||||
},
|
||||
},
|
||||
{
|
||||
title: "duplicated detectors (db level)",
|
||||
in: []database.Detector{
|
||||
database.NewNamespaceDetector("os-release", "1.0"),
|
||||
database.NewNamespaceDetector("os-release", "1.0"),
|
||||
database.NewFeatureDetector("dpkg", "1.0"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
func TestPersistDetector(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistDetector", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
for _, test := range persistDetectorTests {
|
||||
t.Run(test.title, func(t *testing.T) {
|
||||
err := tx.PersistDetectors(test.in)
|
||||
if test.err != "" {
|
||||
require.EqualError(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
detectors := testGetAllDetectors(tx)
|
||||
|
||||
// ensure no duplicated detectors
|
||||
detectorSet := mapset.NewSet()
|
||||
for _, d := range detectors {
|
||||
require.False(t, detectorSet.Contains(d), "duplicated: %v", d)
|
||||
detectorSet.Add(d)
|
||||
}
|
||||
|
||||
// ensure all persisted detectors are actually saved
|
||||
for _, d := range test.in {
|
||||
require.True(t, detectorSet.Contains(d), "detector: %v, detectors: %v", d, detectorSet)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@ -0,0 +1,263 @@
|
||||
// Copyright 2018 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/pagination"
|
||||
"github.com/coreos/clair/pkg/testutil"
|
||||
)
|
||||
|
||||
// int keys must be the consistent with the database ID.
|
||||
var (
|
||||
realFeatures = map[int]database.Feature{
|
||||
1: {"ourchat", "0.5", "dpkg"},
|
||||
2: {"openssl", "1.0", "dpkg"},
|
||||
3: {"openssl", "2.0", "dpkg"},
|
||||
4: {"fake", "2.0", "rpm"},
|
||||
}
|
||||
|
||||
realNamespaces = map[int]database.Namespace{
|
||||
1: {"debian:7", "dpkg"},
|
||||
2: {"debian:8", "dpkg"},
|
||||
3: {"fake:1.0", "rpm"},
|
||||
}
|
||||
|
||||
realNamespacedFeatures = map[int]database.NamespacedFeature{
|
||||
1: {realFeatures[1], realNamespaces[1]},
|
||||
2: {realFeatures[2], realNamespaces[1]},
|
||||
3: {realFeatures[2], realNamespaces[2]},
|
||||
4: {realFeatures[3], realNamespaces[1]},
|
||||
}
|
||||
|
||||
realDetectors = map[int]database.Detector{
|
||||
1: database.NewNamespaceDetector("os-release", "1.0"),
|
||||
2: database.NewFeatureDetector("dpkg", "1.0"),
|
||||
3: database.NewFeatureDetector("rpm", "1.0"),
|
||||
4: database.NewNamespaceDetector("apt-sources", "1.0"),
|
||||
}
|
||||
|
||||
realLayers = map[int]database.Layer{
|
||||
2: {
|
||||
Hash: "layer-1",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[1], realDetectors[2]},
|
||||
{realFeatures[2], realDetectors[2]},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
},
|
||||
},
|
||||
6: {
|
||||
Hash: "layer-4",
|
||||
By: []database.Detector{realDetectors[1], realDetectors[2], realDetectors[3], realDetectors[4]},
|
||||
Features: []database.LayerFeature{
|
||||
{realFeatures[4], realDetectors[3]},
|
||||
{realFeatures[3], realDetectors[2]},
|
||||
},
|
||||
Namespaces: []database.LayerNamespace{
|
||||
{realNamespaces[1], realDetectors[1]},
|
||||
{realNamespaces[3], realDetectors[4]},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
realAncestries = map[int]database.Ancestry{
|
||||
2: {
|
||||
Name: "ancestry-2",
|
||||
By: []database.Detector{realDetectors[2], realDetectors[1]},
|
||||
Layers: []database.AncestryLayer{
|
||||
{
|
||||
"layer-0",
|
||||
[]database.AncestryFeature{},
|
||||
},
|
||||
{
|
||||
"layer-1",
|
||||
[]database.AncestryFeature{},
|
||||
},
|
||||
{
|
||||
"layer-2",
|
||||
[]database.AncestryFeature{
|
||||
{
|
||||
realNamespacedFeatures[1],
|
||||
realDetectors[2],
|
||||
realDetectors[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"layer-3b",
|
||||
[]database.AncestryFeature{
|
||||
{
|
||||
realNamespacedFeatures[3],
|
||||
realDetectors[2],
|
||||
realDetectors[1],
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
realVulnerability = map[int]database.Vulnerability{
|
||||
1: {
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Namespace: realNamespaces[1],
|
||||
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
|
||||
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
2: {
|
||||
Name: "CVE-NOPE",
|
||||
Namespace: realNamespaces[1],
|
||||
Description: "A vulnerability affecting nothing",
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
}
|
||||
|
||||
realNotification = map[int]database.VulnerabilityNotification{
|
||||
1: {
|
||||
NotificationHook: database.NotificationHook{
|
||||
Name: "test",
|
||||
},
|
||||
Old: takeVulnerabilityPointerFromMap(realVulnerability, 2),
|
||||
New: takeVulnerabilityPointerFromMap(realVulnerability, 1),
|
||||
},
|
||||
}
|
||||
|
||||
fakeFeatures = map[int]database.Feature{
|
||||
1: {
|
||||
Name: "ourchat",
|
||||
Version: "0.6",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
}
|
||||
|
||||
fakeNamespaces = map[int]database.Namespace{
|
||||
1: {"green hat", "rpm"},
|
||||
}
|
||||
fakeNamespacedFeatures = map[int]database.NamespacedFeature{
|
||||
1: {
|
||||
Feature: fakeFeatures[0],
|
||||
Namespace: realNamespaces[0],
|
||||
},
|
||||
}
|
||||
|
||||
fakeDetector = map[int]database.Detector{
|
||||
1: {
|
||||
Name: "fake",
|
||||
Version: "1.0",
|
||||
DType: database.FeatureDetectorType,
|
||||
},
|
||||
2: {
|
||||
Name: "fake2",
|
||||
Version: "2.0",
|
||||
DType: database.NamespaceDetectorType,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
func takeVulnerabilityPointerFromMap(m map[int]database.Vulnerability, id int) *database.Vulnerability {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func takeAncestryPointerFromMap(m map[int]database.Ancestry, id int) *database.Ancestry {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func takeLayerPointerFromMap(m map[int]database.Layer, id int) *database.Layer {
|
||||
x := m[id]
|
||||
return &x
|
||||
}
|
||||
|
||||
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
|
||||
rows, err := tx.Query("SELECT name, version_format FROM namespace")
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
namespaces := []database.Namespace{}
|
||||
for rows.Next() {
|
||||
var ns database.Namespace
|
||||
err := rows.Scan(&ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
|
||||
return namespaces
|
||||
}
|
||||
|
||||
func assertVulnerabilityNotificationWithVulnerableEqual(t *testing.T, key pagination.Key, expected, actual *database.VulnerabilityNotificationWithVulnerable) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return assert.Equal(t, expected.NotificationHook, actual.NotificationHook) &&
|
||||
AssertPagedVulnerableAncestriesEqual(t, key, expected.Old, actual.Old) &&
|
||||
AssertPagedVulnerableAncestriesEqual(t, key, expected.New, actual.New)
|
||||
}
|
||||
|
||||
func AssertPagedVulnerableAncestriesEqual(t *testing.T, key pagination.Key, expected, actual *database.PagedVulnerableAncestries) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return testutil.AssertVulnerabilityEqual(t, &expected.Vulnerability, &actual.Vulnerability) &&
|
||||
assert.Equal(t, expected.Limit, actual.Limit) &&
|
||||
assert.Equal(t, mustUnmarshalToken(key, expected.Current), mustUnmarshalToken(key, actual.Current)) &&
|
||||
assert.Equal(t, mustUnmarshalToken(key, expected.Next), mustUnmarshalToken(key, actual.Next)) &&
|
||||
assert.Equal(t, expected.End, actual.End) &&
|
||||
testutil.AssertIntStringMapEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func mustUnmarshalToken(key pagination.Key, token pagination.Token) Page {
|
||||
if token == pagination.FirstPageToken {
|
||||
return Page{}
|
||||
}
|
||||
|
||||
p := Page{}
|
||||
if err := key.UnmarshalToken(token, &p); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func mustMarshalToken(key pagination.Key, v interface{}) pagination.Token {
|
||||
token, err := key.MarshalToken(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return token
|
||||
}
|
@ -0,0 +1,285 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/deckarep/golang-set"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
// AssertDetectorsEqual asserts actual detectors are content wise equal to
|
||||
// expected detectors regardless of the ordering.
|
||||
func AssertDetectorsEqual(t *testing.T, expected, actual []database.Detector) bool {
|
||||
if len(expected) != len(actual) {
|
||||
return assert.Fail(t, "detectors are not equal", "expected: '%v', actual: '%v'", expected, actual)
|
||||
}
|
||||
|
||||
sort.Slice(expected, func(i, j int) bool {
|
||||
return expected[i].String() < expected[j].String()
|
||||
})
|
||||
|
||||
sort.Slice(actual, func(i, j int) bool {
|
||||
return actual[i].String() < actual[j].String()
|
||||
})
|
||||
|
||||
for i := range expected {
|
||||
if expected[i] != actual[i] {
|
||||
return assert.Fail(t, "detectors are not equal", "expected: '%v', actual: '%v'", expected, actual)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertAncestryEqual asserts actual ancestry equals to expected ancestry
|
||||
// content wise.
|
||||
func AssertAncestryEqual(t *testing.T, expected, actual *database.Ancestry) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if actual == nil || expected == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
if !assert.Equal(t, expected.Name, actual.Name) || !AssertDetectorsEqual(t, expected.By, actual.By) {
|
||||
return false
|
||||
}
|
||||
|
||||
if assert.Equal(t, len(expected.Layers), len(actual.Layers)) {
|
||||
for index := range expected.Layers {
|
||||
if !AssertAncestryLayerEqual(t, &expected.Layers[index], &actual.Layers[index]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AssertAncestryLayerEqual asserts actual ancestry layer equals to expected
|
||||
// ancestry layer content wise.
|
||||
func AssertAncestryLayerEqual(t *testing.T, expected, actual *database.AncestryLayer) bool {
|
||||
if !assert.Equal(t, expected.Hash, actual.Hash) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.Equal(t, len(expected.Features), len(actual.Features),
|
||||
"layer: %s\nExpected: %v\n Actual: %v",
|
||||
expected.Hash, expected.Features, actual.Features,
|
||||
) {
|
||||
return false
|
||||
}
|
||||
|
||||
// feature -> is in actual layer
|
||||
hitCounter := map[database.AncestryFeature]bool{}
|
||||
for _, f := range expected.Features {
|
||||
hitCounter[f] = false
|
||||
}
|
||||
|
||||
// if there's no extra features and no duplicated features, since expected
|
||||
// and actual have the same length, their result must equal.
|
||||
for _, f := range actual.Features {
|
||||
v, ok := hitCounter[f]
|
||||
assert.True(t, ok, "unexpected feature %s", f)
|
||||
assert.False(t, v, "duplicated feature %s", f)
|
||||
hitCounter[f] = true
|
||||
}
|
||||
|
||||
for f, visited := range hitCounter {
|
||||
assert.True(t, visited, "missing feature %s", f)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertElementsEqual asserts that content in actual equals to content in
|
||||
// expected array regardless of ordering.
|
||||
//
|
||||
// Note: This function uses interface wise comparison.
|
||||
func AssertElementsEqual(t *testing.T, expected, actual []interface{}) bool {
|
||||
counter := map[interface{}]bool{}
|
||||
for _, f := range expected {
|
||||
counter[f] = false
|
||||
}
|
||||
|
||||
for _, f := range actual {
|
||||
v, ok := counter[f]
|
||||
if !assert.True(t, ok, "unexpected element %v\nExpected: %v\n Actual: %v\n", f, expected, actual) {
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.False(t, v, "duplicated element %v\nExpected: %v\n Actual: %v\n", f, expected, actual) {
|
||||
return false
|
||||
}
|
||||
|
||||
counter[f] = true
|
||||
}
|
||||
|
||||
for f, visited := range counter {
|
||||
if !assert.True(t, visited, "missing feature %v\nExpected: %v\n Actual: %v\n", f, expected, actual) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertFeaturesEqual asserts content in actual equals content in expected
|
||||
// regardless of ordering.
|
||||
func AssertFeaturesEqual(t *testing.T, expected, actual []database.Feature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.Feature]bool{}
|
||||
for _, nf := range expected {
|
||||
has[nf] = false
|
||||
}
|
||||
|
||||
for _, nf := range actual {
|
||||
has[nf] = true
|
||||
}
|
||||
|
||||
for nf, visited := range has {
|
||||
if !assert.True(t, visited, nf.Name+" is expected") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// AssertLayerFeaturesEqual asserts content in actual equals to content in
|
||||
// expected regardless of ordering.
|
||||
func AssertLayerFeaturesEqual(t *testing.T, expected, actual []database.LayerFeature) bool {
|
||||
if !assert.Len(t, actual, len(expected)) {
|
||||
return false
|
||||
}
|
||||
|
||||
expectedInterfaces := []interface{}{}
|
||||
for _, e := range expected {
|
||||
expectedInterfaces = append(expectedInterfaces, e)
|
||||
}
|
||||
|
||||
actualInterfaces := []interface{}{}
|
||||
for _, a := range actual {
|
||||
actualInterfaces = append(actualInterfaces, a)
|
||||
}
|
||||
|
||||
return AssertElementsEqual(t, expectedInterfaces, actualInterfaces)
|
||||
}
|
||||
|
||||
// AssertNamespacesEqual asserts content in actual equals to content in
|
||||
// expected regardless of ordering.
|
||||
func AssertNamespacesEqual(t *testing.T, expected, actual []database.Namespace) bool {
|
||||
expectedInterfaces := []interface{}{}
|
||||
for _, e := range expected {
|
||||
expectedInterfaces = append(expectedInterfaces, e)
|
||||
}
|
||||
|
||||
actualInterfaces := []interface{}{}
|
||||
for _, a := range actual {
|
||||
actualInterfaces = append(actualInterfaces, a)
|
||||
}
|
||||
|
||||
return AssertElementsEqual(t, expectedInterfaces, actualInterfaces)
|
||||
}
|
||||
|
||||
// AssertLayerNamespacesEqual asserts content in actual equals to content in
|
||||
// expected regardless of ordering.
|
||||
func AssertLayerNamespacesEqual(t *testing.T, expected, actual []database.LayerNamespace) bool {
|
||||
expectedInterfaces := []interface{}{}
|
||||
for _, e := range expected {
|
||||
expectedInterfaces = append(expectedInterfaces, e)
|
||||
}
|
||||
|
||||
actualInterfaces := []interface{}{}
|
||||
for _, a := range actual {
|
||||
actualInterfaces = append(actualInterfaces, a)
|
||||
}
|
||||
|
||||
return AssertElementsEqual(t, expectedInterfaces, actualInterfaces)
|
||||
}
|
||||
|
||||
// AssertLayerEqual asserts actual layer equals to expected layer content wise.
|
||||
func AssertLayerEqual(t *testing.T, expected, actual *database.Layer) bool {
|
||||
if expected == actual {
|
||||
return true
|
||||
}
|
||||
|
||||
if expected == nil || actual == nil {
|
||||
return assert.Equal(t, expected, actual)
|
||||
}
|
||||
|
||||
return assert.Equal(t, expected.Hash, actual.Hash) &&
|
||||
AssertDetectorsEqual(t, expected.By, actual.By) &&
|
||||
AssertLayerFeaturesEqual(t, expected.Features, actual.Features) &&
|
||||
AssertLayerNamespacesEqual(t, expected.Namespaces, actual.Namespaces)
|
||||
}
|
||||
|
||||
// AssertIntStringMapEqual asserts two maps with integer as key and string as
|
||||
// value are equal.
|
||||
func AssertIntStringMapEqual(t *testing.T, expected, actual map[int]string) bool {
|
||||
checked := mapset.NewSet()
|
||||
for k, v := range expected {
|
||||
assert.Equal(t, v, actual[k])
|
||||
checked.Add(k)
|
||||
}
|
||||
|
||||
for k := range actual {
|
||||
if !assert.True(t, checked.Contains(k)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// AssertVulnerabilityEqual asserts two vulnerabilities are equal.
|
||||
func AssertVulnerabilityEqual(t *testing.T, expected, actual *database.Vulnerability) bool {
|
||||
return assert.Equal(t, expected.Name, actual.Name) &&
|
||||
assert.Equal(t, expected.Link, actual.Link) &&
|
||||
assert.Equal(t, expected.Description, actual.Description) &&
|
||||
assert.Equal(t, expected.Namespace, actual.Namespace) &&
|
||||
assert.Equal(t, expected.Severity, actual.Severity) &&
|
||||
AssertMetadataMapEqual(t, expected.Metadata, actual.Metadata)
|
||||
}
|
||||
|
||||
func castMetadataMapToInterface(metadata database.MetadataMap) map[string]interface{} {
|
||||
content, err := json.Marshal(metadata)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
data := make(map[string]interface{})
|
||||
if err := json.Unmarshal(content, &data); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return data
|
||||
}
|
||||
|
||||
// AssertMetadataMapEqual asserts two metadata maps are equal.
|
||||
func AssertMetadataMapEqual(t *testing.T, expected, actual database.MetadataMap) bool {
|
||||
expectedMap := castMetadataMapToInterface(expected)
|
||||
actualMap := castMetadataMapToInterface(actual)
|
||||
checked := mapset.NewSet()
|
||||
for k, v := range expectedMap {
|
||||
if !assert.Equal(t, v, (actualMap)[k]) {
|
||||
return false
|
||||
}
|
||||
|
||||
checked.Add(k)
|
||||
}
|
||||
|
||||
for k := range actual {
|
||||
if !assert.True(t, checked.Contains(k)) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
Loading…
Reference in new issue