database: postgres implementation with tests.
This commit is contained in:
parent
fb32dcfa58
commit
a5c6400065
261
database/pgsql/ancestry.go
Normal file
261
database/pgsql/ancestry.go
Normal file
@ -0,0 +1,261 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func (tx *pgSession) UpsertAncestry(ancestry database.Ancestry, features []database.NamespacedFeature, processedBy database.Processors) error {
|
||||
if ancestry.Name == "" {
|
||||
log.Warning("Empty ancestry name is not allowed")
|
||||
return commonerr.NewBadRequestError("could not insert an ancestry with empty name")
|
||||
}
|
||||
|
||||
if len(ancestry.Layers) == 0 {
|
||||
log.Warning("Empty ancestry is not allowed")
|
||||
return commonerr.NewBadRequestError("could not insert an ancestry with 0 layers")
|
||||
}
|
||||
|
||||
err := tx.deleteAncestry(ancestry.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ancestryID int64
|
||||
err = tx.QueryRow(insertAncestry, ancestry.Name).Scan(&ancestryID)
|
||||
if err != nil {
|
||||
if isErrUniqueViolation(err) {
|
||||
return handleError("insertAncestry", errors.New("Other Go-routine is processing this ancestry (skip)."))
|
||||
}
|
||||
return handleError("insertAncestry", err)
|
||||
}
|
||||
|
||||
err = tx.insertAncestryLayers(ancestryID, ancestry.Layers)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.insertAncestryFeatures(ancestryID, features)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.persistProcessors(persistAncestryLister,
|
||||
"persistAncestryLister",
|
||||
persistAncestryDetector,
|
||||
"persistAncestryDetector",
|
||||
ancestryID, processedBy)
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindAncestry(name string) (database.Ancestry, database.Processors, bool, error) {
|
||||
ancestry := database.Ancestry{Name: name}
|
||||
processed := database.Processors{}
|
||||
|
||||
var ancestryID int64
|
||||
err := tx.QueryRow(searchAncestry, name).Scan(&ancestryID)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return ancestry, processed, false, nil
|
||||
}
|
||||
return ancestry, processed, false, handleError("searchAncestry", err)
|
||||
}
|
||||
|
||||
ancestry.Layers, err = tx.findAncestryLayers(ancestryID)
|
||||
if err != nil {
|
||||
return ancestry, processed, false, err
|
||||
}
|
||||
|
||||
processed.Detectors, err = tx.findProcessors(searchAncestryDetectors, "searchAncestryDetectors", "detector", ancestryID)
|
||||
if err != nil {
|
||||
return ancestry, processed, false, err
|
||||
}
|
||||
|
||||
processed.Listers, err = tx.findProcessors(searchAncestryListers, "searchAncestryListers", "lister", ancestryID)
|
||||
if err != nil {
|
||||
return ancestry, processed, false, err
|
||||
}
|
||||
|
||||
return ancestry, processed, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) FindAncestryFeatures(name string) (database.AncestryWithFeatures, bool, error) {
|
||||
var (
|
||||
awf database.AncestryWithFeatures
|
||||
ok bool
|
||||
err error
|
||||
)
|
||||
awf.Ancestry, awf.ProcessedBy, ok, err = tx.FindAncestry(name)
|
||||
if err != nil {
|
||||
return awf, false, err
|
||||
}
|
||||
|
||||
if !ok {
|
||||
return awf, false, nil
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchAncestryFeatures, name)
|
||||
if err != nil {
|
||||
return awf, false, handleError("searchAncestryFeatures", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
nf := database.NamespacedFeature{}
|
||||
err := rows.Scan(&nf.Namespace.Name, &nf.Namespace.VersionFormat, &nf.Feature.Name, &nf.Feature.Version)
|
||||
if err != nil {
|
||||
return awf, false, handleError("searchAncestryFeatures", err)
|
||||
}
|
||||
nf.Feature.VersionFormat = nf.Namespace.VersionFormat
|
||||
awf.Features = append(awf.Features, nf)
|
||||
}
|
||||
|
||||
return awf, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) deleteAncestry(name string) error {
|
||||
result, err := tx.Exec(removeAncestry, name)
|
||||
if err != nil {
|
||||
return handleError("removeAncestry", err)
|
||||
}
|
||||
|
||||
_, err = result.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("removeAncestry", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findProcessors(query, queryName, processorType string, id int64) ([]string, error) {
|
||||
rows, err := tx.Query(query, id)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
log.Warning("No " + processorType + " are used")
|
||||
return nil, nil
|
||||
}
|
||||
return nil, handleError(queryName, err)
|
||||
}
|
||||
|
||||
var (
|
||||
processors []string
|
||||
processor string
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&processor)
|
||||
if err != nil {
|
||||
return nil, handleError(queryName, err)
|
||||
}
|
||||
processors = append(processors, processor)
|
||||
}
|
||||
|
||||
return processors, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findAncestryLayers(ancestryID int64) ([]database.Layer, error) {
|
||||
rows, err := tx.Query(searchAncestryLayer, ancestryID)
|
||||
if err != nil {
|
||||
return nil, handleError("searchAncestryLayer", err)
|
||||
}
|
||||
layers := []database.Layer{}
|
||||
for rows.Next() {
|
||||
var layer database.Layer
|
||||
err := rows.Scan(&layer.Hash)
|
||||
if err != nil {
|
||||
return nil, handleError("searchAncestryLayer", err)
|
||||
}
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) insertAncestryLayers(ancestryID int64, layers []database.Layer) error {
|
||||
layerIDs := map[string]sql.NullInt64{}
|
||||
for _, l := range layers {
|
||||
layerIDs[l.Hash] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
layerHashes := []string{}
|
||||
for hash := range layerIDs {
|
||||
layerHashes = append(layerHashes, hash)
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchLayerIDs, pq.Array(layerHashes))
|
||||
if err != nil {
|
||||
return handleError("searchLayerIDs", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
layerID sql.NullInt64
|
||||
layerName string
|
||||
)
|
||||
err := rows.Scan(&layerID, &layerName)
|
||||
if err != nil {
|
||||
return handleError("searchLayerIDs", err)
|
||||
}
|
||||
layerIDs[layerName] = layerID
|
||||
}
|
||||
|
||||
notFound := []string{}
|
||||
for hash, id := range layerIDs {
|
||||
if !id.Valid {
|
||||
notFound = append(notFound, hash)
|
||||
}
|
||||
}
|
||||
|
||||
if len(notFound) > 0 {
|
||||
return handleError("searchLayerIDs", fmt.Errorf("Layer %s is not found in database", strings.Join(notFound, ",")))
|
||||
}
|
||||
|
||||
//TODO(Sida): use bulk insert.
|
||||
stmt, err := tx.Prepare(insertAncestryLayer)
|
||||
if err != nil {
|
||||
return handleError("insertAncestryLayer", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
for index, layer := range layers {
|
||||
_, err := stmt.Exec(ancestryID, index, layerIDs[layer.Hash].Int64)
|
||||
if err != nil {
|
||||
return handleError("insertAncestryLayer", commonerr.CombineErrors(err, stmt.Close()))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) insertAncestryFeatures(ancestryID int64, features []database.NamespacedFeature) error {
|
||||
featureIDs, err := tx.findNamespacedFeatureIDs(features)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//TODO(Sida): use bulk insert.
|
||||
stmtFeatures, err := tx.Prepare(insertAncestryFeature)
|
||||
if err != nil {
|
||||
return handleError("insertAncestryFeature", err)
|
||||
}
|
||||
|
||||
defer stmtFeatures.Close()
|
||||
|
||||
for _, id := range featureIDs {
|
||||
if !id.Valid {
|
||||
return errors.New("requested namespaced feature is not in database")
|
||||
}
|
||||
|
||||
_, err := stmtFeatures.Exec(ancestryID, id)
|
||||
if err != nil {
|
||||
return handleError("insertAncestryFeature", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
207
database/pgsql/ancestry_test.go
Normal file
207
database/pgsql/ancestry_test.go
Normal file
@ -0,0 +1,207 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
func TestUpsertAncestry(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "UpsertAncestry", true)
|
||||
defer closeTest(t, store, tx)
|
||||
a1 := database.Ancestry{
|
||||
Name: "a1",
|
||||
Layers: []database.Layer{
|
||||
{Hash: "layer-N"},
|
||||
},
|
||||
}
|
||||
|
||||
a2 := database.Ancestry{}
|
||||
|
||||
a3 := database.Ancestry{
|
||||
Name: "a",
|
||||
Layers: []database.Layer{
|
||||
{Hash: "layer-0"},
|
||||
},
|
||||
}
|
||||
|
||||
a4 := database.Ancestry{
|
||||
Name: "a",
|
||||
Layers: []database.Layer{
|
||||
{Hash: "layer-1"},
|
||||
},
|
||||
}
|
||||
|
||||
f1 := database.Feature{
|
||||
Name: "wechat",
|
||||
Version: "0.5",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
// not in database
|
||||
f2 := database.Feature{
|
||||
Name: "wechat",
|
||||
Version: "0.6",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
n1 := database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
p := database.Processors{
|
||||
Listers: []string{"dpkg", "non-existing"},
|
||||
Detectors: []string{"os-release", "non-existing"},
|
||||
}
|
||||
|
||||
nsf1 := database.NamespacedFeature{
|
||||
Namespace: n1,
|
||||
Feature: f1,
|
||||
}
|
||||
|
||||
// not in database
|
||||
nsf2 := database.NamespacedFeature{
|
||||
Namespace: n1,
|
||||
Feature: f2,
|
||||
}
|
||||
|
||||
// invalid case
|
||||
assert.NotNil(t, tx.UpsertAncestry(a1, nil, database.Processors{}))
|
||||
assert.NotNil(t, tx.UpsertAncestry(a2, nil, database.Processors{}))
|
||||
// valid case
|
||||
assert.Nil(t, tx.UpsertAncestry(a3, nil, database.Processors{}))
|
||||
// replace invalid case
|
||||
assert.NotNil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1, nsf2}, p))
|
||||
// replace valid case
|
||||
assert.Nil(t, tx.UpsertAncestry(a4, []database.NamespacedFeature{nsf1}, p))
|
||||
// validate
|
||||
ancestry, ok, err := tx.FindAncestryFeatures("a")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, a4, ancestry.Ancestry)
|
||||
}
|
||||
|
||||
func assertProcessorsEqual(t *testing.T, expected database.Processors, actual database.Processors) bool {
|
||||
sort.Strings(expected.Detectors)
|
||||
sort.Strings(actual.Detectors)
|
||||
sort.Strings(expected.Listers)
|
||||
sort.Strings(actual.Listers)
|
||||
return assert.Equal(t, expected.Detectors, actual.Detectors) && assert.Equal(t, expected.Listers, actual.Listers)
|
||||
}
|
||||
|
||||
func TestFindAncestry(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "FindAncestry", true)
|
||||
defer closeTest(t, store, tx)
|
||||
|
||||
// not found
|
||||
_, _, ok, err := tx.FindAncestry("ancestry-non")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
expected := database.Ancestry{
|
||||
Name: "ancestry-1",
|
||||
Layers: []database.Layer{
|
||||
{Hash: "layer-0"},
|
||||
{Hash: "layer-1"},
|
||||
{Hash: "layer-2"},
|
||||
{Hash: "layer-3a"},
|
||||
},
|
||||
}
|
||||
|
||||
expectedProcessors := database.Processors{
|
||||
Detectors: []string{"os-release"},
|
||||
Listers: []string{"dpkg"},
|
||||
}
|
||||
|
||||
// found
|
||||
a, p, ok2, err := tx.FindAncestry("ancestry-1")
|
||||
if assert.Nil(t, err) && assert.True(t, ok2) {
|
||||
assertAncestryEqual(t, expected, a)
|
||||
assertProcessorsEqual(t, expectedProcessors, p)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAncestryWithFeatureEqual(t *testing.T, expected database.AncestryWithFeatures, actual database.AncestryWithFeatures) bool {
|
||||
return assertAncestryEqual(t, expected.Ancestry, actual.Ancestry) &&
|
||||
assertNamespacedFeatureEqual(t, expected.Features, actual.Features) &&
|
||||
assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy)
|
||||
}
|
||||
func assertAncestryEqual(t *testing.T, expected database.Ancestry, actual database.Ancestry) bool {
|
||||
return assert.Equal(t, expected.Name, actual.Name) && assert.Equal(t, expected.Layers, actual.Layers)
|
||||
}
|
||||
|
||||
func TestFindAncestryFeatures(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "FindAncestryFeatures", true)
|
||||
defer closeTest(t, store, tx)
|
||||
|
||||
// invalid
|
||||
_, ok, err := tx.FindAncestryFeatures("ancestry-non")
|
||||
if assert.Nil(t, err) {
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
expected := database.AncestryWithFeatures{
|
||||
Ancestry: database.Ancestry{
|
||||
Name: "ancestry-2",
|
||||
Layers: []database.Layer{
|
||||
{Hash: "layer-0"},
|
||||
{Hash: "layer-1"},
|
||||
{Hash: "layer-2"},
|
||||
{Hash: "layer-3b"},
|
||||
},
|
||||
},
|
||||
ProcessedBy: database.Processors{
|
||||
Detectors: []string{"os-release"},
|
||||
Listers: []string{"dpkg"},
|
||||
},
|
||||
Features: []database.NamespacedFeature{
|
||||
{
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
Feature: database.Feature{
|
||||
Name: "wechat",
|
||||
Version: "0.5",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
},
|
||||
{
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:8",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
Feature: database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "1.0",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
// valid
|
||||
ancestry, ok, err := tx.FindAncestryFeatures("ancestry-2")
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assertAncestryEqual(t, expected.Ancestry, ancestry.Ancestry)
|
||||
assertNamespacedFeatureEqual(t, expected.Features, ancestry.Features)
|
||||
assertProcessorsEqual(t, expected.ProcessedBy, ancestry.ProcessedBy)
|
||||
}
|
||||
}
|
@ -27,135 +27,200 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
"github.com/coreos/clair/pkg/strutil"
|
||||
)
|
||||
|
||||
const (
|
||||
numVulnerabilities = 100
|
||||
numFeatureVersions = 100
|
||||
numFeatures = 100
|
||||
)
|
||||
|
||||
func TestRaceAffects(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("RaceAffects", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
func testGenRandomVulnerabilityAndNamespacedFeature(t *testing.T, store database.Datastore) ([]database.NamespacedFeature, []database.VulnerabilityWithAffected) {
|
||||
tx, err := store.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Insert the Feature on which we'll work.
|
||||
feature := database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestRaceAffectsFeatureNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestRaceAffecturesFeature1",
|
||||
featureName := "TestFeature"
|
||||
featureVersionFormat := dpkg.ParserName
|
||||
// Insert the namespace on which we'll work.
|
||||
namespace := database.Namespace{
|
||||
Name: "TestRaceAffectsFeatureNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}
|
||||
_, err = datastore.insertFeature(feature)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
|
||||
if !assert.Nil(t, tx.PersistNamespaces([]database.Namespace{namespace})) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Initialize random generator and enforce max procs.
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
|
||||
// Generate FeatureVersions.
|
||||
featureVersions := make([]database.FeatureVersion, numFeatureVersions)
|
||||
for i := 0; i < numFeatureVersions; i++ {
|
||||
version := rand.Intn(numFeatureVersions)
|
||||
// Generate Distinct random features
|
||||
features := make([]database.Feature, numFeatures)
|
||||
nsFeatures := make([]database.NamespacedFeature, numFeatures)
|
||||
for i := 0; i < numFeatures; i++ {
|
||||
version := rand.Intn(numFeatures)
|
||||
|
||||
featureVersions[i] = database.FeatureVersion{
|
||||
Feature: feature,
|
||||
Version: strconv.Itoa(version),
|
||||
features[i] = database.Feature{
|
||||
Name: featureName,
|
||||
VersionFormat: featureVersionFormat,
|
||||
Version: strconv.Itoa(version),
|
||||
}
|
||||
|
||||
nsFeatures[i] = database.NamespacedFeature{
|
||||
Namespace: namespace,
|
||||
Feature: features[i],
|
||||
}
|
||||
}
|
||||
|
||||
// insert features
|
||||
if !assert.Nil(t, tx.PersistFeatures(features)) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Generate vulnerabilities.
|
||||
// They are mapped by fixed version, which will make verification really easy afterwards.
|
||||
vulnerabilities := make(map[int][]database.Vulnerability)
|
||||
vulnerabilities := []database.VulnerabilityWithAffected{}
|
||||
for i := 0; i < numVulnerabilities; i++ {
|
||||
version := rand.Intn(numFeatureVersions) + 1
|
||||
// any version less than this is vulnerable
|
||||
version := rand.Intn(numFeatures) + 1
|
||||
|
||||
// if _, ok := vulnerabilities[version]; !ok {
|
||||
// vulnerabilities[version] = make([]database.Vulnerability)
|
||||
// }
|
||||
|
||||
vulnerability := database.Vulnerability{
|
||||
Name: uuid.New(),
|
||||
Namespace: feature.Namespace,
|
||||
FixedIn: []database.FeatureVersion{
|
||||
vulnerability := database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: uuid.New(),
|
||||
Namespace: namespace,
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
Affected: []database.AffectedFeature{
|
||||
{
|
||||
Feature: feature,
|
||||
Version: strconv.Itoa(version),
|
||||
Namespace: namespace,
|
||||
FeatureName: featureName,
|
||||
AffectedVersion: strconv.Itoa(version),
|
||||
FixedInVersion: strconv.Itoa(version),
|
||||
},
|
||||
},
|
||||
Severity: database.UnknownSeverity,
|
||||
}
|
||||
|
||||
vulnerabilities[version] = append(vulnerabilities[version], vulnerability)
|
||||
vulnerabilities = append(vulnerabilities, vulnerability)
|
||||
}
|
||||
tx.Commit()
|
||||
|
||||
return nsFeatures, vulnerabilities
|
||||
}
|
||||
|
||||
func TestConcurrency(t *testing.T) {
|
||||
store, err := openDatabaseForTest("Concurrency", false)
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
start := time.Now()
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(100)
|
||||
for i := 0; i < 100; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
nsNamespaces := genRandomNamespaces(t, 100)
|
||||
tx, err := store.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
assert.Nil(t, tx.PersistNamespaces(nsNamespaces))
|
||||
tx.Commit()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
fmt.Println("total", time.Since(start))
|
||||
}
|
||||
|
||||
func genRandomNamespaces(t *testing.T, count int) []database.Namespace {
|
||||
r := make([]database.Namespace, count)
|
||||
for i := 0; i < count; i++ {
|
||||
r[i] = database.Namespace{
|
||||
Name: uuid.New(),
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func TestCaching(t *testing.T) {
|
||||
store, err := openDatabaseForTest("Caching", false)
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
defer store.Close()
|
||||
|
||||
nsFeatures, vulnerabilities := testGenRandomVulnerabilityAndNamespacedFeature(t, store)
|
||||
|
||||
fmt.Printf("%d features, %d vulnerabilities are generated", len(nsFeatures), len(vulnerabilities))
|
||||
|
||||
// Insert featureversions and vulnerabilities in parallel.
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for _, vulnerabilitiesM := range vulnerabilities {
|
||||
for _, vulnerability := range vulnerabilitiesM {
|
||||
err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
tx, err := store.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
fmt.Println("finished to insert vulnerabilities")
|
||||
|
||||
assert.Nil(t, tx.PersistNamespacedFeatures(nsFeatures))
|
||||
fmt.Println("finished to insert namespaced features")
|
||||
|
||||
tx.Commit()
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for i := 0; i < len(featureVersions); i++ {
|
||||
featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i])
|
||||
assert.Nil(t, err)
|
||||
tx, err := store.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
fmt.Println("finished to insert featureVersions")
|
||||
|
||||
assert.Nil(t, tx.InsertVulnerabilities(vulnerabilities))
|
||||
fmt.Println("finished to insert vulnerabilities")
|
||||
tx.Commit()
|
||||
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
tx, err := store.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
// Verify consistency now.
|
||||
var actualAffectedNames []string
|
||||
var expectedAffectedNames []string
|
||||
affected, err := tx.FindAffectedNamespacedFeatures(nsFeatures)
|
||||
if !assert.Nil(t, err) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
for _, featureVersion := range featureVersions {
|
||||
featureVersionVersion, _ := strconv.Atoi(featureVersion.Version)
|
||||
|
||||
// Get actual affects.
|
||||
rows, err := datastore.Query(searchComplexTestFeatureVersionAffects,
|
||||
featureVersion.ID)
|
||||
assert.Nil(t, err)
|
||||
defer rows.Close()
|
||||
|
||||
var vulnName string
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&vulnName)
|
||||
if !assert.Nil(t, err) {
|
||||
continue
|
||||
}
|
||||
actualAffectedNames = append(actualAffectedNames, vulnName)
|
||||
}
|
||||
if assert.Nil(t, rows.Err()) {
|
||||
rows.Close()
|
||||
for _, ansf := range affected {
|
||||
if !assert.True(t, ansf.Valid) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Get expected affects.
|
||||
for i := numVulnerabilities; i > featureVersionVersion; i-- {
|
||||
for _, vulnerability := range vulnerabilities[i] {
|
||||
expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name)
|
||||
expectedAffectedNames := []string{}
|
||||
for _, vuln := range vulnerabilities {
|
||||
if ok, err := versionfmt.InRange(dpkg.ParserName, ansf.Version, vuln.Affected[0].AffectedVersion); err == nil {
|
||||
if ok {
|
||||
expectedAffectedNames = append(expectedAffectedNames, vuln.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.Len(t, compareStringLists(expectedAffectedNames, actualAffectedNames), 0)
|
||||
assert.Len(t, compareStringLists(actualAffectedNames, expectedAffectedNames), 0)
|
||||
actualAffectedNames := []string{}
|
||||
for _, s := range ansf.AffectedBy {
|
||||
actualAffectedNames = append(actualAffectedNames, s.Name)
|
||||
}
|
||||
|
||||
assert.Len(t, strutil.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0)
|
||||
assert.Len(t, strutil.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0)
|
||||
}
|
||||
}
|
||||
|
@ -16,230 +16,366 @@ package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/lib/pq"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) {
|
||||
if feature.Name == "" {
|
||||
return 0, commonerr.NewBadRequestError("could not find/insert invalid Feature")
|
||||
}
|
||||
var (
|
||||
errFeatureNotFound = errors.New("Feature not found")
|
||||
)
|
||||
|
||||
// Do cache lookup.
|
||||
if pgSQL.cache != nil {
|
||||
promCacheQueriesTotal.WithLabelValues("feature").Inc()
|
||||
id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name)
|
||||
if found {
|
||||
promCacheHitsTotal.WithLabelValues("feature").Inc()
|
||||
return id.(int), nil
|
||||
}
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe cached features.
|
||||
defer observeQueryTime("insertFeature", "all", time.Now())
|
||||
|
||||
// Find or create Namespace.
|
||||
namespaceID, err := pgSQL.insertNamespace(feature.Namespace)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Find or create Feature.
|
||||
var id int
|
||||
err = pgSQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id)
|
||||
if err != nil {
|
||||
return 0, handleError("soiFeature", err)
|
||||
}
|
||||
|
||||
if pgSQL.cache != nil {
|
||||
pgSQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
type vulnerabilityAffecting struct {
|
||||
vulnerabilityID int64
|
||||
addedByID int64
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) insertFeatureVersion(fv database.FeatureVersion) (id int, err error) {
|
||||
err = versionfmt.Valid(fv.Feature.Namespace.VersionFormat, fv.Version)
|
||||
if err != nil {
|
||||
return 0, commonerr.NewBadRequestError("could not find/insert invalid FeatureVersion")
|
||||
func (tx *pgSession) PersistFeatures(features []database.Feature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Do cache lookup.
|
||||
cacheIndex := strings.Join([]string{"featureversion", fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":")
|
||||
if pgSQL.cache != nil {
|
||||
promCacheQueriesTotal.WithLabelValues("featureversion").Inc()
|
||||
id, found := pgSQL.cache.Get(cacheIndex)
|
||||
if found {
|
||||
promCacheHitsTotal.WithLabelValues("featureversion").Inc()
|
||||
return id.(int), nil
|
||||
// Sorting is needed before inserting into database to prevent deadlock.
|
||||
sort.Slice(features, func(i, j int) bool {
|
||||
return features[i].Name < features[j].Name ||
|
||||
features[i].Version < features[j].Version ||
|
||||
features[i].VersionFormat < features[j].VersionFormat
|
||||
})
|
||||
|
||||
// TODO(Sida): A better interface for bulk insertion is needed.
|
||||
keys := make([]interface{}, len(features)*3)
|
||||
for i, f := range features {
|
||||
keys[i*3] = f.Name
|
||||
keys[i*3+1] = f.Version
|
||||
keys[i*3+2] = f.VersionFormat
|
||||
if f.Name == "" || f.Version == "" || f.VersionFormat == "" {
|
||||
return commonerr.NewBadRequestError("Empty feature name, version or version format is not allowed")
|
||||
}
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe cached featureversions.
|
||||
defer observeQueryTime("insertFeatureVersion", "all", time.Now())
|
||||
|
||||
// Find or create Feature first.
|
||||
t := time.Now()
|
||||
featureID, err := pgSQL.insertFeature(fv.Feature)
|
||||
observeQueryTime("insertFeatureVersion", "insertFeature", t)
|
||||
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
fv.Feature.ID = featureID
|
||||
|
||||
// Try to find the FeatureVersion.
|
||||
//
|
||||
// In a populated database, the likelihood of the FeatureVersion already being there is high.
|
||||
// If we can find it here, we then avoid using a transaction and locking the database.
|
||||
err = pgSQL.QueryRow(searchFeatureVersion, featureID, fv.Version).Scan(&fv.ID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, handleError("searchFeatureVersion", err)
|
||||
}
|
||||
if err == nil {
|
||||
if pgSQL.cache != nil {
|
||||
pgSQL.cache.Add(cacheIndex, fv.ID)
|
||||
}
|
||||
|
||||
return fv.ID, nil
|
||||
}
|
||||
|
||||
// Begin transaction.
|
||||
tx, err := pgSQL.Begin()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, handleError("insertFeatureVersion.Begin()", err)
|
||||
}
|
||||
|
||||
// Lock Vulnerability_Affects_FeatureVersion exclusively.
|
||||
// We want to prevent InsertVulnerability to modify it.
|
||||
promConcurrentLockVAFV.Inc()
|
||||
defer promConcurrentLockVAFV.Dec()
|
||||
t = time.Now()
|
||||
_, err = tx.Exec(lockVulnerabilityAffects)
|
||||
observeQueryTime("insertFeatureVersion", "lock", t)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err)
|
||||
}
|
||||
|
||||
// Find or create FeatureVersion.
|
||||
var created bool
|
||||
|
||||
t = time.Now()
|
||||
err = tx.QueryRow(soiFeatureVersion, featureID, fv.Version).Scan(&created, &fv.ID)
|
||||
observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, handleError("soiFeatureVersion", err)
|
||||
}
|
||||
|
||||
if !created {
|
||||
// The featureVersion already existed, no need to link it to
|
||||
// vulnerabilities.
|
||||
tx.Commit()
|
||||
|
||||
if pgSQL.cache != nil {
|
||||
pgSQL.cache.Add(cacheIndex, fv.ID)
|
||||
}
|
||||
|
||||
return fv.ID, nil
|
||||
}
|
||||
|
||||
// Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in
|
||||
// Vulnerability_Affects_FeatureVersion.
|
||||
t = time.Now()
|
||||
err = linkFeatureVersionToVulnerabilities(tx, fv)
|
||||
observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// Commit transaction.
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return 0, handleError("insertFeatureVersion.Commit()", err)
|
||||
}
|
||||
|
||||
if pgSQL.cache != nil {
|
||||
pgSQL.cache.Add(cacheIndex, fv.ID)
|
||||
}
|
||||
|
||||
return fv.ID, nil
|
||||
_, err := tx.Exec(queryPersistFeature(len(features)), keys...)
|
||||
return handleError("queryPersistFeature", err)
|
||||
}
|
||||
|
||||
// TODO(Quentin-M): Batch me
|
||||
func (pgSQL *pgSQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) {
|
||||
IDs := make([]int, 0, len(featureVersions))
|
||||
type namespacedFeatureWithID struct {
|
||||
database.NamespacedFeature
|
||||
|
||||
for i := 0; i < len(featureVersions); i++ {
|
||||
id, err := pgSQL.insertFeatureVersion(featureVersions[i])
|
||||
if err != nil {
|
||||
return IDs, err
|
||||
}
|
||||
IDs = append(IDs, id)
|
||||
ID int64
|
||||
}
|
||||
|
||||
type vulnerabilityCache struct {
|
||||
nsFeatureID int64
|
||||
vulnID int64
|
||||
vulnAffectingID int64
|
||||
}
|
||||
|
||||
func (tx *pgSession) searchAffectingVulnerabilities(features []database.NamespacedFeature) ([]vulnerabilityCache, error) {
|
||||
if len(features) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return IDs, nil
|
||||
}
|
||||
|
||||
type vulnerabilityAffectsFeatureVersion struct {
|
||||
vulnerabilityID int
|
||||
fixedInID int
|
||||
fixedInVersion string
|
||||
}
|
||||
|
||||
func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error {
|
||||
// Select every vulnerability and the fixed version that affect this Feature.
|
||||
// TODO(Quentin-M): LIMIT
|
||||
rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID)
|
||||
ids, err := tx.findNamespacedFeatureIDs(features)
|
||||
if err != nil {
|
||||
return handleError("searchVulnerabilityFixedInFeature", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
fMap := map[int64]database.NamespacedFeature{}
|
||||
for i, f := range features {
|
||||
if !ids[i].Valid {
|
||||
return nil, errFeatureNotFound
|
||||
}
|
||||
fMap[ids[i].Int64] = f
|
||||
}
|
||||
|
||||
cacheTable := []vulnerabilityCache{}
|
||||
rows, err := tx.Query(searchPotentialAffectingVulneraibilities, pq.Array(ids))
|
||||
if err != nil {
|
||||
return nil, handleError("searchPotentialAffectingVulneraibilities", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var affects []vulnerabilityAffectsFeatureVersion
|
||||
for rows.Next() {
|
||||
var affect vulnerabilityAffectsFeatureVersion
|
||||
var (
|
||||
cache vulnerabilityCache
|
||||
affected string
|
||||
)
|
||||
|
||||
err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion)
|
||||
err := rows.Scan(&cache.nsFeatureID, &cache.vulnID, &affected, &cache.vulnAffectingID)
|
||||
if err != nil {
|
||||
return handleError("searchVulnerabilityFixedInFeature.Scan()", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cmp, err := versionfmt.Compare(featureVersion.Feature.Namespace.VersionFormat, featureVersion.Version, affect.fixedInVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmp < 0 {
|
||||
// The version of the FeatureVersion we are inserting is lower than the fixed version on this
|
||||
// Vulnerability, thus, this FeatureVersion is affected by it.
|
||||
affects = append(affects, affect)
|
||||
if ok, err := versionfmt.InRange(fMap[cache.nsFeatureID].VersionFormat, fMap[cache.nsFeatureID].Version, affected); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
cacheTable = append(cacheTable, cache)
|
||||
}
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return handleError("searchVulnerabilityFixedInFeature.Rows()", err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Insert into Vulnerability_Affects_FeatureVersion.
|
||||
for _, affect := range affects {
|
||||
// TODO(Quentin-M): Batch me.
|
||||
_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID,
|
||||
featureVersion.ID, affect.fixedInID)
|
||||
if err != nil {
|
||||
return handleError("insertVulnerabilityAffectsFeatureVersion", err)
|
||||
return cacheTable, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) CacheAffectedNamespacedFeatures(features []database.NamespacedFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := tx.Exec(lockVulnerabilityAffects)
|
||||
if err != nil {
|
||||
return handleError("lockVulnerabilityAffects", err)
|
||||
}
|
||||
|
||||
cache, err := tx.searchAffectingVulnerabilities(features)
|
||||
|
||||
keys := make([]interface{}, len(cache)*3)
|
||||
for i, c := range cache {
|
||||
keys[i*3] = c.vulnID
|
||||
keys[i*3+1] = c.nsFeatureID
|
||||
keys[i*3+2] = c.vulnAffectingID
|
||||
}
|
||||
|
||||
if len(cache) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
affected, err := tx.Exec(queryPersistVulnerabilityAffectedNamespacedFeature(len(cache)), keys...)
|
||||
if err != nil {
|
||||
return handleError("persistVulnerabilityAffectedNamespacedFeature", err)
|
||||
}
|
||||
if count, err := affected.RowsAffected(); err != nil {
|
||||
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", count)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) PersistNamespacedFeatures(features []database.NamespacedFeature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nsIDs := map[database.Namespace]sql.NullInt64{}
|
||||
fIDs := map[database.Feature]sql.NullInt64{}
|
||||
for _, f := range features {
|
||||
nsIDs[f.Namespace] = sql.NullInt64{}
|
||||
fIDs[f.Feature] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
fToFind := []database.Feature{}
|
||||
for f := range fIDs {
|
||||
fToFind = append(fToFind, f)
|
||||
}
|
||||
|
||||
sort.Slice(fToFind, func(i, j int) bool {
|
||||
return fToFind[i].Name < fToFind[j].Name ||
|
||||
fToFind[i].Version < fToFind[j].Version ||
|
||||
fToFind[i].VersionFormat < fToFind[j].VersionFormat
|
||||
})
|
||||
|
||||
if ids, err := tx.findFeatureIDs(fToFind); err == nil {
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return errFeatureNotFound
|
||||
}
|
||||
fIDs[fToFind[i]] = id
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
nsToFind := []database.Namespace{}
|
||||
for ns := range nsIDs {
|
||||
nsToFind = append(nsToFind, ns)
|
||||
}
|
||||
|
||||
if ids, err := tx.findNamespaceIDs(nsToFind); err == nil {
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return errNamespaceNotFound
|
||||
}
|
||||
nsIDs[nsToFind[i]] = id
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
|
||||
keys := make([]interface{}, len(features)*2)
|
||||
for i, f := range features {
|
||||
keys[i*2] = fIDs[f.Feature]
|
||||
keys[i*2+1] = nsIDs[f.Namespace]
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistNamespacedFeature(len(features)), keys...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// FindAffectedNamespacedFeatures looks up cache table and retrieves all
|
||||
// vulnerabilities associated with the features.
|
||||
func (tx *pgSession) FindAffectedNamespacedFeatures(features []database.NamespacedFeature) ([]database.NullableAffectedNamespacedFeature, error) {
|
||||
if len(features) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
returnFeatures := make([]database.NullableAffectedNamespacedFeature, len(features))
|
||||
|
||||
// featureMap is used to keep track of duplicated features.
|
||||
featureMap := map[database.NamespacedFeature][]*database.NullableAffectedNamespacedFeature{}
|
||||
// initialize return value and generate unique feature request queries.
|
||||
for i, f := range features {
|
||||
returnFeatures[i] = database.NullableAffectedNamespacedFeature{
|
||||
AffectedNamespacedFeature: database.AffectedNamespacedFeature{
|
||||
NamespacedFeature: f,
|
||||
},
|
||||
}
|
||||
|
||||
featureMap[f] = append(featureMap[f], &returnFeatures[i])
|
||||
}
|
||||
|
||||
// query unique namespaced features
|
||||
distinctFeatures := []database.NamespacedFeature{}
|
||||
for f := range featureMap {
|
||||
distinctFeatures = append(distinctFeatures, f)
|
||||
}
|
||||
|
||||
nsFeatureIDs, err := tx.findNamespacedFeatureIDs(distinctFeatures)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
toQuery := []int64{}
|
||||
featureIDMap := map[int64][]*database.NullableAffectedNamespacedFeature{}
|
||||
for i, id := range nsFeatureIDs {
|
||||
if id.Valid {
|
||||
toQuery = append(toQuery, id.Int64)
|
||||
for _, f := range featureMap[distinctFeatures[i]] {
|
||||
f.Valid = id.Valid
|
||||
featureIDMap[id.Int64] = append(featureIDMap[id.Int64], f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchNamespacedFeaturesVulnerabilities, pq.Array(toQuery))
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
featureID int64
|
||||
vuln database.VulnerabilityWithFixedIn
|
||||
)
|
||||
err := rows.Scan(&featureID,
|
||||
&vuln.Name,
|
||||
&vuln.Description,
|
||||
&vuln.Link,
|
||||
&vuln.Severity,
|
||||
&vuln.Metadata,
|
||||
&vuln.FixedInVersion,
|
||||
&vuln.Namespace.Name,
|
||||
&vuln.Namespace.VersionFormat,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeaturesVulnerabilities", err)
|
||||
}
|
||||
|
||||
for _, f := range featureIDMap[featureID] {
|
||||
f.AffectedBy = append(f.AffectedBy, vuln)
|
||||
}
|
||||
}
|
||||
|
||||
return returnFeatures, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findNamespacedFeatureIDs(nfs []database.NamespacedFeature) ([]sql.NullInt64, error) {
|
||||
if len(nfs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
nfsMap := map[database.NamespacedFeature]sql.NullInt64{}
|
||||
keys := make([]interface{}, len(nfs)*4)
|
||||
for i, nf := range nfs {
|
||||
keys[i*4] = nfs[i].Name
|
||||
keys[i*4+1] = nfs[i].Version
|
||||
keys[i*4+2] = nfs[i].VersionFormat
|
||||
keys[i*4+3] = nfs[i].Namespace.Name
|
||||
nfsMap[nf] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchNamespacedFeature(len(nfs)), keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeature", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
var (
|
||||
id sql.NullInt64
|
||||
nf database.NamespacedFeature
|
||||
)
|
||||
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &nf.Name, &nf.Version, &nf.VersionFormat, &nf.Namespace.Name)
|
||||
nf.Namespace.VersionFormat = nf.VersionFormat
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespacedFeature", err)
|
||||
}
|
||||
nfsMap[nf] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(nfs))
|
||||
for i, nf := range nfs {
|
||||
ids[i] = nfsMap[nf]
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findFeatureIDs(fs []database.Feature) ([]sql.NullInt64, error) {
|
||||
if len(fs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fMap := map[database.Feature]sql.NullInt64{}
|
||||
|
||||
keys := make([]interface{}, len(fs)*3)
|
||||
for i, f := range fs {
|
||||
keys[i*3] = f.Name
|
||||
keys[i*3+1] = f.Version
|
||||
keys[i*3+2] = f.VersionFormat
|
||||
fMap[f] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchFeatureID(len(fs)), keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchFeatureID", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var (
|
||||
id sql.NullInt64
|
||||
f database.Feature
|
||||
)
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &f.Name, &f.Version, &f.VersionFormat)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchFeatureID", err)
|
||||
}
|
||||
fMap[f] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(fs))
|
||||
for i, f := range fs {
|
||||
ids[i] = fMap[f]
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
@ -20,96 +20,237 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
|
||||
// register dpkg feature lister for testing
|
||||
_ "github.com/coreos/clair/ext/featurefmt/dpkg"
|
||||
)
|
||||
|
||||
func TestInsertFeature(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("InsertFeature", false)
|
||||
func TestPersistFeatures(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistFeatures", false)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
f1 := database.Feature{}
|
||||
f2 := database.Feature{Name: "n", Version: "v", VersionFormat: "vf"}
|
||||
|
||||
// empty
|
||||
assert.Nil(t, tx.PersistFeatures([]database.Feature{}))
|
||||
// invalid
|
||||
assert.NotNil(t, tx.PersistFeatures([]database.Feature{f1}))
|
||||
// duplicated
|
||||
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2, f2}))
|
||||
// existing
|
||||
assert.Nil(t, tx.PersistFeatures([]database.Feature{f2}))
|
||||
|
||||
fs := listFeatures(t, tx)
|
||||
assert.Len(t, fs, 1)
|
||||
assert.Equal(t, f2, fs[0])
|
||||
}
|
||||
|
||||
func TestPersistNamespacedFeatures(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistNamespacedFeatures", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// existing features
|
||||
f1 := database.Feature{
|
||||
Name: "wechat",
|
||||
Version: "0.5",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
// non-existing features
|
||||
f2 := database.Feature{
|
||||
Name: "fake!",
|
||||
}
|
||||
|
||||
f3 := database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "2.0",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
// exising namespace
|
||||
n1 := database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
n3 := database.Namespace{
|
||||
Name: "debian:8",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
// non-existing namespace
|
||||
n2 := database.Namespace{
|
||||
Name: "debian:non",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
// existing namespaced feature
|
||||
nf1 := database.NamespacedFeature{
|
||||
Namespace: n1,
|
||||
Feature: f1,
|
||||
}
|
||||
|
||||
// invalid namespaced feature
|
||||
nf2 := database.NamespacedFeature{
|
||||
Namespace: n2,
|
||||
Feature: f2,
|
||||
}
|
||||
|
||||
// new namespaced feature affected by vulnerability
|
||||
nf3 := database.NamespacedFeature{
|
||||
Namespace: n3,
|
||||
Feature: f3,
|
||||
}
|
||||
|
||||
// namespaced features with namespaces or features not in the database will
|
||||
// generate error.
|
||||
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{}))
|
||||
|
||||
assert.NotNil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf2}))
|
||||
// valid case: insert nf3
|
||||
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1, nf3}))
|
||||
|
||||
all := listNamespacedFeatures(t, tx)
|
||||
assert.Contains(t, all, nf1)
|
||||
assert.Contains(t, all, nf3)
|
||||
}
|
||||
|
||||
func TestVulnerableFeature(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "VulnerableFeature", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
f1 := database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "1.3",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
n1 := database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
nf1 := database.NamespacedFeature{
|
||||
Namespace: n1,
|
||||
Feature: f1,
|
||||
}
|
||||
assert.Nil(t, tx.PersistFeatures([]database.Feature{f1}))
|
||||
assert.Nil(t, tx.PersistNamespacedFeatures([]database.NamespacedFeature{nf1}))
|
||||
assert.Nil(t, tx.CacheAffectedNamespacedFeatures([]database.NamespacedFeature{nf1}))
|
||||
// ensure the namespaced feature is affected correctly
|
||||
anf, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{nf1})
|
||||
if assert.Nil(t, err) &&
|
||||
assert.Len(t, anf, 1) &&
|
||||
assert.True(t, anf[0].Valid) &&
|
||||
assert.Len(t, anf[0].AffectedBy, 1) {
|
||||
assert.Equal(t, "CVE-OPENSSL-1-DEB7", anf[0].AffectedBy[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindAffectedNamespacedFeatures(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindAffectedNamespacedFeatures", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
ns := database.NamespacedFeature{
|
||||
Feature: database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "1.0",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
}
|
||||
|
||||
ans, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{ns})
|
||||
if assert.Nil(t, err) &&
|
||||
assert.Len(t, ans, 1) &&
|
||||
assert.True(t, ans[0].Valid) &&
|
||||
assert.Len(t, ans[0].AffectedBy, 1) {
|
||||
assert.Equal(t, "CVE-OPENSSL-1-DEB7", ans[0].AffectedBy[0].Name)
|
||||
}
|
||||
}
|
||||
|
||||
func listNamespacedFeatures(t *testing.T, tx *pgSession) []database.NamespacedFeature {
|
||||
rows, err := tx.Query(`SELECT f.name, f.version, f.version_format, n.name, n.version_format
|
||||
FROM feature AS f, namespace AS n, namespaced_feature AS nf
|
||||
WHERE nf.feature_id = f.id AND nf.namespace_id = n.id`)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Invalid Feature.
|
||||
id0, err := datastore.insertFeature(database.Feature{})
|
||||
assert.NotNil(t, err)
|
||||
assert.Zero(t, id0)
|
||||
|
||||
id0, err = datastore.insertFeature(database.Feature{
|
||||
Namespace: database.Namespace{},
|
||||
Name: "TestInsertFeature0",
|
||||
})
|
||||
assert.NotNil(t, err)
|
||||
assert.Zero(t, id0)
|
||||
|
||||
// Insert Feature and ensure we can find it.
|
||||
feature := database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertFeatureNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertFeature1",
|
||||
}
|
||||
id1, err := datastore.insertFeature(feature)
|
||||
assert.Nil(t, err)
|
||||
id2, err := datastore.insertFeature(feature)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, id1, id2)
|
||||
|
||||
// Insert invalid FeatureVersion.
|
||||
for _, invalidFeatureVersion := range []database.FeatureVersion{
|
||||
{
|
||||
Feature: database.Feature{},
|
||||
Version: "1.0",
|
||||
},
|
||||
{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{},
|
||||
Name: "TestInsertFeature2",
|
||||
},
|
||||
Version: "1.0",
|
||||
},
|
||||
{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertFeatureNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertFeature2",
|
||||
},
|
||||
Version: "",
|
||||
},
|
||||
{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertFeatureNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertFeature2",
|
||||
},
|
||||
Version: "bad version",
|
||||
},
|
||||
} {
|
||||
id3, err := datastore.insertFeatureVersion(invalidFeatureVersion)
|
||||
assert.Error(t, err)
|
||||
assert.Zero(t, id3)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
// Insert FeatureVersion and ensure we can find it.
|
||||
featureVersion := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertFeatureNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertFeature1",
|
||||
},
|
||||
Version: "2:3.0-imba",
|
||||
nf := []database.NamespacedFeature{}
|
||||
for rows.Next() {
|
||||
f := database.NamespacedFeature{}
|
||||
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat, &f.Namespace.Name, &f.Namespace.VersionFormat)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
nf = append(nf, f)
|
||||
}
|
||||
id4, err := datastore.insertFeatureVersion(featureVersion)
|
||||
assert.Nil(t, err)
|
||||
id5, err := datastore.insertFeatureVersion(featureVersion)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, id4, id5)
|
||||
|
||||
return nf
|
||||
}
|
||||
|
||||
func listFeatures(t *testing.T, tx *pgSession) []database.Feature {
|
||||
rows, err := tx.Query("SELECT name, version, version_format FROM feature")
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
fs := []database.Feature{}
|
||||
for rows.Next() {
|
||||
f := database.Feature{}
|
||||
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
fs = append(fs, f)
|
||||
}
|
||||
return fs
|
||||
}
|
||||
|
||||
func assertFeaturesEqual(t *testing.T, expected []database.Feature, actual []database.Feature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.Feature]bool{}
|
||||
for _, nf := range expected {
|
||||
has[nf] = false
|
||||
}
|
||||
|
||||
for _, nf := range actual {
|
||||
has[nf] = true
|
||||
}
|
||||
|
||||
for nf, visited := range has {
|
||||
if !assert.True(t, visited, nf.Name+" is expected") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func assertNamespacedFeatureEqual(t *testing.T, expected []database.NamespacedFeature, actual []database.NamespacedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.NamespacedFeature]bool{}
|
||||
for _, nf := range expected {
|
||||
has[nf] = false
|
||||
}
|
||||
|
||||
for _, nf := range actual {
|
||||
has[nf] = true
|
||||
}
|
||||
|
||||
for nf, visited := range has {
|
||||
if !assert.True(t, visited, nf.Namespace.Name+":"+nf.Name+" is expected") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@ -23,63 +23,35 @@ import (
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
// InsertKeyValue stores (or updates) a single key / value tuple.
|
||||
func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
|
||||
func (tx *pgSession) UpdateKeyValue(key, value string) (err error) {
|
||||
if key == "" || value == "" {
|
||||
log.Warning("could not insert a flag which has an empty name or value")
|
||||
return commonerr.NewBadRequestError("could not insert a flag which has an empty name or value")
|
||||
}
|
||||
|
||||
defer observeQueryTime("InsertKeyValue", "all", time.Now())
|
||||
defer observeQueryTime("PersistKeyValue", "all", time.Now())
|
||||
|
||||
// Upsert.
|
||||
//
|
||||
// Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS.
|
||||
// The best solution is currently the use of http://dba.stackexchange.com/a/13477
|
||||
// but the key/value storage doesn't need to be super-efficient and super-safe at the
|
||||
// moment so we can just use a client-side solution with transactions, based on
|
||||
// http://postgresql.org/docs/current/static/plpgsql-control-structures.html.
|
||||
// TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable.
|
||||
|
||||
for {
|
||||
// First, try to update.
|
||||
r, err := pgSQL.Exec(updateKeyValue, value, key)
|
||||
if err != nil {
|
||||
return handleError("updateKeyValue", err)
|
||||
}
|
||||
if n, _ := r.RowsAffected(); n > 0 {
|
||||
// Updated successfully.
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to insert the key.
|
||||
// If someone else inserts the same key concurrently, we could get a unique-key violation error.
|
||||
_, err = pgSQL.Exec(insertKeyValue, key, value)
|
||||
if err != nil {
|
||||
if isErrUniqueViolation(err) {
|
||||
// Got unique constraint violation, retry.
|
||||
continue
|
||||
}
|
||||
return handleError("insertKeyValue", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
_, err = tx.Exec(upsertKeyValue, key, value)
|
||||
if err != nil {
|
||||
return handleError("insertKeyValue", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist.
|
||||
func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) {
|
||||
defer observeQueryTime("GetKeyValue", "all", time.Now())
|
||||
func (tx *pgSession) FindKeyValue(key string) (string, bool, error) {
|
||||
defer observeQueryTime("FindKeyValue", "all", time.Now())
|
||||
|
||||
var value string
|
||||
err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value)
|
||||
err := tx.QueryRow(searchKeyValue, key).Scan(&value)
|
||||
|
||||
if err == sql.ErrNoRows {
|
||||
return "", nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", handleError("searchKeyValue", err)
|
||||
return "", false, nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
if err != nil {
|
||||
return "", false, handleError("searchKeyValue", err)
|
||||
}
|
||||
|
||||
return value, true, nil
|
||||
}
|
||||
|
@ -21,32 +21,30 @@ import (
|
||||
)
|
||||
|
||||
func TestKeyValue(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("KeyValue", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
datastore, tx := openSessionForTest(t, "KeyValue", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// Get non-existing key/value
|
||||
f, err := datastore.GetKeyValue("test")
|
||||
f, ok, err := tx.FindKeyValue("test")
|
||||
assert.Nil(t, err)
|
||||
assert.Empty(t, "", f)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Try to insert invalid key/value.
|
||||
assert.Error(t, datastore.InsertKeyValue("test", ""))
|
||||
assert.Error(t, datastore.InsertKeyValue("", "test"))
|
||||
assert.Error(t, datastore.InsertKeyValue("", ""))
|
||||
assert.Error(t, tx.UpdateKeyValue("test", ""))
|
||||
assert.Error(t, tx.UpdateKeyValue("", "test"))
|
||||
assert.Error(t, tx.UpdateKeyValue("", ""))
|
||||
|
||||
// Insert and verify.
|
||||
assert.Nil(t, datastore.InsertKeyValue("test", "test1"))
|
||||
f, err = datastore.GetKeyValue("test")
|
||||
assert.Nil(t, tx.UpdateKeyValue("test", "test1"))
|
||||
f, ok, err = tx.FindKeyValue("test")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test1", f)
|
||||
|
||||
// Update and verify.
|
||||
assert.Nil(t, datastore.InsertKeyValue("test", "test2"))
|
||||
f, err = datastore.GetKeyValue("test")
|
||||
assert.Nil(t, tx.UpdateKeyValue("test", "test2"))
|
||||
f, ok, err = tx.FindKeyValue("test")
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, "test2", f)
|
||||
}
|
||||
|
@ -16,464 +16,293 @@ package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/guregu/null/zero"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) {
|
||||
subquery := "all"
|
||||
if withFeatures {
|
||||
subquery += "/features"
|
||||
} else if withVulnerabilities {
|
||||
subquery += "/features+vulnerabilities"
|
||||
}
|
||||
defer observeQueryTime("FindLayer", subquery, time.Now())
|
||||
func (tx *pgSession) FindLayer(hash string) (database.Layer, database.Processors, bool, error) {
|
||||
l, p, _, ok, err := tx.findLayer(hash)
|
||||
return l, p, ok, err
|
||||
}
|
||||
|
||||
// Find the layer
|
||||
func (tx *pgSession) FindLayerWithContent(hash string) (database.LayerWithContent, bool, error) {
|
||||
var (
|
||||
layer database.Layer
|
||||
parentID zero.Int
|
||||
parentName zero.String
|
||||
nsID zero.Int
|
||||
nsName sql.NullString
|
||||
nsVersionFormat sql.NullString
|
||||
layer database.LayerWithContent
|
||||
layerID int64
|
||||
ok bool
|
||||
err error
|
||||
)
|
||||
|
||||
t := time.Now()
|
||||
err := pgSQL.QueryRow(searchLayer, name).Scan(
|
||||
&layer.ID,
|
||||
&layer.Name,
|
||||
&layer.EngineVersion,
|
||||
&parentID,
|
||||
&parentName,
|
||||
)
|
||||
observeQueryTime("FindLayer", "searchLayer", t)
|
||||
|
||||
layer.Layer, layer.ProcessedBy, layerID, ok, err = tx.findLayer(hash)
|
||||
if err != nil {
|
||||
return layer, handleError("searchLayer", err)
|
||||
return layer, false, err
|
||||
}
|
||||
|
||||
if !parentID.IsZero() {
|
||||
layer.Parent = &database.Layer{
|
||||
Model: database.Model{ID: int(parentID.Int64)},
|
||||
Name: parentName.String,
|
||||
}
|
||||
if !ok {
|
||||
return layer, false, nil
|
||||
}
|
||||
|
||||
rows, err := pgSQL.Query(searchLayerNamespace, layer.ID)
|
||||
defer rows.Close()
|
||||
if err != nil {
|
||||
return layer, handleError("searchLayerNamespace", err)
|
||||
}
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&nsID, &nsName, &nsVersionFormat)
|
||||
if err != nil {
|
||||
return layer, handleError("searchLayerNamespace", err)
|
||||
}
|
||||
if !nsID.IsZero() {
|
||||
layer.Namespaces = append(layer.Namespaces, database.Namespace{
|
||||
Model: database.Model{ID: int(nsID.Int64)},
|
||||
Name: nsName.String,
|
||||
VersionFormat: nsVersionFormat.String,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Find its features
|
||||
if withFeatures || withVulnerabilities {
|
||||
// Create a transaction to disable hash/merge joins as our experiments have shown that
|
||||
// PostgreSQL 9.4 makes bad planning decisions about:
|
||||
// - joining the layer tree to feature versions and feature
|
||||
// - joining the feature versions to affected/fixed feature version and vulnerabilities
|
||||
// It would for instance do a merge join between affected feature versions (300 rows, estimated
|
||||
// 3000 rows) and fixed in feature version (100k rows). In this case, it is much more
|
||||
// preferred to use a nested loop.
|
||||
tx, err := pgSQL.Begin()
|
||||
if err != nil {
|
||||
return layer, handleError("FindLayer.Begin()", err)
|
||||
}
|
||||
defer tx.Commit()
|
||||
|
||||
_, err = tx.Exec(disableHashJoin)
|
||||
if err != nil {
|
||||
log.WithError(err).Warningf("FindLayer: could not disable hash join")
|
||||
}
|
||||
_, err = tx.Exec(disableMergeJoin)
|
||||
if err != nil {
|
||||
log.WithError(err).Warningf("FindLayer: could not disable merge join")
|
||||
}
|
||||
|
||||
t = time.Now()
|
||||
featureVersions, err := getLayerFeatureVersions(tx, layer.ID)
|
||||
observeQueryTime("FindLayer", "getLayerFeatureVersions", t)
|
||||
|
||||
if err != nil {
|
||||
return layer, err
|
||||
}
|
||||
|
||||
layer.Features = featureVersions
|
||||
|
||||
if withVulnerabilities {
|
||||
// Load the vulnerabilities that affect the FeatureVersions.
|
||||
t = time.Now()
|
||||
err := loadAffectedBy(tx, layer.Features)
|
||||
observeQueryTime("FindLayer", "loadAffectedBy", t)
|
||||
|
||||
if err != nil {
|
||||
return layer, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return layer, nil
|
||||
layer.Features, err = tx.findLayerFeatures(layerID)
|
||||
layer.Namespaces, err = tx.findLayerNamespaces(layerID)
|
||||
return layer, true, nil
|
||||
}
|
||||
|
||||
// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has.
|
||||
func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) {
|
||||
var featureVersions []database.FeatureVersion
|
||||
func (tx *pgSession) PersistLayer(layer database.Layer) error {
|
||||
if layer.Hash == "" {
|
||||
return commonerr.NewBadRequestError("Empty Layer Hash is not allowed")
|
||||
}
|
||||
|
||||
// Query.
|
||||
rows, err := tx.Query(searchLayerFeatureVersion, layerID)
|
||||
_, err := tx.Exec(queryPersistLayer(1), layer.Hash)
|
||||
if err != nil {
|
||||
return featureVersions, handleError("searchLayerFeatureVersion", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Scan query.
|
||||
var modification string
|
||||
mapFeatureVersions := make(map[int]database.FeatureVersion)
|
||||
for rows.Next() {
|
||||
var fv database.FeatureVersion
|
||||
err = rows.Scan(
|
||||
&fv.ID,
|
||||
&modification,
|
||||
&fv.Feature.Namespace.ID,
|
||||
&fv.Feature.Namespace.Name,
|
||||
&fv.Feature.Namespace.VersionFormat,
|
||||
&fv.Feature.ID,
|
||||
&fv.Feature.Name,
|
||||
&fv.ID,
|
||||
&fv.Version,
|
||||
&fv.AddedBy.ID,
|
||||
&fv.AddedBy.Name,
|
||||
)
|
||||
if err != nil {
|
||||
return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err)
|
||||
}
|
||||
|
||||
// Do transitive closure.
|
||||
switch modification {
|
||||
case "add":
|
||||
mapFeatureVersions[fv.ID] = fv
|
||||
case "del":
|
||||
delete(mapFeatureVersions, fv.ID)
|
||||
default:
|
||||
log.WithField("modification", modification).Warning("unknown Layer_diff_FeatureVersion's modification")
|
||||
return featureVersions, database.ErrInconsistent
|
||||
}
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err)
|
||||
return handleError("queryPersistLayer", err)
|
||||
}
|
||||
|
||||
// Build result by converting our map to a slice.
|
||||
for _, featureVersion := range mapFeatureVersions {
|
||||
featureVersions = append(featureVersions, featureVersion)
|
||||
}
|
||||
|
||||
return featureVersions, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadAffectedBy returns the list of database.Vulnerability that affect the given
|
||||
// FeatureVersion.
|
||||
func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error {
|
||||
if len(featureVersions) == 0 {
|
||||
// PersistLayerContent relates layer identified by hash with namespaces,
|
||||
// features and processors provided. If the layer, namespaces, features are not
|
||||
// in database, the function returns an error.
|
||||
func (tx *pgSession) PersistLayerContent(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error {
|
||||
if hash == "" {
|
||||
return commonerr.NewBadRequestError("Empty layer hash is not allowed")
|
||||
}
|
||||
|
||||
var layerID int64
|
||||
err := tx.QueryRow(searchLayer, hash).Scan(&layerID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistLayerNamespace(layerID, namespaces); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistLayerFeatures(layerID, features); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistLayerDetectors(layerID, processedBy.Detectors); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = tx.persistLayerListers(layerID, processedBy.Listers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerDetectors(id int64, detectors []string) error {
|
||||
if len(detectors) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Construct list of FeatureVersion IDs, we will do a single query
|
||||
featureVersionIDs := make([]int, 0, len(featureVersions))
|
||||
for i := 0; i < len(featureVersions); i++ {
|
||||
featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID)
|
||||
// Sorting is needed before inserting into database to prevent deadlock.
|
||||
sort.Strings(detectors)
|
||||
keys := make([]interface{}, len(detectors)*2)
|
||||
for i, d := range detectors {
|
||||
keys[i*2] = id
|
||||
keys[i*2+1] = d
|
||||
}
|
||||
_, err := tx.Exec(queryPersistLayerDetectors(len(detectors)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistLayerDetectors", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerListers(id int64, listers []string) error {
|
||||
if len(listers) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchFeatureVersionVulnerability,
|
||||
buildInputArray(featureVersionIDs))
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return handleError("searchFeatureVersionVulnerability", err)
|
||||
sort.Strings(listers)
|
||||
keys := make([]interface{}, len(listers)*2)
|
||||
for i, d := range listers {
|
||||
keys[i*2] = id
|
||||
keys[i*2+1] = d
|
||||
}
|
||||
|
||||
_, err := tx.Exec(queryPersistLayerListers(len(listers)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistLayerDetectors", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerFeatures(id int64, features []database.Feature) error {
|
||||
if len(features) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
fIDs, err := tx.findFeatureIDs(features)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ids := make([]int, len(fIDs))
|
||||
for i, fID := range fIDs {
|
||||
if !fID.Valid {
|
||||
return errNamespaceNotFound
|
||||
}
|
||||
ids[i] = int(fID.Int64)
|
||||
}
|
||||
|
||||
sort.IntSlice(ids).Sort()
|
||||
keys := make([]interface{}, len(features)*2)
|
||||
for i, fID := range ids {
|
||||
keys[i*2] = id
|
||||
keys[i*2+1] = fID
|
||||
}
|
||||
|
||||
_, err = tx.Exec(queryPersistLayerFeature(len(features)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistLayerFeature", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistLayerNamespace(id int64, namespaces []database.Namespace) error {
|
||||
if len(namespaces) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
nsIDs, err := tx.findNamespaceIDs(namespaces)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// for every bulk persist operation, the input data should be sorted.
|
||||
ids := make([]int, len(nsIDs))
|
||||
for i, nsID := range nsIDs {
|
||||
if !nsID.Valid {
|
||||
panic(errNamespaceNotFound)
|
||||
}
|
||||
ids[i] = int(nsID.Int64)
|
||||
}
|
||||
|
||||
sort.IntSlice(ids).Sort()
|
||||
|
||||
keys := make([]interface{}, len(namespaces)*2)
|
||||
for i, nsID := range ids {
|
||||
keys[i*2] = id
|
||||
keys[i*2+1] = nsID
|
||||
}
|
||||
|
||||
_, err = tx.Exec(queryPersistLayerNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryPersistLayerNamespace", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) persistProcessors(listerQuery, listerQueryName, detectorQuery, detectorQueryName string, id int64, processors database.Processors) error {
|
||||
stmt, err := tx.Prepare(listerQuery)
|
||||
if err != nil {
|
||||
return handleError(listerQueryName, err)
|
||||
}
|
||||
|
||||
for _, l := range processors.Listers {
|
||||
_, err := stmt.Exec(id, l)
|
||||
if err != nil {
|
||||
stmt.Close()
|
||||
return handleError(listerQueryName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
return handleError(listerQueryName, err)
|
||||
}
|
||||
|
||||
stmt, err = tx.Prepare(detectorQuery)
|
||||
if err != nil {
|
||||
return handleError(detectorQueryName, err)
|
||||
}
|
||||
|
||||
for _, d := range processors.Detectors {
|
||||
_, err := stmt.Exec(id, d)
|
||||
if err != nil {
|
||||
stmt.Close()
|
||||
return handleError(detectorQueryName, err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
return handleError(detectorQueryName, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) findLayerNamespaces(layerID int64) ([]database.Namespace, error) {
|
||||
var namespaces []database.Namespace
|
||||
|
||||
rows, err := tx.Query(searchLayerNamespaces, layerID)
|
||||
if err != nil {
|
||||
return nil, handleError("searchLayerFeatures", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions))
|
||||
var featureversionID int
|
||||
for rows.Next() {
|
||||
var vulnerability database.Vulnerability
|
||||
err := rows.Scan(
|
||||
&featureversionID,
|
||||
&vulnerability.ID,
|
||||
&vulnerability.Name,
|
||||
&vulnerability.Description,
|
||||
&vulnerability.Link,
|
||||
&vulnerability.Severity,
|
||||
&vulnerability.Metadata,
|
||||
&vulnerability.Namespace.Name,
|
||||
&vulnerability.Namespace.VersionFormat,
|
||||
&vulnerability.FixedBy,
|
||||
)
|
||||
ns := database.Namespace{}
|
||||
err := rows.Scan(&ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
return handleError("searchFeatureVersionVulnerability.Scan()", err)
|
||||
return nil, err
|
||||
}
|
||||
vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability)
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return handleError("searchFeatureVersionVulnerability.Rows()", err)
|
||||
}
|
||||
|
||||
// Assign vulnerabilities to every FeatureVersions
|
||||
for i := 0; i < len(featureVersions); i++ {
|
||||
featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID]
|
||||
}
|
||||
|
||||
return nil
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent,
|
||||
// the Feature list will be compared to the parent's Feature list and the difference will be stored.
|
||||
// Note that when the Namespace of a layer differs from its parent, it is expected that several
|
||||
// Feature that were already included a parent will have their Namespace updated as well
|
||||
// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed
|
||||
// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't
|
||||
// been modified.
|
||||
func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
|
||||
tf := time.Now()
|
||||
func (tx *pgSession) findLayerFeatures(layerID int64) ([]database.Feature, error) {
|
||||
var features []database.Feature
|
||||
|
||||
// Verify parameters
|
||||
if layer.Name == "" {
|
||||
log.Warning("could not insert a layer which has an empty Name")
|
||||
return commonerr.NewBadRequestError("could not insert a layer which has an empty Name")
|
||||
}
|
||||
|
||||
// Get a potentially existing layer.
|
||||
existingLayer, err := pgSQL.FindLayer(layer.Name, true, false)
|
||||
if err != nil && err != commonerr.ErrNotFound {
|
||||
return err
|
||||
} else if err == nil {
|
||||
if existingLayer.EngineVersion >= layer.EngineVersion {
|
||||
// The layer exists and has an equal or higher engine version, do nothing.
|
||||
return nil
|
||||
}
|
||||
|
||||
layer.ID = existingLayer.ID
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe existing layers.
|
||||
defer observeQueryTime("InsertLayer", "all", tf)
|
||||
|
||||
// Get parent ID.
|
||||
var parentID zero.Int
|
||||
if layer.Parent != nil {
|
||||
if layer.Parent.ID == 0 {
|
||||
log.Warning("Parent is expected to be retrieved from database when inserting a layer.")
|
||||
return commonerr.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.")
|
||||
}
|
||||
|
||||
parentID = zero.IntFrom(int64(layer.Parent.ID))
|
||||
}
|
||||
|
||||
// namespaceIDs will contain inherited and new namespaces
|
||||
namespaceIDs := make(map[int]struct{})
|
||||
|
||||
// try to insert the new namespaces
|
||||
for _, ns := range layer.Namespaces {
|
||||
n, err := pgSQL.insertNamespace(ns)
|
||||
if err != nil {
|
||||
return handleError("pgSQL.insertNamespace", err)
|
||||
}
|
||||
namespaceIDs[n] = struct{}{}
|
||||
}
|
||||
|
||||
// inherit namespaces from parent layer
|
||||
if layer.Parent != nil {
|
||||
for _, ns := range layer.Parent.Namespaces {
|
||||
namespaceIDs[ns.ID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// Begin transaction.
|
||||
tx, err := pgSQL.Begin()
|
||||
rows, err := tx.Query(searchLayerFeatures, layerID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("InsertLayer.Begin()", err)
|
||||
return nil, handleError("searchLayerFeatures", err)
|
||||
}
|
||||
|
||||
if layer.ID == 0 {
|
||||
// Insert a new layer.
|
||||
err = tx.QueryRow(insertLayer, layer.Name, layer.EngineVersion, parentID).
|
||||
Scan(&layer.ID)
|
||||
for rows.Next() {
|
||||
f := database.Feature{}
|
||||
err := rows.Scan(&f.Name, &f.Version, &f.VersionFormat)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
||||
if isErrUniqueViolation(err) {
|
||||
// Ignore this error, another process collided.
|
||||
log.Debug("Attempted to insert duplicate layer.")
|
||||
return nil
|
||||
}
|
||||
return handleError("insertLayer", err)
|
||||
}
|
||||
} else {
|
||||
// Update an existing layer.
|
||||
_, err = tx.Exec(updateLayer, layer.ID, layer.EngineVersion)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("updateLayer", err)
|
||||
}
|
||||
|
||||
// replace the old namespace in the database
|
||||
_, err := tx.Exec(removeLayerNamespace, layer.ID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("removeLayerNamespace", err)
|
||||
}
|
||||
// Remove all existing Layer_diff_FeatureVersion.
|
||||
_, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("removeLayerDiffFeatureVersion", err)
|
||||
return nil, err
|
||||
}
|
||||
features = append(features, f)
|
||||
}
|
||||
|
||||
// insert the layer's namespaces
|
||||
stmt, err := tx.Prepare(insertLayerNamespace)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("failed to prepare statement", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err = stmt.Close()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
log.WithError(err).Error("failed to close prepared statement")
|
||||
}
|
||||
}()
|
||||
|
||||
for nsid := range namespaceIDs {
|
||||
_, err := stmt.Exec(layer.ID, nsid)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("insertLayerNamespace", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Update Layer_diff_FeatureVersion now.
|
||||
err = pgSQL.updateDiffFeatureVersions(tx, &layer, &existingLayer)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit transaction.
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("InsertLayer.Commit()", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return features, nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error {
|
||||
// add and del are the FeatureVersion diff we should insert.
|
||||
var add []database.FeatureVersion
|
||||
var del []database.FeatureVersion
|
||||
func (tx *pgSession) findLayer(hash string) (database.Layer, database.Processors, int64, bool, error) {
|
||||
var (
|
||||
layerID int64
|
||||
layer = database.Layer{Hash: hash}
|
||||
processors database.Processors
|
||||
)
|
||||
|
||||
if layer.Parent == nil {
|
||||
// There is no parent, every Features are added.
|
||||
add = append(add, layer.Features...)
|
||||
} else if layer.Parent != nil {
|
||||
// There is a parent, we need to diff the Features with it.
|
||||
|
||||
// Build name:version structures.
|
||||
layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features)
|
||||
parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features)
|
||||
|
||||
// Calculate the added and deleted FeatureVersions name:version.
|
||||
addNV := compareStringLists(layerFeaturesNV, parentLayerFeaturesNV)
|
||||
delNV := compareStringLists(parentLayerFeaturesNV, layerFeaturesNV)
|
||||
|
||||
// Fill the structures containing the added and deleted FeatureVersions.
|
||||
for _, nv := range addNV {
|
||||
add = append(add, *layerFeaturesMapNV[nv])
|
||||
}
|
||||
for _, nv := range delNV {
|
||||
del = append(del, *parentLayerFeaturesMapNV[nv])
|
||||
}
|
||||
if hash == "" {
|
||||
return layer, processors, layerID, false, commonerr.NewBadRequestError("Empty Layer Hash is not allowed")
|
||||
}
|
||||
|
||||
// Insert FeatureVersions in the database.
|
||||
addIDs, err := pgSQL.insertFeatureVersions(add)
|
||||
err := tx.QueryRow(searchLayer, hash).Scan(&layerID)
|
||||
if err != nil {
|
||||
return err
|
||||
if err == sql.ErrNoRows {
|
||||
return layer, processors, layerID, false, nil
|
||||
}
|
||||
return layer, processors, layerID, false, err
|
||||
}
|
||||
delIDs, err := pgSQL.insertFeatureVersions(del)
|
||||
|
||||
processors.Detectors, err = tx.findProcessors(searchLayerDetectors, "searchLayerDetectors", "detector", layerID)
|
||||
if err != nil {
|
||||
return err
|
||||
return layer, processors, layerID, false, err
|
||||
}
|
||||
|
||||
// Insert diff in the database.
|
||||
if len(addIDs) > 0 {
|
||||
_, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "add", buildInputArray(addIDs))
|
||||
if err != nil {
|
||||
return handleError("insertLayerDiffFeatureVersion.Add", err)
|
||||
}
|
||||
}
|
||||
if len(delIDs) > 0 {
|
||||
_, err = tx.Exec(insertLayerDiffFeatureVersion, layer.ID, "del", buildInputArray(delIDs))
|
||||
if err != nil {
|
||||
return handleError("insertLayerDiffFeatureVersion.Del", err)
|
||||
}
|
||||
processors.Listers, err = tx.findProcessors(searchLayerListers, "searchLayerListers", "lister", layerID)
|
||||
if err != nil {
|
||||
return layer, processors, layerID, false, err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) {
|
||||
mapNV := make(map[string]*database.FeatureVersion, 0)
|
||||
sliceNV := make([]string, 0, len(features))
|
||||
|
||||
for i := 0; i < len(features); i++ {
|
||||
fv := &features[i]
|
||||
nv := strings.Join([]string{fv.Feature.Namespace.Name, fv.Feature.Name, fv.Version}, ":")
|
||||
mapNV[nv] = fv
|
||||
sliceNV = append(sliceNV, nv)
|
||||
}
|
||||
|
||||
return mapNV, sliceNV
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) DeleteLayer(name string) error {
|
||||
defer observeQueryTime("DeleteLayer", "all", time.Now())
|
||||
|
||||
result, err := pgSQL.Exec(removeLayer, name)
|
||||
if err != nil {
|
||||
return handleError("removeLayer", err)
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("removeLayer.RowsAffected()", err)
|
||||
}
|
||||
|
||||
if affected <= 0 {
|
||||
return commonerr.ErrNotFound
|
||||
}
|
||||
|
||||
return nil
|
||||
return layer, processors, layerID, true, nil
|
||||
}
|
||||
|
@ -15,423 +15,100 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func TestPersistLayer(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistLayer", false)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
l1 := database.Layer{}
|
||||
l2 := database.Layer{Hash: "HESOYAM"}
|
||||
|
||||
// invalid
|
||||
assert.NotNil(t, tx.PersistLayer(l1))
|
||||
// valid
|
||||
assert.Nil(t, tx.PersistLayer(l2))
|
||||
// duplicated
|
||||
assert.Nil(t, tx.PersistLayer(l2))
|
||||
}
|
||||
|
||||
func TestPersistLayerProcessors(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistLayerProcessors", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// invalid
|
||||
assert.NotNil(t, tx.PersistLayerContent("hash", []database.Namespace{}, []database.Feature{}, database.Processors{}))
|
||||
// valid
|
||||
assert.Nil(t, tx.PersistLayerContent("layer-4", []database.Namespace{}, []database.Feature{}, database.Processors{Detectors: []string{"new detector!"}}))
|
||||
}
|
||||
|
||||
func TestFindLayer(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("FindLayer", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
datastore, tx := openSessionForTest(t, "FindLayer", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// Layer-0: no parent, no namespace, no feature, no vulnerability
|
||||
layer, err := datastore.FindLayer("layer-0", false, false)
|
||||
if assert.Nil(t, err) && assert.NotNil(t, layer) {
|
||||
assert.Equal(t, "layer-0", layer.Name)
|
||||
assert.Len(t, layer.Namespaces, 0)
|
||||
assert.Nil(t, layer.Parent)
|
||||
assert.Equal(t, 1, layer.EngineVersion)
|
||||
assert.Len(t, layer.Features, 0)
|
||||
expected := database.Layer{Hash: "layer-4"}
|
||||
expectedProcessors := database.Processors{
|
||||
Detectors: []string{"os-release", "apt-sources"},
|
||||
Listers: []string{"dpkg", "rpm"},
|
||||
}
|
||||
|
||||
layer, err = datastore.FindLayer("layer-0", true, false)
|
||||
if assert.Nil(t, err) && assert.NotNil(t, layer) {
|
||||
assert.Len(t, layer.Features, 0)
|
||||
}
|
||||
|
||||
// Layer-1: one parent, adds two features, one vulnerability
|
||||
layer, err = datastore.FindLayer("layer-1", false, false)
|
||||
if assert.Nil(t, err) && assert.NotNil(t, layer) {
|
||||
assert.Equal(t, layer.Name, "layer-1")
|
||||
assertExpectedNamespaceName(t, &layer, []string{"debian:7"})
|
||||
if assert.NotNil(t, layer.Parent) {
|
||||
assert.Equal(t, "layer-0", layer.Parent.Name)
|
||||
}
|
||||
assert.Equal(t, 1, layer.EngineVersion)
|
||||
assert.Len(t, layer.Features, 0)
|
||||
}
|
||||
|
||||
layer, err = datastore.FindLayer("layer-1", true, false)
|
||||
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
|
||||
for _, featureVersion := range layer.Features {
|
||||
assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name)
|
||||
|
||||
switch featureVersion.Feature.Name {
|
||||
case "wechat":
|
||||
assert.Equal(t, "0.5", featureVersion.Version)
|
||||
case "openssl":
|
||||
assert.Equal(t, "1.0", featureVersion.Version)
|
||||
default:
|
||||
t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
layer, err = datastore.FindLayer("layer-1", true, true)
|
||||
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
|
||||
for _, featureVersion := range layer.Features {
|
||||
assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name)
|
||||
|
||||
switch featureVersion.Feature.Name {
|
||||
case "wechat":
|
||||
assert.Equal(t, "0.5", featureVersion.Version)
|
||||
case "openssl":
|
||||
assert.Equal(t, "1.0", featureVersion.Version)
|
||||
|
||||
if assert.Len(t, featureVersion.AffectedBy, 1) {
|
||||
assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name)
|
||||
assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name)
|
||||
assert.Equal(t, database.HighSeverity, featureVersion.AffectedBy[0].Severity)
|
||||
assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description)
|
||||
assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link)
|
||||
assert.Equal(t, "2.0", featureVersion.AffectedBy[0].FixedBy)
|
||||
}
|
||||
default:
|
||||
t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Testing Multiple namespaces layer-3b has debian:7 and debian:8 namespaces
|
||||
layer, err = datastore.FindLayer("layer-3b", true, true)
|
||||
|
||||
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
|
||||
assert.Equal(t, "layer-3b", layer.Name)
|
||||
// validate the namespace
|
||||
assertExpectedNamespaceName(t, &layer, []string{"debian:7", "debian:8"})
|
||||
for _, featureVersion := range layer.Features {
|
||||
switch featureVersion.Feature.Namespace.Name {
|
||||
case "debian:7":
|
||||
assert.Equal(t, "wechat", featureVersion.Feature.Name)
|
||||
assert.Equal(t, "0.5", featureVersion.Version)
|
||||
case "debian:8":
|
||||
assert.Equal(t, "openssl", featureVersion.Feature.Name)
|
||||
assert.Equal(t, "1.0", featureVersion.Version)
|
||||
default:
|
||||
t.Errorf("unexpected package %s for layer-3b", featureVersion.Feature.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertLayer(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("InsertLayer", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Insert invalid layer.
|
||||
testInsertLayerInvalid(t, datastore)
|
||||
|
||||
// Insert a layer tree.
|
||||
testInsertLayerTree(t, datastore)
|
||||
|
||||
// Update layer.
|
||||
testInsertLayerUpdate(t, datastore)
|
||||
|
||||
// Delete layer.
|
||||
testInsertLayerDelete(t, datastore)
|
||||
}
|
||||
|
||||
func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) {
|
||||
invalidLayers := []database.Layer{
|
||||
{},
|
||||
{Name: "layer0", Parent: &database.Layer{}},
|
||||
{Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}},
|
||||
}
|
||||
|
||||
for _, invalidLayer := range invalidLayers {
|
||||
err := datastore.InsertLayer(invalidLayer)
|
||||
assert.Error(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
func testInsertLayerTree(t *testing.T, datastore database.Datastore) {
|
||||
f1 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature1",
|
||||
},
|
||||
Version: "1.0",
|
||||
}
|
||||
f2 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature2",
|
||||
},
|
||||
Version: "0.34",
|
||||
}
|
||||
f3 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature3",
|
||||
},
|
||||
Version: "0.56",
|
||||
}
|
||||
f4 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace3",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature2",
|
||||
},
|
||||
Version: "0.34",
|
||||
}
|
||||
f5 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace3",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature3",
|
||||
},
|
||||
Version: "0.56",
|
||||
}
|
||||
f6 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace3",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature4",
|
||||
},
|
||||
Version: "0.666",
|
||||
}
|
||||
|
||||
layers := []database.Layer{
|
||||
{
|
||||
Name: "TestInsertLayer1",
|
||||
},
|
||||
{
|
||||
Name: "TestInsertLayer2",
|
||||
Parent: &database.Layer{Name: "TestInsertLayer1"},
|
||||
Namespaces: []database.Namespace{database.Namespace{
|
||||
Name: "TestInsertLayerNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}},
|
||||
},
|
||||
// This layer changes the namespace and adds Features.
|
||||
{
|
||||
Name: "TestInsertLayer3",
|
||||
Parent: &database.Layer{Name: "TestInsertLayer2"},
|
||||
Namespaces: []database.Namespace{database.Namespace{
|
||||
Name: "TestInsertLayerNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}},
|
||||
Features: []database.FeatureVersion{f1, f2, f3},
|
||||
},
|
||||
// This layer covers the case where the last layer doesn't provide any new Feature.
|
||||
{
|
||||
Name: "TestInsertLayer4a",
|
||||
Parent: &database.Layer{Name: "TestInsertLayer3"},
|
||||
Features: []database.FeatureVersion{f1, f2, f3},
|
||||
},
|
||||
// This layer covers the case where the last layer provides Features.
|
||||
// It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their
|
||||
// Namespaces should then remain unchanged.
|
||||
{
|
||||
Name: "TestInsertLayer4b",
|
||||
Parent: &database.Layer{Name: "TestInsertLayer3"},
|
||||
Namespaces: []database.Namespace{database.Namespace{
|
||||
Name: "TestInsertLayerNamespace3",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}},
|
||||
Features: []database.FeatureVersion{
|
||||
// Deletes TestInsertLayerFeature1.
|
||||
// Keep TestInsertLayerFeature2 (old Namespace should be kept):
|
||||
f4,
|
||||
// Upgrades TestInsertLayerFeature3 (with new Namespace):
|
||||
f5,
|
||||
// Adds TestInsertLayerFeature4:
|
||||
f6,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
var err error
|
||||
retrievedLayers := make(map[string]database.Layer)
|
||||
for _, layer := range layers {
|
||||
if layer.Parent != nil {
|
||||
// Retrieve from database its parent and assign.
|
||||
parent := retrievedLayers[layer.Parent.Name]
|
||||
layer.Parent = &parent
|
||||
}
|
||||
|
||||
err = datastore.InsertLayer(layer)
|
||||
assert.Nil(t, err)
|
||||
|
||||
retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
// layer inherits all namespaces from its ancestries
|
||||
l4a := retrievedLayers["TestInsertLayer4a"]
|
||||
assertExpectedNamespaceName(t, &l4a, []string{"TestInsertLayerNamespace2", "TestInsertLayerNamespace1"})
|
||||
assert.Len(t, l4a.Features, 3)
|
||||
for _, featureVersion := range l4a.Features {
|
||||
if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) {
|
||||
assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3))
|
||||
}
|
||||
}
|
||||
|
||||
l4b := retrievedLayers["TestInsertLayer4b"]
|
||||
assertExpectedNamespaceName(t, &l4b, []string{"TestInsertLayerNamespace1", "TestInsertLayerNamespace2", "TestInsertLayerNamespace3"})
|
||||
assert.Len(t, l4b.Features, 3)
|
||||
for _, featureVersion := range l4b.Features {
|
||||
if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) {
|
||||
assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) {
|
||||
f7 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestInsertLayerNamespace3",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Name: "TestInsertLayerFeature7",
|
||||
},
|
||||
Version: "0.01",
|
||||
}
|
||||
|
||||
l3, _ := datastore.FindLayer("TestInsertLayer3", true, false)
|
||||
l3u := database.Layer{
|
||||
Name: l3.Name,
|
||||
Parent: l3.Parent,
|
||||
Namespaces: []database.Namespace{database.Namespace{
|
||||
Name: "TestInsertLayerNamespaceUpdated1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}},
|
||||
Features: []database.FeatureVersion{f7},
|
||||
}
|
||||
|
||||
l4u := database.Layer{
|
||||
Name: "TestInsertLayer4",
|
||||
Parent: &database.Layer{Name: "TestInsertLayer3"},
|
||||
Features: []database.FeatureVersion{f7},
|
||||
EngineVersion: 2,
|
||||
}
|
||||
|
||||
// Try to re-insert without increasing the EngineVersion.
|
||||
err := datastore.InsertLayer(l3u)
|
||||
// invalid
|
||||
_, _, _, err := tx.FindLayer("")
|
||||
assert.NotNil(t, err)
|
||||
_, _, ok, err := tx.FindLayer("layer-non")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
l3uf, err := datastore.FindLayer(l3u.Name, true, false)
|
||||
if assert.Nil(t, err) {
|
||||
assertSameNamespaceName(t, &l3, &l3uf)
|
||||
assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion)
|
||||
assert.Len(t, l3uf.Features, len(l3.Features))
|
||||
// valid
|
||||
layer, processors, ok2, err := tx.FindLayer("layer-4")
|
||||
if assert.Nil(t, err) && assert.True(t, ok2) {
|
||||
assert.Equal(t, expected, layer)
|
||||
assertProcessorsEqual(t, expectedProcessors, processors)
|
||||
}
|
||||
}
|
||||
|
||||
// Update layer l3.
|
||||
// Verify that the Namespace, EngineVersion and FeatureVersions got updated.
|
||||
l3u.EngineVersion = 2
|
||||
err = datastore.InsertLayer(l3u)
|
||||
func TestFindLayerWithContent(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindLayerWithContent", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
_, _, err := tx.FindLayerWithContent("")
|
||||
assert.NotNil(t, err)
|
||||
_, ok, err := tx.FindLayerWithContent("layer-non")
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
l3uf, err = datastore.FindLayer(l3u.Name, true, false)
|
||||
if assert.Nil(t, err) {
|
||||
assertSameNamespaceName(t, &l3u, &l3uf)
|
||||
assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion)
|
||||
if assert.Len(t, l3uf.Features, 1) {
|
||||
assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0])
|
||||
}
|
||||
expectedL := database.LayerWithContent{
|
||||
Layer: database.Layer{
|
||||
Hash: "layer-4",
|
||||
},
|
||||
Features: []database.Feature{
|
||||
{Name: "fake", Version: "2.0", VersionFormat: "rpm"},
|
||||
{Name: "openssl", Version: "2.0", VersionFormat: "dpkg"},
|
||||
},
|
||||
Namespaces: []database.Namespace{
|
||||
{Name: "debian:7", VersionFormat: "dpkg"},
|
||||
{Name: "fake:1.0", VersionFormat: "rpm"},
|
||||
},
|
||||
ProcessedBy: database.Processors{
|
||||
Detectors: []string{"os-release", "apt-sources"},
|
||||
Listers: []string{"dpkg", "rpm"},
|
||||
},
|
||||
}
|
||||
|
||||
// Update layer l4.
|
||||
// Verify that the Namespace got updated from its new Parent's, and also verify the
|
||||
// EnginVersion and FeatureVersions.
|
||||
l4u.Parent = &l3uf
|
||||
err = datastore.InsertLayer(l4u)
|
||||
assert.Nil(t, err)
|
||||
|
||||
l4uf, err := datastore.FindLayer(l3u.Name, true, false)
|
||||
if assert.Nil(t, err) {
|
||||
assertSameNamespaceName(t, &l3u, &l4uf)
|
||||
assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion)
|
||||
if assert.Len(t, l4uf.Features, 1) {
|
||||
assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0])
|
||||
}
|
||||
layer, ok2, err := tx.FindLayerWithContent("layer-4")
|
||||
if assert.Nil(t, err) && assert.True(t, ok2) {
|
||||
assertLayerWithContentEqual(t, expectedL, layer)
|
||||
}
|
||||
}
|
||||
|
||||
func assertSameNamespaceName(t *testing.T, layer1 *database.Layer, layer2 *database.Layer) {
|
||||
assert.Len(t, compareStringLists(extractNamespaceName(layer1), extractNamespaceName(layer2)), 0)
|
||||
}
|
||||
|
||||
func assertExpectedNamespaceName(t *testing.T, layer *database.Layer, expectedNames []string) {
|
||||
assert.Len(t, compareStringLists(extractNamespaceName(layer), expectedNames), 0)
|
||||
}
|
||||
|
||||
func extractNamespaceName(layer *database.Layer) []string {
|
||||
slist := make([]string, 0, len(layer.Namespaces))
|
||||
for _, ns := range layer.Namespaces {
|
||||
slist = append(slist, ns.Name)
|
||||
}
|
||||
return slist
|
||||
}
|
||||
|
||||
func testInsertLayerDelete(t *testing.T, datastore database.Datastore) {
|
||||
err := datastore.DeleteLayer("TestInsertLayerX")
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
|
||||
// ensure layer_namespace table is cleaned up once a layer is removed
|
||||
layer3, err := datastore.FindLayer("TestInsertLayer3", false, false)
|
||||
layer4a, err := datastore.FindLayer("TestInsertLayer4a", false, false)
|
||||
layer4b, err := datastore.FindLayer("TestInsertLayer4b", false, false)
|
||||
|
||||
err = datastore.DeleteLayer("TestInsertLayer3")
|
||||
assert.Nil(t, err)
|
||||
|
||||
_, err = datastore.FindLayer("TestInsertLayer3", false, false)
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
assertNotInLayerNamespace(t, layer3.ID, datastore)
|
||||
_, err = datastore.FindLayer("TestInsertLayer4a", false, false)
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
assertNotInLayerNamespace(t, layer4a.ID, datastore)
|
||||
_, err = datastore.FindLayer("TestInsertLayer4b", true, false)
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
assertNotInLayerNamespace(t, layer4b.ID, datastore)
|
||||
}
|
||||
|
||||
func assertNotInLayerNamespace(t *testing.T, layerID int, datastore database.Datastore) {
|
||||
pg, ok := datastore.(*pgSQL)
|
||||
if !assert.True(t, ok) {
|
||||
return
|
||||
}
|
||||
tx, err := pg.Begin()
|
||||
if !assert.Nil(t, err) {
|
||||
return
|
||||
}
|
||||
rows, err := tx.Query(searchLayerNamespace, layerID)
|
||||
assert.False(t, rows.Next())
|
||||
}
|
||||
|
||||
func cmpFV(a, b database.FeatureVersion) bool {
|
||||
return a.Feature.Name == b.Feature.Name &&
|
||||
a.Feature.Namespace.Name == b.Feature.Namespace.Name &&
|
||||
a.Version == b.Version
|
||||
func assertLayerWithContentEqual(t *testing.T, expected database.LayerWithContent, actual database.LayerWithContent) bool {
|
||||
return assert.Equal(t, expected.Layer, actual.Layer) &&
|
||||
assertFeaturesEqual(t, expected.Features, actual.Features) &&
|
||||
assertProcessorsEqual(t, expected.ProcessedBy, actual.ProcessedBy) &&
|
||||
assertNamespacesEqual(t, expected.Namespaces, actual.Namespaces)
|
||||
}
|
||||
|
@ -15,6 +15,7 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@ -22,86 +23,91 @@ import (
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
var (
|
||||
errLockNotFound = errors.New("lock is not in database")
|
||||
)
|
||||
|
||||
// Lock tries to set a temporary lock in the database.
|
||||
//
|
||||
// Lock does not block, instead, it returns true and its expiration time
|
||||
// is the lock has been successfully acquired or false otherwise
|
||||
func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) {
|
||||
// is the lock has been successfully acquired or false otherwise.
|
||||
func (tx *pgSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) {
|
||||
if name == "" || owner == "" || duration == 0 {
|
||||
log.Warning("could not create an invalid lock")
|
||||
return false, time.Time{}
|
||||
return false, time.Time{}, commonerr.NewBadRequestError("Invalid Lock Parameters")
|
||||
}
|
||||
|
||||
defer observeQueryTime("Lock", "all", time.Now())
|
||||
|
||||
// Compute expiration.
|
||||
until := time.Now().Add(duration)
|
||||
|
||||
if renew {
|
||||
defer observeQueryTime("Lock", "update", time.Now())
|
||||
// Renew lock.
|
||||
r, err := pgSQL.Exec(updateLock, name, owner, until)
|
||||
r, err := tx.Exec(updateLock, name, owner, until)
|
||||
if err != nil {
|
||||
handleError("updateLock", err)
|
||||
return false, until
|
||||
return false, until, handleError("updateLock", err)
|
||||
}
|
||||
if n, _ := r.RowsAffected(); n > 0 {
|
||||
// Updated successfully.
|
||||
return true, until
|
||||
|
||||
if n, err := r.RowsAffected(); err == nil {
|
||||
return n > 0, until, nil
|
||||
}
|
||||
} else {
|
||||
// Prune locks.
|
||||
pgSQL.pruneLocks()
|
||||
return false, until, handleError("updateLock", err)
|
||||
} else if err := tx.pruneLocks(); err != nil {
|
||||
return false, until, err
|
||||
}
|
||||
|
||||
// Lock.
|
||||
_, err := pgSQL.Exec(insertLock, name, owner, until)
|
||||
defer observeQueryTime("Lock", "soiLock", time.Now())
|
||||
_, err := tx.Exec(soiLock, name, owner, until)
|
||||
if err != nil {
|
||||
if !isErrUniqueViolation(err) {
|
||||
handleError("insertLock", err)
|
||||
if isErrUniqueViolation(err) {
|
||||
return false, until, nil
|
||||
}
|
||||
return false, until
|
||||
return false, until, handleError("insertLock", err)
|
||||
}
|
||||
|
||||
return true, until
|
||||
return true, until, nil
|
||||
}
|
||||
|
||||
// Unlock unlocks a lock specified by its name if I own it
|
||||
func (pgSQL *pgSQL) Unlock(name, owner string) {
|
||||
func (tx *pgSession) Unlock(name, owner string) error {
|
||||
if name == "" || owner == "" {
|
||||
log.Warning("could not delete an invalid lock")
|
||||
return
|
||||
return commonerr.NewBadRequestError("Invalid Lock Parameters")
|
||||
}
|
||||
|
||||
defer observeQueryTime("Unlock", "all", time.Now())
|
||||
|
||||
pgSQL.Exec(removeLock, name, owner)
|
||||
_, err := tx.Exec(removeLock, name, owner)
|
||||
return err
|
||||
}
|
||||
|
||||
// FindLock returns the owner of a lock specified by its name and its
|
||||
// expiration time.
|
||||
func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) {
|
||||
func (tx *pgSession) FindLock(name string) (string, time.Time, bool, error) {
|
||||
if name == "" {
|
||||
log.Warning("could not find an invalid lock")
|
||||
return "", time.Time{}, commonerr.NewBadRequestError("could not find an invalid lock")
|
||||
return "", time.Time{}, false, commonerr.NewBadRequestError("could not find an invalid lock")
|
||||
}
|
||||
|
||||
defer observeQueryTime("FindLock", "all", time.Now())
|
||||
|
||||
var owner string
|
||||
var until time.Time
|
||||
err := pgSQL.QueryRow(searchLock, name).Scan(&owner, &until)
|
||||
err := tx.QueryRow(searchLock, name).Scan(&owner, &until)
|
||||
if err != nil {
|
||||
return owner, until, handleError("searchLock", err)
|
||||
return owner, until, false, handleError("searchLock", err)
|
||||
}
|
||||
|
||||
return owner, until, nil
|
||||
return owner, until, true, nil
|
||||
}
|
||||
|
||||
// pruneLocks removes every expired locks from the database
|
||||
func (pgSQL *pgSQL) pruneLocks() {
|
||||
func (tx *pgSession) pruneLocks() error {
|
||||
defer observeQueryTime("pruneLocks", "all", time.Now())
|
||||
|
||||
if _, err := pgSQL.Exec(removeLockExpired); err != nil {
|
||||
handleError("removeLockExpired", err)
|
||||
if r, err := tx.Exec(removeLockExpired); err != nil {
|
||||
return handleError("removeLockExpired", err)
|
||||
} else if affected, err := r.RowsAffected(); err != nil {
|
||||
return handleError("removeLockExpired", err)
|
||||
} else {
|
||||
log.Debugf("Pruned %d Locks", affected)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -22,48 +22,72 @@ import (
|
||||
)
|
||||
|
||||
func TestLock(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("InsertNamespace", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
datastore, tx := openSessionForTest(t, "Lock", true)
|
||||
defer datastore.Close()
|
||||
|
||||
var l bool
|
||||
var et time.Time
|
||||
|
||||
// Create a first lock.
|
||||
l, _ = datastore.Lock("test1", "owner1", time.Minute, false)
|
||||
l, _, err := tx.Lock("test1", "owner1", time.Minute, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Try to lock the same lock with another owner.
|
||||
l, _ = datastore.Lock("test1", "owner2", time.Minute, true)
|
||||
// lock again by itself, the previous lock is not expired yet.
|
||||
l, _, err = tx.Lock("test1", "owner1", time.Minute, false)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, l)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
l, _ = datastore.Lock("test1", "owner2", time.Minute, false)
|
||||
// Try to renew the same lock with another owner.
|
||||
l, _, err = tx.Lock("test1", "owner2", time.Minute, true)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, l)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
l, _, err = tx.Lock("test1", "owner2", time.Minute, false)
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, l)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
// Renew the lock.
|
||||
l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true)
|
||||
l, _, err = tx.Lock("test1", "owner1", 2*time.Minute, true)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Unlock and then relock by someone else.
|
||||
datastore.Unlock("test1", "owner1")
|
||||
err = tx.Unlock("test1", "owner1")
|
||||
assert.Nil(t, err)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
l, et = datastore.Lock("test1", "owner2", time.Minute, false)
|
||||
l, et, err = tx.Lock("test1", "owner2", time.Minute, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// LockInfo
|
||||
o, et2, err := datastore.FindLock("test1")
|
||||
o, et2, ok, err := tx.FindLock("test1")
|
||||
assert.True(t, ok)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "owner2", o)
|
||||
assert.Equal(t, et.Second(), et2.Second())
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Create a second lock which is actually already expired ...
|
||||
l, _ = datastore.Lock("test2", "owner1", -time.Minute, false)
|
||||
l, _, err = tx.Lock("test2", "owner1", -time.Minute, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
// Take over the lock
|
||||
l, _ = datastore.Lock("test2", "owner2", time.Minute, false)
|
||||
l, _, err = tx.Lock("test2", "owner2", time.Minute, false)
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, l)
|
||||
tx = restartSession(t, datastore, tx, true)
|
||||
|
||||
if !assert.Nil(t, tx.Rollback()) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
@ -1,53 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/remind101/migrate"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// This migration removes the data maintained by the previous migration tool
|
||||
// (liamstask/goose), and if it was present, mark the 00002_initial_schema
|
||||
// migration as done.
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 1,
|
||||
Up: func(tx *sql.Tx) error {
|
||||
// Verify that goose was in use before, otherwise skip this migration.
|
||||
var e bool
|
||||
err := tx.QueryRow("SELECT true FROM pg_class WHERE relname = $1", "goose_db_version").Scan(&e)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Delete goose's data.
|
||||
_, err = tx.Exec("DROP TABLE goose_db_version CASCADE")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Mark the '00002_initial_schema' as done.
|
||||
_, err = tx.Exec("INSERT INTO schema_migrations (version) VALUES (2)")
|
||||
|
||||
return err
|
||||
},
|
||||
Down: migrate.Queries([]string{}),
|
||||
})
|
||||
}
|
192
database/pgsql/migrations/00001_initial_schema.go
Normal file
192
database/pgsql/migrations/00001_initial_schema.go
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 1,
|
||||
Up: migrate.Queries([]string{
|
||||
// namespaces
|
||||
`CREATE TABLE IF NOT EXISTS namespace (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NULL,
|
||||
version_format TEXT,
|
||||
UNIQUE (name, version_format));`,
|
||||
`CREATE INDEX ON namespace(name);`,
|
||||
|
||||
// features
|
||||
`CREATE TABLE IF NOT EXISTS feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL,
|
||||
version TEXT NOT NULL,
|
||||
version_format TEXT NOT NULL,
|
||||
UNIQUE (name, version, version_format));`,
|
||||
`CREATE INDEX ON feature(name);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS namespaced_feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
namespace_id INT REFERENCES namespace,
|
||||
feature_id INT REFERENCES feature,
|
||||
UNIQUE (namespace_id, feature_id));`,
|
||||
|
||||
// layers
|
||||
`CREATE TABLE IF NOT EXISTS layer(
|
||||
id SERIAL PRIMARY KEY,
|
||||
hash TEXT NOT NULL UNIQUE);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS layer_feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
layer_id INT REFERENCES layer ON DELETE CASCADE,
|
||||
feature_id INT REFERENCES feature ON DELETE CASCADE,
|
||||
UNIQUE (layer_id, feature_id));`,
|
||||
`CREATE INDEX ON layer_feature(layer_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS layer_lister (
|
||||
id SERIAL PRIMARY KEY,
|
||||
layer_id INT REFERENCES layer ON DELETE CASCADE,
|
||||
lister TEXT NOT NULL,
|
||||
UNIQUE (layer_id, lister));`,
|
||||
`CREATE INDEX ON layer_lister(layer_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS layer_detector (
|
||||
id SERIAL PRIMARY KEY,
|
||||
layer_id INT REFERENCES layer ON DELETE CASCADE,
|
||||
detector TEXT,
|
||||
UNIQUE (layer_id, detector));`,
|
||||
`CREATE INDEX ON layer_detector(layer_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS layer_namespace (
|
||||
id SERIAL PRIMARY KEY,
|
||||
layer_id INT REFERENCES layer ON DELETE CASCADE,
|
||||
namespace_id INT REFERENCES namespace ON DELETE CASCADE,
|
||||
UNIQUE (layer_id, namespace_id));`,
|
||||
`CREATE INDEX ON layer_namespace(layer_id);`,
|
||||
|
||||
// ancestry
|
||||
`CREATE TABLE IF NOT EXISTS ancestry (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name TEXT NOT NULL UNIQUE);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS ancestry_layer (
|
||||
id SERIAL PRIMARY KEY,
|
||||
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
|
||||
ancestry_index INT NOT NULL,
|
||||
layer_id INT REFERENCES layer ON DELETE RESTRICT,
|
||||
UNIQUE (ancestry_id, ancestry_index));`,
|
||||
`CREATE INDEX ON ancestry_layer(ancestry_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS ancestry_feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
|
||||
namespaced_feature_id INT REFERENCES namespaced_feature ON DELETE CASCADE,
|
||||
UNIQUE (ancestry_id, namespaced_feature_id));`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS ancestry_lister (
|
||||
id SERIAL PRIMARY KEY,
|
||||
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
|
||||
lister TEXT,
|
||||
UNIQUE (ancestry_id, lister));`,
|
||||
`CREATE INDEX ON ancestry_lister(ancestry_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS ancestry_detector (
|
||||
id SERIAL PRIMARY KEY,
|
||||
ancestry_id INT REFERENCES ancestry ON DELETE CASCADE,
|
||||
detector TEXT,
|
||||
UNIQUE (ancestry_id, detector));`,
|
||||
`CREATE INDEX ON ancestry_detector(ancestry_id);`,
|
||||
|
||||
`CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`,
|
||||
|
||||
// vulnerability
|
||||
`CREATE TABLE IF NOT EXISTS vulnerability (
|
||||
id SERIAL PRIMARY KEY,
|
||||
namespace_id INT NOT NULL REFERENCES Namespace,
|
||||
name TEXT NOT NULL,
|
||||
description TEXT NULL,
|
||||
link TEXT NULL,
|
||||
severity severity NOT NULL,
|
||||
metadata TEXT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE NULL);`,
|
||||
`CREATE INDEX ON vulnerability(namespace_id, name);`,
|
||||
`CREATE INDEX ON vulnerability(namespace_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS vulnerability_affected_feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE,
|
||||
feature_name TEXT NOT NULL,
|
||||
affected_version TEXT,
|
||||
fixedin TEXT);`,
|
||||
`CREATE INDEX ON vulnerability_affected_feature(vulnerability_id, feature_name);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS vulnerability_affected_namespaced_feature(
|
||||
id SERIAL PRIMARY KEY,
|
||||
vulnerability_id INT NOT NULL REFERENCES vulnerability ON DELETE CASCADE,
|
||||
namespaced_feature_id INT NOT NULL REFERENCES namespaced_feature ON DELETE CASCADE,
|
||||
added_by INT NOT NULL REFERENCES vulnerability_affected_feature ON DELETE CASCADE,
|
||||
UNIQUE (vulnerability_id, namespaced_feature_id));`,
|
||||
`CREATE INDEX ON vulnerability_affected_namespaced_feature(namespaced_feature_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS KeyValue (
|
||||
id SERIAL PRIMARY KEY,
|
||||
key TEXT NOT NULL UNIQUE,
|
||||
value TEXT);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Lock (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(64) NOT NULL UNIQUE,
|
||||
owner VARCHAR(64) NOT NULL,
|
||||
until TIMESTAMP WITH TIME ZONE);`,
|
||||
`CREATE INDEX ON Lock (owner);`,
|
||||
|
||||
// Notification
|
||||
`CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(64) NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
notified_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE,
|
||||
new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`,
|
||||
`CREATE INDEX ON Vulnerability_Notification (notified_at);`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`DROP TABLE IF EXISTS
|
||||
ancestry,
|
||||
ancestry_layer,
|
||||
ancestry_feature,
|
||||
ancestry_detector,
|
||||
ancestry_lister,
|
||||
feature,
|
||||
namespaced_feature,
|
||||
keyvalue,
|
||||
layer,
|
||||
layer_detector,
|
||||
layer_feature,
|
||||
layer_lister,
|
||||
layer_namespace,
|
||||
lock,
|
||||
namespace,
|
||||
vulnerability,
|
||||
vulnerability_affected_feature,
|
||||
vulnerability_affected_namespaced_feature,
|
||||
vulnerability_notification
|
||||
CASCADE;`,
|
||||
`DROP TYPE IF EXISTS severity;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,128 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
// This migration creates the initial Clair's schema.
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 2,
|
||||
Up: migrate.Queries([]string{
|
||||
`CREATE TABLE IF NOT EXISTS Namespace (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(128) NULL);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Layer (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(128) NOT NULL UNIQUE,
|
||||
engineversion SMALLINT NOT NULL,
|
||||
parent_id INT NULL REFERENCES Layer ON DELETE CASCADE,
|
||||
namespace_id INT NULL REFERENCES Namespace,
|
||||
created_at TIMESTAMP WITH TIME ZONE);`,
|
||||
`CREATE INDEX ON Layer (parent_id);`,
|
||||
`CREATE INDEX ON Layer (namespace_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
namespace_id INT NOT NULL REFERENCES Namespace,
|
||||
name VARCHAR(128) NOT NULL,
|
||||
UNIQUE (namespace_id, name));`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS FeatureVersion (
|
||||
id SERIAL PRIMARY KEY,
|
||||
feature_id INT NOT NULL REFERENCES Feature,
|
||||
version VARCHAR(128) NOT NULL);`,
|
||||
`CREATE INDEX ON FeatureVersion (feature_id);`,
|
||||
|
||||
`CREATE TYPE modification AS ENUM ('add', 'del');`,
|
||||
`CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion (
|
||||
id SERIAL PRIMARY KEY,
|
||||
layer_id INT NOT NULL REFERENCES Layer ON DELETE CASCADE,
|
||||
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
|
||||
modification modification NOT NULL,
|
||||
UNIQUE (layer_id, featureversion_id));`,
|
||||
`CREATE INDEX ON Layer_diff_FeatureVersion (layer_id);`,
|
||||
`CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id);`,
|
||||
`CREATE INDEX ON Layer_diff_FeatureVersion (featureversion_id, layer_id);`,
|
||||
|
||||
`CREATE TYPE severity AS ENUM ('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1');`,
|
||||
`CREATE TABLE IF NOT EXISTS Vulnerability (
|
||||
id SERIAL PRIMARY KEY,
|
||||
namespace_id INT NOT NULL REFERENCES Namespace,
|
||||
name VARCHAR(128) NOT NULL,
|
||||
description TEXT NULL,
|
||||
link VARCHAR(128) NULL,
|
||||
severity severity NOT NULL,
|
||||
metadata TEXT NULL,
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE NULL);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature (
|
||||
id SERIAL PRIMARY KEY,
|
||||
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
|
||||
feature_id INT NOT NULL REFERENCES Feature,
|
||||
version VARCHAR(128) NOT NULL,
|
||||
UNIQUE (vulnerability_id, feature_id));`,
|
||||
`CREATE INDEX ON Vulnerability_FixedIn_Feature (feature_id, vulnerability_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion (
|
||||
id SERIAL PRIMARY KEY,
|
||||
vulnerability_id INT NOT NULL REFERENCES Vulnerability ON DELETE CASCADE,
|
||||
featureversion_id INT NOT NULL REFERENCES FeatureVersion,
|
||||
fixedin_id INT NOT NULL REFERENCES Vulnerability_FixedIn_Feature ON DELETE CASCADE,
|
||||
UNIQUE (vulnerability_id, featureversion_id));`,
|
||||
`CREATE INDEX ON Vulnerability_Affects_FeatureVersion (fixedin_id);`,
|
||||
`CREATE INDEX ON Vulnerability_Affects_FeatureVersion (featureversion_id, vulnerability_id);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS KeyValue (
|
||||
id SERIAL PRIMARY KEY,
|
||||
key VARCHAR(128) NOT NULL UNIQUE,
|
||||
value TEXT);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Lock (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(64) NOT NULL UNIQUE,
|
||||
owner VARCHAR(64) NOT NULL,
|
||||
until TIMESTAMP WITH TIME ZONE);`,
|
||||
`CREATE INDEX ON Lock (owner);`,
|
||||
|
||||
`CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
|
||||
id SERIAL PRIMARY KEY,
|
||||
name VARCHAR(64) NOT NULL UNIQUE,
|
||||
created_at TIMESTAMP WITH TIME ZONE,
|
||||
notified_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
deleted_at TIMESTAMP WITH TIME ZONE NULL,
|
||||
old_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE,
|
||||
new_vulnerability_id INT NULL REFERENCES Vulnerability ON DELETE CASCADE);`,
|
||||
`CREATE INDEX ON Vulnerability_Notification (notified_at);`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`DROP TABLE IF EXISTS
|
||||
Namespace,
|
||||
Layer,
|
||||
Feature,
|
||||
FeatureVersion,
|
||||
Layer_diff_FeatureVersion,
|
||||
Vulnerability,
|
||||
Vulnerability_FixedIn_Feature,
|
||||
Vulnerability_Affects_FeatureVersion,
|
||||
Vulnerability_Notification,
|
||||
KeyValue,
|
||||
Lock
|
||||
CASCADE;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,35 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 3,
|
||||
Up: migrate.Queries([]string{
|
||||
`CREATE UNIQUE INDEX namespace_name_key ON Namespace (name);`,
|
||||
`CREATE INDEX vulnerability_name_idx ON Vulnerability (name);`,
|
||||
`CREATE INDEX vulnerability_namespace_id_name_idx ON Vulnerability (namespace_id, name);`,
|
||||
`CREATE UNIQUE INDEX featureversion_feature_id_version_key ON FeatureVersion (feature_id, version);`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`DROP INDEX namespace_name_key;`,
|
||||
`DROP INDEX vulnerability_name_idx;`,
|
||||
`DROP INDEX vulnerability_namespace_id_name_idx;`,
|
||||
`DROP INDEX featureversion_feature_id_version_key;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 4,
|
||||
Up: migrate.Queries([]string{
|
||||
`CREATE INDEX vulnerability_notification_deleted_at_idx ON Vulnerability_Notification (deleted_at);`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`DROP INDEX vulnerability_notification_deleted_at_idx;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 5,
|
||||
Up: migrate.Queries([]string{
|
||||
`CREATE INDEX layer_diff_featureversion_layer_id_modification_idx ON Layer_diff_FeatureVersion (layer_id, modification);`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`DROP INDEX layer_diff_featureversion_layer_id_modification_idx;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 6,
|
||||
Up: migrate.Queries([]string{
|
||||
`ALTER TABLE Namespace ADD COLUMN version_format varchar(128);`,
|
||||
`UPDATE Namespace SET version_format = 'rpm' WHERE name LIKE 'rhel%' OR name LIKE 'centos%' OR name LIKE 'fedora%' OR name LIKE 'amzn%' OR name LIKE 'scientific%' OR name LIKE 'ol%' OR name LIKE 'oracle%';`,
|
||||
`UPDATE Namespace SET version_format = 'dpkg' WHERE version_format is NULL;`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`ALTER TABLE Namespace DROP COLUMN version_format;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,31 +0,0 @@
|
||||
// Copyright 2017 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 7,
|
||||
Up: migrate.Queries([]string{
|
||||
`ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(256);`,
|
||||
`ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(256);`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`ALTER TABLE Namespace ALTER COLUMN version_format SET DATA TYPE varchar(128);`,
|
||||
`ALTER TABLE Layer ALTER COLUMN name SET DATA TYPE varchar(128);`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
// Copyright 2016 clair authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package migrations
|
||||
|
||||
import "github.com/remind101/migrate"
|
||||
|
||||
func init() {
|
||||
RegisterMigration(migrate.Migration{
|
||||
ID: 8,
|
||||
Up: migrate.Queries([]string{
|
||||
// set on deletion, remove the corresponding rows in database
|
||||
`CREATE TABLE IF NOT EXISTS Layer_Namespace(
|
||||
id SERIAL PRIMARY KEY,
|
||||
layer_id INT REFERENCES Layer(id) ON DELETE CASCADE,
|
||||
namespace_id INT REFERENCES Namespace(id) ON DELETE CASCADE,
|
||||
unique(layer_id, namespace_id)
|
||||
);`,
|
||||
`CREATE INDEX ON Layer_Namespace (namespace_id);`,
|
||||
`CREATE INDEX ON Layer_Namespace (layer_id);`,
|
||||
// move the namespace_id to the table
|
||||
`INSERT INTO Layer_Namespace (layer_id, namespace_id) SELECT id, namespace_id FROM Layer;`,
|
||||
// alter the Layer table to remove the column
|
||||
`ALTER TABLE IF EXISTS Layer DROP namespace_id;`,
|
||||
}),
|
||||
Down: migrate.Queries([]string{
|
||||
`ALTER TABLE IF EXISTS Layer ADD namespace_id INT NULL REFERENCES Namespace;`,
|
||||
`CREATE INDEX ON Layer (namespace_id);`,
|
||||
`UPDATE IF EXISTS Layer SET namespace_id = (SELECT lns.namespace_id FROM Layer_Namespace lns WHERE Layer.id = lns.layer_id LIMIT 1);`,
|
||||
`DROP TABLE IF EXISTS Layer_Namespace;`,
|
||||
}),
|
||||
})
|
||||
}
|
@ -15,61 +15,82 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"time"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"sort"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) {
|
||||
if namespace.Name == "" {
|
||||
return 0, commonerr.NewBadRequestError("could not find/insert invalid Namespace")
|
||||
var (
|
||||
errNamespaceNotFound = errors.New("Requested Namespace is not in database")
|
||||
)
|
||||
|
||||
// PersistNamespaces soi namespaces into database.
|
||||
func (tx *pgSession) PersistNamespaces(namespaces []database.Namespace) error {
|
||||
if len(namespaces) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if pgSQL.cache != nil {
|
||||
promCacheQueriesTotal.WithLabelValues("namespace").Inc()
|
||||
if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found {
|
||||
promCacheHitsTotal.WithLabelValues("namespace").Inc()
|
||||
return id.(int), nil
|
||||
// Sorting is needed before inserting into database to prevent deadlock.
|
||||
sort.Slice(namespaces, func(i, j int) bool {
|
||||
return namespaces[i].Name < namespaces[j].Name &&
|
||||
namespaces[i].VersionFormat < namespaces[j].VersionFormat
|
||||
})
|
||||
|
||||
keys := make([]interface{}, len(namespaces)*2)
|
||||
for i, ns := range namespaces {
|
||||
if ns.Name == "" || ns.VersionFormat == "" {
|
||||
return commonerr.NewBadRequestError("Empty namespace name or version format is not allowed")
|
||||
}
|
||||
keys[i*2] = ns.Name
|
||||
keys[i*2+1] = ns.VersionFormat
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe cached namespaces.
|
||||
defer observeQueryTime("insertNamespace", "all", time.Now())
|
||||
|
||||
var id int
|
||||
err := pgSQL.QueryRow(soiNamespace, namespace.Name, namespace.VersionFormat).Scan(&id)
|
||||
_, err := tx.Exec(queryPersistNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return 0, handleError("soiNamespace", err)
|
||||
return handleError("queryPersistNamespace", err)
|
||||
}
|
||||
|
||||
if pgSQL.cache != nil {
|
||||
pgSQL.cache.Add("namespace:"+namespace.Name, id)
|
||||
}
|
||||
|
||||
return id, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) ListNamespaces() (namespaces []database.Namespace, err error) {
|
||||
rows, err := pgSQL.Query(listNamespace)
|
||||
if err != nil {
|
||||
return namespaces, handleError("listNamespace", err)
|
||||
func (tx *pgSession) findNamespaceIDs(namespaces []database.Namespace) ([]sql.NullInt64, error) {
|
||||
if len(namespaces) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
keys := make([]interface{}, len(namespaces)*2)
|
||||
nsMap := map[database.Namespace]sql.NullInt64{}
|
||||
for i, n := range namespaces {
|
||||
keys[i*2] = n.Name
|
||||
keys[i*2+1] = n.VersionFormat
|
||||
nsMap[n] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
rows, err := tx.Query(querySearchNamespace(len(namespaces)), keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("searchNamespace", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
var (
|
||||
id sql.NullInt64
|
||||
ns database.Namespace
|
||||
)
|
||||
for rows.Next() {
|
||||
var ns database.Namespace
|
||||
|
||||
err = rows.Scan(&ns.ID, &ns.Name, &ns.VersionFormat)
|
||||
err := rows.Scan(&id, &ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
return namespaces, handleError("listNamespace.Scan()", err)
|
||||
return nil, handleError("searchNamespace", err)
|
||||
}
|
||||
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return namespaces, handleError("listNamespace.Rows()", err)
|
||||
nsMap[ns] = id
|
||||
}
|
||||
|
||||
return namespaces, err
|
||||
ids := make([]sql.NullInt64, len(namespaces))
|
||||
for i, ns := range namespaces {
|
||||
ids[i] = nsMap[ns]
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
@ -15,60 +15,69 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
)
|
||||
|
||||
func TestInsertNamespace(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("InsertNamespace", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
func TestPersistNamespaces(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "PersistNamespaces", false)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// Invalid Namespace.
|
||||
id0, err := datastore.insertNamespace(database.Namespace{})
|
||||
assert.NotNil(t, err)
|
||||
assert.Zero(t, id0)
|
||||
ns1 := database.Namespace{}
|
||||
ns2 := database.Namespace{Name: "t", VersionFormat: "b"}
|
||||
|
||||
// Insert Namespace and ensure we can find it.
|
||||
id1, err := datastore.insertNamespace(database.Namespace{
|
||||
Name: "TestInsertNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
id2, err := datastore.insertNamespace(database.Namespace{
|
||||
Name: "TestInsertNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
})
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, id1, id2)
|
||||
// Empty Case
|
||||
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{}))
|
||||
// Invalid Case
|
||||
assert.NotNil(t, tx.PersistNamespaces([]database.Namespace{ns1}))
|
||||
// Duplicated Case
|
||||
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2, ns2}))
|
||||
// Existing Case
|
||||
assert.Nil(t, tx.PersistNamespaces([]database.Namespace{ns2}))
|
||||
|
||||
nsList := listNamespaces(t, tx)
|
||||
assert.Len(t, nsList, 1)
|
||||
assert.Equal(t, ns2, nsList[0])
|
||||
}
|
||||
|
||||
func TestListNamespace(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("ListNamespaces", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
namespaces, err := datastore.ListNamespaces()
|
||||
assert.Nil(t, err)
|
||||
if assert.Len(t, namespaces, 2) {
|
||||
for _, namespace := range namespaces {
|
||||
switch namespace.Name {
|
||||
case "debian:7", "debian:8":
|
||||
continue
|
||||
default:
|
||||
assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name))
|
||||
func assertNamespacesEqual(t *testing.T, expected []database.Namespace, actual []database.Namespace) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.Namespace]bool{}
|
||||
for _, i := range expected {
|
||||
has[i] = false
|
||||
}
|
||||
for _, i := range actual {
|
||||
has[i] = true
|
||||
}
|
||||
for key, v := range has {
|
||||
if !assert.True(t, v, key.Name+"is expected") {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func listNamespaces(t *testing.T, tx *pgSession) []database.Namespace {
|
||||
rows, err := tx.Query("SELECT name, version_format FROM namespace")
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
namespaces := []database.Namespace{}
|
||||
for rows.Next() {
|
||||
var ns database.Namespace
|
||||
err := rows.Scan(&ns.Name, &ns.VersionFormat)
|
||||
if err != nil {
|
||||
t.FailNow()
|
||||
}
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
|
||||
return namespaces
|
||||
}
|
||||
|
@ -16,235 +16,320 @@ package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/guregu/null/zero"
|
||||
"github.com/pborman/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
// do it in tx so we won't insert/update a vuln without notification and vice-versa.
|
||||
// name and created doesn't matter.
|
||||
func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error {
|
||||
defer observeQueryTime("createNotification", "all", time.Now())
|
||||
var (
|
||||
errNotificationNotFound = errors.New("requested notification is not found")
|
||||
)
|
||||
|
||||
// Insert Notification.
|
||||
oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0}
|
||||
newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0}
|
||||
_, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID)
|
||||
func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error {
|
||||
if len(notifications) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
newVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64)
|
||||
oldVulnIDMap = make(map[database.VulnerabilityID]sql.NullInt64)
|
||||
)
|
||||
|
||||
invalidCreationTime := time.Time{}
|
||||
for _, noti := range notifications {
|
||||
if noti.Name == "" {
|
||||
return commonerr.NewBadRequestError("notification should not have empty name")
|
||||
}
|
||||
if noti.Created == invalidCreationTime {
|
||||
return commonerr.NewBadRequestError("notification should not have empty created time")
|
||||
}
|
||||
|
||||
if noti.New != nil {
|
||||
key := database.VulnerabilityID{
|
||||
Name: noti.New.Name,
|
||||
Namespace: noti.New.Namespace.Name,
|
||||
}
|
||||
newVulnIDMap[key] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
if noti.Old != nil {
|
||||
key := database.VulnerabilityID{
|
||||
Name: noti.Old.Name,
|
||||
Namespace: noti.Old.Namespace.Name,
|
||||
}
|
||||
oldVulnIDMap[key] = sql.NullInt64{}
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
newVulnIDs = make([]database.VulnerabilityID, 0, len(newVulnIDMap))
|
||||
oldVulnIDs = make([]database.VulnerabilityID, 0, len(oldVulnIDMap))
|
||||
)
|
||||
|
||||
for vulnID := range newVulnIDMap {
|
||||
newVulnIDs = append(newVulnIDs, vulnID)
|
||||
}
|
||||
|
||||
for vulnID := range oldVulnIDMap {
|
||||
oldVulnIDs = append(oldVulnIDs, vulnID)
|
||||
}
|
||||
|
||||
ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("insertNotification", err)
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
|
||||
}
|
||||
newVulnIDMap[newVulnIDs[i]] = id
|
||||
}
|
||||
|
||||
ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for i, id := range ids {
|
||||
if !id.Valid {
|
||||
return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
|
||||
}
|
||||
oldVulnIDMap[oldVulnIDs[i]] = id
|
||||
}
|
||||
|
||||
var (
|
||||
newVulnID sql.NullInt64
|
||||
oldVulnID sql.NullInt64
|
||||
)
|
||||
|
||||
keys := make([]interface{}, len(notifications)*4)
|
||||
for i, noti := range notifications {
|
||||
if noti.New != nil {
|
||||
newVulnID = newVulnIDMap[database.VulnerabilityID{
|
||||
Name: noti.New.Name,
|
||||
Namespace: noti.New.Namespace.Name,
|
||||
}]
|
||||
}
|
||||
|
||||
if noti.Old != nil {
|
||||
oldVulnID = oldVulnIDMap[database.VulnerabilityID{
|
||||
Name: noti.Old.Name,
|
||||
Namespace: noti.Old.Namespace.Name,
|
||||
}]
|
||||
}
|
||||
|
||||
keys[4*i] = noti.Name
|
||||
keys[4*i+1] = noti.Created
|
||||
keys[4*i+2] = oldVulnID
|
||||
keys[4*i+3] = newVulnID
|
||||
}
|
||||
|
||||
// NOTE(Sida): The data is not sorted before inserting into database under
|
||||
// the fact that there's only one updater running at a time. If there are
|
||||
// multiple updaters, deadlock may happen.
|
||||
_, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...)
|
||||
if err != nil {
|
||||
return handleError("queryInsertNotifications", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)).
|
||||
// Does not fill new/old vuln.
|
||||
func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) {
|
||||
defer observeQueryTime("GetAvailableNotification", "all", time.Now())
|
||||
|
||||
before := time.Now().Add(-renotifyInterval)
|
||||
row := pgSQL.QueryRow(searchNotificationAvailable, before)
|
||||
notification, err := pgSQL.scanNotification(row, false)
|
||||
|
||||
return notification, handleError("searchNotificationAvailable", err)
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) {
|
||||
defer observeQueryTime("GetNotification", "all", time.Now())
|
||||
|
||||
// Get Notification.
|
||||
notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true)
|
||||
if err != nil {
|
||||
return notification, page, handleError("searchNotification", err)
|
||||
}
|
||||
|
||||
// Load vulnerabilities' LayersIntroducingVulnerability.
|
||||
page.OldVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
|
||||
notification.OldVulnerability,
|
||||
limit,
|
||||
page.OldVulnerability,
|
||||
func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) {
|
||||
var (
|
||||
notification database.NotificationHook
|
||||
created zero.Time
|
||||
notified zero.Time
|
||||
deleted zero.Time
|
||||
)
|
||||
|
||||
err := tx.QueryRow(searchNotificationAvailable, notifiedBefore).Scan(¬ification.Name, &created, ¬ified, &deleted)
|
||||
if err != nil {
|
||||
return notification, page, err
|
||||
}
|
||||
|
||||
page.NewVulnerability, err = pgSQL.loadLayerIntroducingVulnerability(
|
||||
notification.NewVulnerability,
|
||||
limit,
|
||||
page.NewVulnerability,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return notification, page, err
|
||||
}
|
||||
|
||||
return notification, page, nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) {
|
||||
var notification database.VulnerabilityNotification
|
||||
var created zero.Time
|
||||
var notified zero.Time
|
||||
var deleted zero.Time
|
||||
var oldVulnerabilityNullableID sql.NullInt64
|
||||
var newVulnerabilityNullableID sql.NullInt64
|
||||
|
||||
// Scan notification.
|
||||
if hasVulns {
|
||||
err := row.Scan(
|
||||
¬ification.ID,
|
||||
¬ification.Name,
|
||||
&created,
|
||||
¬ified,
|
||||
&deleted,
|
||||
&oldVulnerabilityNullableID,
|
||||
&newVulnerabilityNullableID,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return notification, err
|
||||
}
|
||||
} else {
|
||||
err := row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted)
|
||||
|
||||
if err != nil {
|
||||
return notification, err
|
||||
if err == sql.ErrNoRows {
|
||||
return notification, false, nil
|
||||
}
|
||||
return notification, false, handleError("searchNotificationAvailable", err)
|
||||
}
|
||||
|
||||
notification.Created = created.Time
|
||||
notification.Notified = notified.Time
|
||||
notification.Deleted = deleted.Time
|
||||
|
||||
if hasVulns {
|
||||
if oldVulnerabilityNullableID.Valid {
|
||||
vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64))
|
||||
if err != nil {
|
||||
return notification, err
|
||||
}
|
||||
|
||||
notification.OldVulnerability = &vulnerability
|
||||
}
|
||||
|
||||
if newVulnerabilityNullableID.Valid {
|
||||
vulnerability, err := pgSQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64))
|
||||
if err != nil {
|
||||
return notification, err
|
||||
}
|
||||
|
||||
notification.NewVulnerability = &vulnerability
|
||||
}
|
||||
}
|
||||
|
||||
return notification, nil
|
||||
return notification, true, nil
|
||||
}
|
||||
|
||||
// Fills Vulnerability.LayersIntroducingVulnerability.
|
||||
// limit -1: won't do anything
|
||||
// limit 0: will just get the startID of the second page
|
||||
func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) {
|
||||
tf := time.Now()
|
||||
|
||||
if vulnerability == nil {
|
||||
return -1, nil
|
||||
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) {
|
||||
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
|
||||
current := idPageNumber{0}
|
||||
if currentPage != "" {
|
||||
var err error
|
||||
current, err = decryptPage(currentPage, tx.paginationKey)
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
}
|
||||
|
||||
// A startID equals to -1 means that we reached the end already.
|
||||
if startID == -1 || limit == -1 {
|
||||
return -1, nil
|
||||
}
|
||||
|
||||
// Create a transaction to disable hash joins as our experience shows that
|
||||
// PostgreSQL plans in certain cases a sequential scan and a hash on
|
||||
// Layer_diff_FeatureVersion for the condition `ldfv.layer_id >= $2 AND
|
||||
// ldfv.modification = 'add'` before realizing a hash inner join with
|
||||
// Vulnerability_Affects_FeatureVersion. By disabling explictly hash joins,
|
||||
// we force PostgreSQL to perform a bitmap index scan with
|
||||
// `ldfv.featureversion_id = fv.id` on Layer_diff_FeatureVersion, followed by
|
||||
// a bitmap heap scan on `ldfv.layer_id >= $2 AND ldfv.modification = 'add'`,
|
||||
// thus avoiding a sequential scan on the biggest database table and
|
||||
// allowing a small nested loop join instead.
|
||||
tx, err := pgSQL.Begin()
|
||||
err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
|
||||
&vulnPage.Name,
|
||||
&vulnPage.Description,
|
||||
&vulnPage.Link,
|
||||
&vulnPage.Severity,
|
||||
&vulnPage.Metadata,
|
||||
&vulnPage.Namespace.Name,
|
||||
&vulnPage.Namespace.VersionFormat,
|
||||
)
|
||||
if err != nil {
|
||||
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Begin()", err)
|
||||
}
|
||||
defer tx.Commit()
|
||||
|
||||
_, err = tx.Exec(disableHashJoin)
|
||||
if err != nil {
|
||||
log.WithError(err).Warning("searchNotificationLayerIntroducingVulnerability: could not disable hash join")
|
||||
return vulnPage, handleError("searchVulnerabilityByID", err)
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe invalid calls.
|
||||
defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf)
|
||||
|
||||
// Query with limit + 1, the last item will be used to know the next starting ID.
|
||||
rows, err := tx.Query(searchNotificationLayerIntroducingVulnerability,
|
||||
vulnerability.ID, startID, limit+1)
|
||||
// the last result is used for the next page's startID
|
||||
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1)
|
||||
if err != nil {
|
||||
return 0, handleError("searchNotificationLayerIntroducingVulnerability", err)
|
||||
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var layers []database.Layer
|
||||
ancestries := []affectedAncestry{}
|
||||
for rows.Next() {
|
||||
var layer database.Layer
|
||||
|
||||
if err := rows.Scan(&layer.ID, &layer.Name); err != nil {
|
||||
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err)
|
||||
var ancestry affectedAncestry
|
||||
err := rows.Scan(&ancestry.id, &ancestry.name)
|
||||
if err != nil {
|
||||
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err)
|
||||
ancestries = append(ancestries, ancestry)
|
||||
}
|
||||
|
||||
size := limit
|
||||
if len(layers) < limit {
|
||||
size = len(layers)
|
||||
}
|
||||
vulnerability.LayersIntroducingVulnerability = layers[:size]
|
||||
lastIndex := 0
|
||||
if len(ancestries)-1 < limit {
|
||||
lastIndex = len(ancestries)
|
||||
vulnPage.End = true
|
||||
} else {
|
||||
// Use the last ancestry's ID as the next PageNumber.
|
||||
lastIndex = len(ancestries) - 1
|
||||
vulnPage.Next, err = encryptPage(
|
||||
idPageNumber{
|
||||
ancestries[len(ancestries)-1].id,
|
||||
}, tx.paginationKey)
|
||||
|
||||
nextID := -1
|
||||
if len(layers) > limit {
|
||||
nextID = layers[limit].ID
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
}
|
||||
|
||||
return nextID, nil
|
||||
vulnPage.Affected = map[int]string{}
|
||||
for _, ancestry := range ancestries[0:lastIndex] {
|
||||
vulnPage.Affected[int(ancestry.id)] = ancestry.name
|
||||
}
|
||||
|
||||
vulnPage.Current, err = encryptPage(current, tx.paginationKey)
|
||||
if err != nil {
|
||||
return vulnPage, err
|
||||
}
|
||||
|
||||
return vulnPage, nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) SetNotificationNotified(name string) error {
|
||||
defer observeQueryTime("SetNotificationNotified", "all", time.Now())
|
||||
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) (
|
||||
database.VulnerabilityNotificationWithVulnerable, bool, error) {
|
||||
var (
|
||||
noti database.VulnerabilityNotificationWithVulnerable
|
||||
oldVulnID sql.NullInt64
|
||||
newVulnID sql.NullInt64
|
||||
created zero.Time
|
||||
notified zero.Time
|
||||
deleted zero.Time
|
||||
)
|
||||
|
||||
if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil {
|
||||
if name == "" {
|
||||
return noti, false, commonerr.NewBadRequestError("Empty notification name is not allowed")
|
||||
}
|
||||
|
||||
noti.Name = name
|
||||
|
||||
err := tx.QueryRow(searchNotification, name).Scan(&created, ¬ified,
|
||||
&deleted, &oldVulnID, &newVulnID)
|
||||
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return noti, false, nil
|
||||
}
|
||||
return noti, false, handleError("searchNotification", err)
|
||||
}
|
||||
|
||||
if created.Valid {
|
||||
noti.Created = created.Time
|
||||
}
|
||||
|
||||
if notified.Valid {
|
||||
noti.Notified = notified.Time
|
||||
}
|
||||
|
||||
if deleted.Valid {
|
||||
noti.Deleted = deleted.Time
|
||||
}
|
||||
|
||||
if oldVulnID.Valid {
|
||||
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage)
|
||||
if err != nil {
|
||||
return noti, false, err
|
||||
}
|
||||
noti.Old = &page
|
||||
}
|
||||
|
||||
if newVulnID.Valid {
|
||||
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage)
|
||||
if err != nil {
|
||||
return noti, false, err
|
||||
}
|
||||
noti.New = &page
|
||||
}
|
||||
|
||||
return noti, true, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) MarkNotificationNotified(name string) error {
|
||||
if name == "" {
|
||||
return commonerr.NewBadRequestError("Empty notification name is not allowed")
|
||||
}
|
||||
|
||||
r, err := tx.Exec(updatedNotificationNotified, name)
|
||||
if err != nil {
|
||||
return handleError("updatedNotificationNotified", err)
|
||||
}
|
||||
|
||||
affected, err := r.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("updatedNotificationNotified", err)
|
||||
}
|
||||
|
||||
if affected <= 0 {
|
||||
return handleError("updatedNotificationNotified", errNotificationNotFound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) DeleteNotification(name string) error {
|
||||
defer observeQueryTime("DeleteNotification", "all", time.Now())
|
||||
func (tx *pgSession) DeleteNotification(name string) error {
|
||||
if name == "" {
|
||||
return commonerr.NewBadRequestError("Empty notification name is not allowed")
|
||||
}
|
||||
|
||||
result, err := pgSQL.Exec(removeNotification, name)
|
||||
result, err := tx.Exec(removeNotification, name)
|
||||
if err != nil {
|
||||
return handleError("removeNotification", err)
|
||||
}
|
||||
|
||||
affected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return handleError("removeNotification.RowsAffected()", err)
|
||||
return handleError("removeNotification", err)
|
||||
}
|
||||
|
||||
if affected <= 0 {
|
||||
return commonerr.ErrNotFound
|
||||
return handleError("removeNotification", commonerr.ErrNotFound)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
@ -21,211 +21,225 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func TestNotification(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("Notification", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
func TestPagination(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "Pagination", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// Try to get a notification when there is none.
|
||||
_, err = datastore.GetAvailableNotification(time.Second)
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
|
||||
// Create some data.
|
||||
f1 := database.Feature{
|
||||
Name: "TestNotificationFeature1",
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestNotificationNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
ns := database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
f2 := database.Feature{
|
||||
Name: "TestNotificationFeature2",
|
||||
Namespace: database.Namespace{
|
||||
Name: "TestNotificationNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
vNew := database.Vulnerability{
|
||||
Namespace: ns,
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
|
||||
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
|
||||
Severity: database.HighSeverity,
|
||||
}
|
||||
|
||||
l1 := database.Layer{
|
||||
Name: "TestNotificationLayer1",
|
||||
Features: []database.FeatureVersion{
|
||||
{
|
||||
Feature: f1,
|
||||
Version: "0.1",
|
||||
},
|
||||
},
|
||||
vOld := database.Vulnerability{
|
||||
Namespace: ns,
|
||||
Name: "CVE-NOPE",
|
||||
Description: "A vulnerability affecting nothing",
|
||||
Severity: database.UnknownSeverity,
|
||||
}
|
||||
|
||||
l2 := database.Layer{
|
||||
Name: "TestNotificationLayer2",
|
||||
Features: []database.FeatureVersion{
|
||||
{
|
||||
Feature: f1,
|
||||
Version: "0.2",
|
||||
},
|
||||
},
|
||||
noti, ok, err := tx.FindVulnerabilityNotification("test", 1, "", "")
|
||||
oldPage := database.PagedVulnerableAncestries{
|
||||
Vulnerability: vOld,
|
||||
Limit: 1,
|
||||
Affected: make(map[int]string),
|
||||
End: true,
|
||||
}
|
||||
|
||||
l3 := database.Layer{
|
||||
Name: "TestNotificationLayer3",
|
||||
Features: []database.FeatureVersion{
|
||||
{
|
||||
Feature: f1,
|
||||
Version: "0.3",
|
||||
},
|
||||
},
|
||||
newPage1 := database.PagedVulnerableAncestries{
|
||||
Vulnerability: vNew,
|
||||
Limit: 1,
|
||||
Affected: map[int]string{3: "ancestry-3"},
|
||||
End: false,
|
||||
}
|
||||
|
||||
l4 := database.Layer{
|
||||
Name: "TestNotificationLayer4",
|
||||
Features: []database.FeatureVersion{
|
||||
{
|
||||
Feature: f2,
|
||||
Version: "0.1",
|
||||
},
|
||||
},
|
||||
newPage2 := database.PagedVulnerableAncestries{
|
||||
Vulnerability: vNew,
|
||||
Limit: 1,
|
||||
Affected: map[int]string{4: "ancestry-4"},
|
||||
Next: "",
|
||||
End: true,
|
||||
}
|
||||
|
||||
if !assert.Nil(t, datastore.InsertLayer(l1)) ||
|
||||
!assert.Nil(t, datastore.InsertLayer(l2)) ||
|
||||
!assert.Nil(t, datastore.InsertLayer(l3)) ||
|
||||
!assert.Nil(t, datastore.InsertLayer(l4)) {
|
||||
return
|
||||
}
|
||||
|
||||
// Insert a new vulnerability that is introduced by three layers.
|
||||
v1 := database.Vulnerability{
|
||||
Name: "TestNotificationVulnerability1",
|
||||
Namespace: f1.Namespace,
|
||||
Description: "TestNotificationDescription1",
|
||||
Link: "TestNotificationLink1",
|
||||
Severity: "Unknown",
|
||||
FixedIn: []database.FeatureVersion{
|
||||
{
|
||||
Feature: f1,
|
||||
Version: "1.0",
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Nil(t, datastore.insertVulnerability(v1, false, true))
|
||||
|
||||
// Get the notification associated to the previously inserted vulnerability.
|
||||
notification, err := datastore.GetAvailableNotification(time.Second)
|
||||
|
||||
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
|
||||
// Verify the renotify behaviour.
|
||||
if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) {
|
||||
_, err := datastore.GetAvailableNotification(time.Second)
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, notification.Name, notificationB.Name)
|
||||
|
||||
datastore.SetNotificationNotified(notification.Name)
|
||||
}
|
||||
|
||||
// Get notification.
|
||||
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
|
||||
if assert.Nil(t, err) {
|
||||
assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage)
|
||||
assert.Nil(t, filledNotification.OldVulnerability)
|
||||
|
||||
if assert.NotNil(t, filledNotification.NewVulnerability) {
|
||||
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
|
||||
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2)
|
||||
}
|
||||
}
|
||||
|
||||
// Get second page.
|
||||
filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage)
|
||||
assert.Nil(t, filledNotification.OldVulnerability)
|
||||
|
||||
if assert.NotNil(t, filledNotification.NewVulnerability) {
|
||||
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
|
||||
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete notification.
|
||||
assert.Nil(t, datastore.DeleteNotification(notification.Name))
|
||||
|
||||
_, err = datastore.GetAvailableNotification(time.Millisecond)
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
}
|
||||
|
||||
// Update a vulnerability and ensure that the old/new vulnerabilities are correct.
|
||||
v1b := v1
|
||||
v1b.Severity = database.HighSeverity
|
||||
v1b.FixedIn = []database.FeatureVersion{
|
||||
{
|
||||
Feature: f1,
|
||||
Version: versionfmt.MinVersion,
|
||||
},
|
||||
{
|
||||
Feature: f2,
|
||||
Version: versionfmt.MaxVersion,
|
||||
},
|
||||
}
|
||||
|
||||
if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) {
|
||||
notification, err = datastore.GetAvailableNotification(time.Second)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEmpty(t, notification.Name)
|
||||
|
||||
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
|
||||
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
|
||||
if assert.Nil(t, err) {
|
||||
if assert.NotNil(t, filledNotification.OldVulnerability) {
|
||||
assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name)
|
||||
assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity)
|
||||
assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2)
|
||||
}
|
||||
|
||||
if assert.NotNil(t, filledNotification.NewVulnerability) {
|
||||
assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name)
|
||||
assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity)
|
||||
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
|
||||
}
|
||||
|
||||
assert.Equal(t, -1, nextPage.NewVulnerability)
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assert.Equal(t, "test", noti.Name)
|
||||
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
|
||||
oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
|
||||
assert.Nil(t, datastore.DeleteNotification(notification.Name))
|
||||
assert.Equal(t, int64(0), oldPageNum.StartID)
|
||||
newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
assert.Equal(t, int64(0), newPageNum.StartID)
|
||||
assert.Equal(t, int64(4), newPageNextNum.StartID)
|
||||
|
||||
noti.Old.Current = ""
|
||||
noti.New.Current = ""
|
||||
noti.New.Next = ""
|
||||
assert.Equal(t, oldPage, *noti.Old)
|
||||
assert.Equal(t, newPage1, *noti.New)
|
||||
}
|
||||
}
|
||||
|
||||
// Delete a vulnerability and verify the notification.
|
||||
if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) {
|
||||
notification, err = datastore.GetAvailableNotification(time.Second)
|
||||
assert.Nil(t, err)
|
||||
assert.NotEmpty(t, notification.Name)
|
||||
page1, err := encryptPage(idPageNumber{0}, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
|
||||
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
|
||||
filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
|
||||
if assert.Nil(t, err) {
|
||||
assert.Nil(t, filledNotification.NewVulnerability)
|
||||
page2, err := encryptPage(idPageNumber{4}, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
|
||||
if assert.NotNil(t, filledNotification.OldVulnerability) {
|
||||
assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name)
|
||||
assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity)
|
||||
assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1)
|
||||
}
|
||||
noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2)
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assert.Equal(t, "test", noti.Name)
|
||||
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
|
||||
oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
|
||||
assert.Nil(t, datastore.DeleteNotification(notification.Name))
|
||||
newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey)
|
||||
if !assert.Nil(t, err) {
|
||||
assert.FailNow(t, "")
|
||||
}
|
||||
|
||||
assert.Equal(t, int64(0), oldCurrentPage.StartID)
|
||||
assert.Equal(t, int64(4), newCurrentPage.StartID)
|
||||
noti.Old.Current = ""
|
||||
noti.New.Current = ""
|
||||
assert.Equal(t, oldPage, *noti.Old)
|
||||
assert.Equal(t, newPage2, *noti.New)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertVulnerabilityNotifications(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true)
|
||||
|
||||
n1 := database.VulnerabilityNotification{}
|
||||
n3 := database.VulnerabilityNotification{
|
||||
NotificationHook: database.NotificationHook{
|
||||
Name: "random name",
|
||||
Created: time.Now(),
|
||||
},
|
||||
Old: nil,
|
||||
New: &database.Vulnerability{},
|
||||
}
|
||||
n4 := database.VulnerabilityNotification{
|
||||
NotificationHook: database.NotificationHook{
|
||||
Name: "random name",
|
||||
Created: time.Now(),
|
||||
},
|
||||
Old: nil,
|
||||
New: &database.Vulnerability{
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// invalid case
|
||||
err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// invalid case: unknown vulnerability
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
// invalid case: duplicated input notification
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4})
|
||||
assert.NotNil(t, err)
|
||||
tx = restartSession(t, datastore, tx, false)
|
||||
|
||||
// valid case
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
|
||||
assert.Nil(t, err)
|
||||
// invalid case: notification is already in database
|
||||
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
closeTest(t, datastore, tx)
|
||||
}
|
||||
|
||||
func TestFindNewNotification(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindNewNotification", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
noti, ok, err := tx.FindNewNotification(time.Now())
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assert.Equal(t, "test", noti.Name)
|
||||
assert.Equal(t, time.Time{}, noti.Notified)
|
||||
assert.Equal(t, time.Time{}, noti.Created)
|
||||
assert.Equal(t, time.Time{}, noti.Deleted)
|
||||
}
|
||||
|
||||
// can't find the notified
|
||||
assert.Nil(t, tx.MarkNotificationNotified("test"))
|
||||
// if the notified time is before
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second)))
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
// can find the notified after a period of time
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
|
||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||
assert.Equal(t, "test", noti.Name)
|
||||
assert.NotEqual(t, time.Time{}, noti.Notified)
|
||||
assert.Equal(t, time.Time{}, noti.Created)
|
||||
assert.Equal(t, time.Time{}, noti.Deleted)
|
||||
}
|
||||
|
||||
assert.Nil(t, tx.DeleteNotification("test"))
|
||||
// can't find in any time
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000)))
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
|
||||
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestMarkNotificationNotified(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "MarkNotificationNotified", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// invalid case: notification doesn't exist
|
||||
assert.NotNil(t, tx.MarkNotificationNotified("non-existing"))
|
||||
// valid case
|
||||
assert.Nil(t, tx.MarkNotificationNotified("test"))
|
||||
// valid case
|
||||
assert.Nil(t, tx.MarkNotificationNotified("test"))
|
||||
}
|
||||
|
||||
func TestDeleteNotification(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "DeleteNotification", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
// invalid case: notification doesn't exist
|
||||
assert.NotNil(t, tx.DeleteNotification("non-existing"))
|
||||
// valid case
|
||||
assert.Nil(t, tx.DeleteNotification("test"))
|
||||
// invalid case: notification is already deleted
|
||||
assert.NotNil(t, tx.DeleteNotification("test"))
|
||||
}
|
||||
|
@ -31,6 +31,7 @@ import (
|
||||
"github.com/remind101/migrate"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/api/token"
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/database/pgsql/migrations"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
@ -59,7 +60,7 @@ var (
|
||||
|
||||
promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
|
||||
Name: "clair_pgsql_concurrent_lock_vafv_total",
|
||||
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.",
|
||||
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_Feature lock.",
|
||||
})
|
||||
)
|
||||
|
||||
@ -73,17 +74,65 @@ func init() {
|
||||
database.Register("pgsql", openDatabase)
|
||||
}
|
||||
|
||||
type Queryer interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
// pgSessionCache is the session's cache, which holds the pgSQL's cache and the
|
||||
// individual session's cache. Only when session.Commit is called, all the
|
||||
// changes to pgSQL cache will be applied.
|
||||
type pgSessionCache struct {
|
||||
c *lru.ARCCache
|
||||
}
|
||||
|
||||
type pgSQL struct {
|
||||
*sql.DB
|
||||
|
||||
cache *lru.ARCCache
|
||||
config Config
|
||||
}
|
||||
|
||||
type pgSession struct {
|
||||
*sql.Tx
|
||||
|
||||
paginationKey string
|
||||
}
|
||||
|
||||
type idPageNumber struct {
|
||||
// StartID is an implementation detail for paginating by an ID required to
|
||||
// be unique to every ancestry and always increasing.
|
||||
//
|
||||
// StartID is used to search for ancestry with ID >= StartID
|
||||
StartID int64
|
||||
}
|
||||
|
||||
func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) {
|
||||
resultBytes, err := token.Marshal(page, paginationKey)
|
||||
if err != nil {
|
||||
return result, err
|
||||
}
|
||||
result = database.PageNumber(resultBytes)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) {
|
||||
err = token.Unmarshal(string(page), paginationKey, &result)
|
||||
return
|
||||
}
|
||||
|
||||
// Begin initiates a transaction to database. The expected transaction isolation
|
||||
// level in this implementation is "Read Committed".
|
||||
func (pgSQL *pgSQL) Begin() (database.Session, error) {
|
||||
tx, err := pgSQL.DB.Begin()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgSession{
|
||||
Tx: tx,
|
||||
paginationKey: pgSQL.config.PaginationKey,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) Commit() error {
|
||||
return tx.Tx.Commit()
|
||||
}
|
||||
|
||||
// Close closes the database and destroys if ManageDatabaseLifecycle has been specified in
|
||||
// the configuration.
|
||||
func (pgSQL *pgSQL) Close() {
|
||||
@ -109,6 +158,7 @@ type Config struct {
|
||||
|
||||
ManageDatabaseLifecycle bool
|
||||
FixturePath string
|
||||
PaginationKey string
|
||||
}
|
||||
|
||||
// openDatabase opens a PostgresSQL-backed Datastore using the given
|
||||
@ -134,6 +184,10 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
|
||||
if pg.config.PaginationKey == "" {
|
||||
panic("pagination key should be given")
|
||||
}
|
||||
|
||||
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -179,7 +233,7 @@ func openDatabase(registrableComponentConfig database.RegistrableComponentConfig
|
||||
_, err = pg.DB.Exec(string(d))
|
||||
if err != nil {
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: an error occured while importing fixtures: %v", err)
|
||||
return nil, fmt.Errorf("pgsql: an error occurred while importing fixtures: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -217,7 +271,7 @@ func migrateDatabase(db *sql.DB) error {
|
||||
|
||||
err := migrate.NewPostgresMigrator(db).Exec(migrate.Up, migrations.Migrations...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgsql: an error occured while running migrations: %v", err)
|
||||
return fmt.Errorf("pgsql: an error occurred while running migrations: %v", err)
|
||||
}
|
||||
|
||||
log.Info("database migration ran successfully")
|
||||
@ -271,7 +325,8 @@ func dropDatabase(source, dbName string) error {
|
||||
}
|
||||
|
||||
// handleError logs an error with an extra description and masks the error if it's an SQL one.
|
||||
// This ensures we never return plain SQL errors and leak anything.
|
||||
// The function ensures we never return plain SQL errors and leak anything.
|
||||
// The function should be used for every database query error.
|
||||
func handleError(desc string, err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
@ -297,6 +352,11 @@ func isErrUniqueViolation(err error) bool {
|
||||
return ok && pqErr.Code == "23505"
|
||||
}
|
||||
|
||||
// observeQueryTime computes the time elapsed since `start` to represent the
|
||||
// query time.
|
||||
// 1. `query` is a pgSession function name.
|
||||
// 2. `subquery` is a specific query or a batched query.
|
||||
// 3. `start` is the time right before query is executed.
|
||||
func observeQueryTime(query, subquery string, start time.Time) {
|
||||
promQueryDurationMilliseconds.
|
||||
WithLabelValues(query, subquery).
|
||||
|
@ -15,27 +15,193 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
fernet "github.com/fernet/fernet-go"
|
||||
"github.com/pborman/uuid"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
yaml "gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
)
|
||||
|
||||
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
|
||||
ds, err := openDatabase(generateTestConfig(testName, loadFixture))
|
||||
var (
|
||||
withFixtureName, withoutFixtureName string
|
||||
)
|
||||
|
||||
func genTemplateDatabase(name string, loadFixture bool) (sourceURL string, dbName string) {
|
||||
config := generateTestConfig(name, loadFixture, false)
|
||||
source := config.Options["source"].(string)
|
||||
name, url, err := parseConnectionString(source)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
fixturePath := config.Options["fixturepath"].(string)
|
||||
|
||||
if err := createDatabase(url, name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// migration and fixture
|
||||
db, err := sql.Open("postgres", source)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Verify database state.
|
||||
if err := db.Ping(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// Run migrations.
|
||||
if err := migrateDatabase(db); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if loadFixture {
|
||||
log.Info("pgsql: loading fixtures")
|
||||
|
||||
d, err := ioutil.ReadFile(fixturePath)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(d))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
db.Exec("UPDATE pg_database SET datistemplate=True WHERE datname=$1", name)
|
||||
db.Close()
|
||||
|
||||
log.Info("Generated Template database ", name)
|
||||
return url, name
|
||||
}
|
||||
|
||||
func dropTemplateDatabase(url string, name string) {
|
||||
db, err := sql.Open("postgres", url)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if _, err := db.Exec("UPDATE pg_database SET datistemplate=False WHERE datname=$1", name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := db.Close(); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if err := dropDatabase(url, name); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
}
|
||||
func TestMain(m *testing.M) {
|
||||
fURL, fName := genTemplateDatabase("fixture", true)
|
||||
nfURL, nfName := genTemplateDatabase("nonfixture", false)
|
||||
|
||||
withFixtureName = fName
|
||||
withoutFixtureName = nfName
|
||||
|
||||
m.Run()
|
||||
|
||||
dropTemplateDatabase(fURL, fName)
|
||||
dropTemplateDatabase(nfURL, nfName)
|
||||
}
|
||||
|
||||
func openCopiedDatabase(testConfig database.RegistrableComponentConfig, fixture bool) (database.Datastore, error) {
|
||||
var fixtureName string
|
||||
if fixture {
|
||||
fixtureName = withFixtureName
|
||||
} else {
|
||||
fixtureName = withoutFixtureName
|
||||
}
|
||||
|
||||
// copy the database into new database
|
||||
var pg pgSQL
|
||||
// Parse configuration.
|
||||
pg.config = Config{
|
||||
CacheSize: 16384,
|
||||
}
|
||||
|
||||
bytes, err := yaml.Marshal(testConfig.Options)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
err = yaml.Unmarshal(bytes, &pg.config)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pgsql: could not load configuration: %v", err)
|
||||
}
|
||||
|
||||
dbName, pgSourceURL, err := parseConnectionString(pg.config.Source)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datastore := ds.(*pgSQL)
|
||||
|
||||
// Create database.
|
||||
if pg.config.ManageDatabaseLifecycle {
|
||||
if err = copyDatabase(pgSourceURL, dbName, fixtureName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Open database.
|
||||
pg.DB, err = sql.Open("postgres", pg.config.Source)
|
||||
fmt.Println("database", pg.config.Source)
|
||||
if err != nil {
|
||||
pg.Close()
|
||||
return nil, fmt.Errorf("pgsql: could not open database: %v", err)
|
||||
}
|
||||
|
||||
return &pg, nil
|
||||
}
|
||||
|
||||
// copyDatabase creates a new database with
|
||||
func copyDatabase(url, name string, templateName string) error {
|
||||
// Open database.
|
||||
db, err := sql.Open("postgres", url)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgsql: could not open 'postgres' database for creation: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
// Create database with copy
|
||||
_, err = db.Exec("CREATE DATABASE " + name + " WITH TEMPLATE " + templateName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("pgsql: could not create database: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func openDatabaseForTest(testName string, loadFixture bool) (*pgSQL, error) {
|
||||
var (
|
||||
db database.Datastore
|
||||
err error
|
||||
testConfig = generateTestConfig(testName, loadFixture, true)
|
||||
)
|
||||
|
||||
db, err = openCopiedDatabase(testConfig, loadFixture)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
datastore := db.(*pgSQL)
|
||||
return datastore, nil
|
||||
}
|
||||
|
||||
func generateTestConfig(testName string, loadFixture bool) database.RegistrableComponentConfig {
|
||||
func generateTestConfig(testName string, loadFixture bool, manageLife bool) database.RegistrableComponentConfig {
|
||||
dbName := "test_" + strings.ToLower(testName) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
|
||||
|
||||
var fixturePath string
|
||||
@ -49,12 +215,60 @@ func generateTestConfig(testName string, loadFixture bool) database.RegistrableC
|
||||
source = fmt.Sprintf(sourceEnv, dbName)
|
||||
}
|
||||
|
||||
var key fernet.Key
|
||||
if err := key.Generate(); err != nil {
|
||||
panic("failed to generate pagination key" + err.Error())
|
||||
}
|
||||
|
||||
return database.RegistrableComponentConfig{
|
||||
Options: map[string]interface{}{
|
||||
"source": source,
|
||||
"cachesize": 0,
|
||||
"managedatabaselifecycle": true,
|
||||
"managedatabaselifecycle": manageLife,
|
||||
"fixturepath": fixturePath,
|
||||
"paginationkey": key.Encode(),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func closeTest(t *testing.T, store database.Datastore, session database.Session) {
|
||||
err := session.Rollback()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
store.Close()
|
||||
}
|
||||
|
||||
func openSessionForTest(t *testing.T, name string, loadFixture bool) (*pgSQL, *pgSession) {
|
||||
store, err := openDatabaseForTest(name, loadFixture)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
tx, err := store.Begin()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
t.FailNow()
|
||||
}
|
||||
return store, tx.(*pgSession)
|
||||
}
|
||||
|
||||
func restartSession(t *testing.T, datastore *pgSQL, tx *pgSession, commit bool) *pgSession {
|
||||
var err error
|
||||
if !commit {
|
||||
err = tx.Rollback()
|
||||
} else {
|
||||
err = tx.Commit()
|
||||
}
|
||||
|
||||
if assert.Nil(t, err) {
|
||||
session, err := datastore.Begin()
|
||||
if assert.Nil(t, err) {
|
||||
return session.(*pgSession)
|
||||
}
|
||||
}
|
||||
t.FailNow()
|
||||
return nil
|
||||
}
|
||||
|
@ -14,185 +14,159 @@
|
||||
|
||||
package pgsql
|
||||
|
||||
import "strconv"
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
const (
|
||||
lockVulnerabilityAffects = `LOCK Vulnerability_Affects_FeatureVersion IN SHARE ROW EXCLUSIVE MODE`
|
||||
disableHashJoin = `SET LOCAL enable_hashjoin = off`
|
||||
disableMergeJoin = `SET LOCAL enable_mergejoin = off`
|
||||
lockVulnerabilityAffects = `LOCK vulnerability_affected_namespaced_feature IN SHARE ROW EXCLUSIVE MODE`
|
||||
|
||||
// keyvalue.go
|
||||
updateKeyValue = `UPDATE KeyValue SET value = $1 WHERE key = $2`
|
||||
insertKeyValue = `INSERT INTO KeyValue(key, value) VALUES($1, $2)`
|
||||
searchKeyValue = `SELECT value FROM KeyValue WHERE key = $1`
|
||||
upsertKeyValue = `
|
||||
INSERT INTO KeyValue(key, value)
|
||||
VALUES ($1, $2)
|
||||
ON CONFLICT ON CONSTRAINT keyvalue_key_key
|
||||
DO UPDATE SET key=$1, value=$2`
|
||||
|
||||
// namespace.go
|
||||
soiNamespace = `
|
||||
WITH new_namespace AS (
|
||||
INSERT INTO Namespace(name, version_format)
|
||||
SELECT CAST($1 AS VARCHAR), CAST($2 AS VARCHAR)
|
||||
WHERE NOT EXISTS (SELECT name FROM Namespace WHERE name = $1)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id FROM Namespace WHERE name = $1
|
||||
UNION
|
||||
SELECT id FROM new_namespace`
|
||||
|
||||
searchNamespace = `SELECT id FROM Namespace WHERE name = $1`
|
||||
listNamespace = `SELECT id, name, version_format FROM Namespace`
|
||||
searchNamespaceID = `SELECT id FROM Namespace WHERE name = $1 AND version_format = $2`
|
||||
|
||||
// feature.go
|
||||
soiFeature = `
|
||||
WITH new_feature AS (
|
||||
INSERT INTO Feature(name, namespace_id)
|
||||
SELECT CAST($1 AS VARCHAR), CAST($2 AS INTEGER)
|
||||
WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2)
|
||||
soiNamespacedFeature = `
|
||||
WITH new_feature_ns AS (
|
||||
INSERT INTO namespaced_feature(feature_id, namespace_id)
|
||||
SELECT CAST ($1 AS INTEGER), CAST ($2 AS INTEGER)
|
||||
WHERE NOT EXISTS ( SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT id FROM Feature WHERE name = $1 AND namespace_id = $2
|
||||
SELECT id FROM namespaced_feature WHERE namespaced_feature.feature_id = $1 AND namespaced_feature.namespace_id = $2
|
||||
UNION
|
||||
SELECT id FROM new_feature`
|
||||
SELECT id FROM new_feature_ns`
|
||||
|
||||
searchFeatureVersion = `
|
||||
SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2`
|
||||
searchPotentialAffectingVulneraibilities = `
|
||||
SELECT nf.id, v.id, vaf.affected_version, vaf.id
|
||||
FROM vulnerability_affected_feature AS vaf, vulnerability AS v,
|
||||
namespaced_feature AS nf, feature AS f
|
||||
WHERE nf.id = ANY($1)
|
||||
AND nf.feature_id = f.id
|
||||
AND nf.namespace_id = v.namespace_id
|
||||
AND vaf.feature_name = f.name
|
||||
AND vaf.vulnerability_id = v.id
|
||||
AND v.deleted_at IS NULL`
|
||||
|
||||
soiFeatureVersion = `
|
||||
WITH new_featureversion AS (
|
||||
INSERT INTO FeatureVersion(feature_id, version)
|
||||
SELECT CAST($1 AS INTEGER), CAST($2 AS VARCHAR)
|
||||
WHERE NOT EXISTS (SELECT id FROM FeatureVersion WHERE feature_id = $1 AND version = $2)
|
||||
RETURNING id
|
||||
)
|
||||
SELECT false, id FROM FeatureVersion WHERE feature_id = $1 AND version = $2
|
||||
UNION
|
||||
SELECT true, id FROM new_featureversion`
|
||||
|
||||
searchVulnerabilityFixedInFeature = `
|
||||
SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature
|
||||
WHERE feature_id = $1`
|
||||
|
||||
insertVulnerabilityAffectsFeatureVersion = `
|
||||
INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id)
|
||||
VALUES($1, $2, $3)`
|
||||
searchNamespacedFeaturesVulnerabilities = `
|
||||
SELECT vanf.namespaced_feature_id, v.name, v.description, v.link,
|
||||
v.severity, v.metadata, vaf.fixedin, n.name, n.version_format
|
||||
FROM vulnerability_affected_namespaced_feature AS vanf,
|
||||
Vulnerability AS v,
|
||||
vulnerability_affected_feature AS vaf,
|
||||
namespace AS n
|
||||
WHERE vanf.namespaced_feature_id = ANY($1)
|
||||
AND vaf.id = vanf.added_by
|
||||
AND v.id = vanf.vulnerability_id
|
||||
AND n.id = v.namespace_id
|
||||
AND v.deleted_at IS NULL`
|
||||
|
||||
// layer.go
|
||||
searchLayer = `
|
||||
SELECT l.id, l.name, l.engineversion, p.id, p.name
|
||||
FROM Layer l
|
||||
LEFT JOIN Layer p ON l.parent_id = p.id
|
||||
WHERE l.name = $1;`
|
||||
searchLayerIDs = `SELECT id, hash FROM layer WHERE hash = ANY($1);`
|
||||
|
||||
searchLayerNamespace = `
|
||||
SELECT n.id, n.name, n.version_format
|
||||
FROM Namespace n
|
||||
JOIN Layer_Namespace lns ON lns.namespace_id = n.id
|
||||
WHERE lns.layer_id = $1`
|
||||
searchLayerFeatures = `
|
||||
SELECT feature.Name, feature.Version, feature.version_format
|
||||
FROM feature, layer_feature
|
||||
WHERE layer_feature.layer_id = $1
|
||||
AND layer_feature.feature_id = feature.id`
|
||||
|
||||
searchLayerFeatureVersion = `
|
||||
WITH RECURSIVE layer_tree(id, name, parent_id, depth, path, cycle) AS(
|
||||
SELECT l.id, l.name, l.parent_id, 1, ARRAY[l.id], false
|
||||
FROM Layer l
|
||||
WHERE l.id = $1
|
||||
UNION ALL
|
||||
SELECT l.id, l.name, l.parent_id, lt.depth + 1, path || l.id, l.id = ANY(path)
|
||||
FROM Layer l, layer_tree lt
|
||||
WHERE l.id = lt.parent_id
|
||||
)
|
||||
SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, fn.version_format, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name
|
||||
FROM Layer_diff_FeatureVersion ldf
|
||||
JOIN (
|
||||
SELECT row_number() over (ORDER BY depth DESC), id, name FROM layer_tree
|
||||
) AS ltree (ordering, id, name) ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn
|
||||
WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id
|
||||
ORDER BY ltree.ordering`
|
||||
searchLayerNamespaces = `
|
||||
SELECT namespace.Name, namespace.version_format
|
||||
FROM namespace, layer_namespace
|
||||
WHERE layer_namespace.layer_id = $1
|
||||
AND layer_namespace.namespace_id = namespace.id`
|
||||
|
||||
searchFeatureVersionVulnerability = `
|
||||
SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata,
|
||||
vn.name, vn.version_format, vfif.version
|
||||
FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v,
|
||||
Namespace vn, Vulnerability_FixedIn_Feature vfif
|
||||
WHERE vafv.featureversion_id = ANY($1::integer[])
|
||||
AND vfif.vulnerability_id = v.id
|
||||
AND vafv.fixedin_id = vfif.id
|
||||
AND v.namespace_id = vn.id
|
||||
AND v.deleted_at IS NULL`
|
||||
|
||||
insertLayer = `
|
||||
INSERT INTO Layer(name, engineversion, parent_id, created_at)
|
||||
VALUES($1, $2, $3, CURRENT_TIMESTAMP)
|
||||
RETURNING id`
|
||||
|
||||
insertLayerNamespace = `INSERT INTO Layer_Namespace(layer_id, namespace_id) VALUES($1, $2)`
|
||||
removeLayerNamespace = `DELETE FROM Layer_Namespace WHERE layer_id = $1`
|
||||
|
||||
updateLayer = `UPDATE LAYER SET engineversion = $2 WHERE id = $1`
|
||||
|
||||
removeLayerDiffFeatureVersion = `
|
||||
DELETE FROM Layer_diff_FeatureVersion
|
||||
WHERE layer_id = $1`
|
||||
|
||||
insertLayerDiffFeatureVersion = `
|
||||
INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification)
|
||||
SELECT $1, fv.id, $2
|
||||
FROM FeatureVersion fv
|
||||
WHERE fv.id = ANY($3::integer[])`
|
||||
|
||||
removeLayer = `DELETE FROM Layer WHERE name = $1`
|
||||
searchLayer = `SELECT id FROM layer WHERE hash = $1`
|
||||
searchLayerDetectors = `SELECT detector FROM layer_detector WHERE layer_id = $1`
|
||||
searchLayerListers = `SELECT lister FROM layer_lister WHERE layer_id = $1`
|
||||
|
||||
// lock.go
|
||||
insertLock = `INSERT INTO Lock(name, owner, until) VALUES($1, $2, $3)`
|
||||
soiLock = `INSERT INTO lock(name, owner, until) VALUES ($1, $2, $3)`
|
||||
|
||||
searchLock = `SELECT owner, until FROM Lock WHERE name = $1`
|
||||
updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2`
|
||||
removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2`
|
||||
removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP`
|
||||
|
||||
// vulnerability.go
|
||||
searchVulnerabilityBase = `
|
||||
SELECT v.id, v.name, n.id, n.name, n.version_format, v.description, v.link, v.severity, v.metadata
|
||||
FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id`
|
||||
searchVulnerabilityForUpdate = ` FOR UPDATE OF v`
|
||||
searchVulnerabilityByNamespaceAndName = ` WHERE n.name = $1 AND v.name = $2 AND v.deleted_at IS NULL`
|
||||
searchVulnerabilityByID = ` WHERE v.id = $1`
|
||||
searchVulnerabilityByNamespace = ` WHERE n.name = $1 AND v.deleted_at IS NULL
|
||||
AND v.id >= $2
|
||||
ORDER BY v.id
|
||||
LIMIT $3`
|
||||
searchVulnerability = `
|
||||
SELECT v.id, v.description, v.link, v.severity, v.metadata, n.version_format
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id
|
||||
AND v.name = $1
|
||||
AND n.name = $2
|
||||
AND v.deleted_at IS NULL
|
||||
`
|
||||
|
||||
searchVulnerabilityFixedIn = `
|
||||
SELECT vfif.version, f.id, f.Name
|
||||
FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id
|
||||
WHERE vfif.vulnerability_id = $1`
|
||||
insertVulnerabilityAffected = `
|
||||
INSERT INTO vulnerability_affected_feature(vulnerability_id, feature_name, affected_version, fixedin)
|
||||
VALUES ($1, $2, $3, $4)
|
||||
RETURNING ID
|
||||
`
|
||||
|
||||
searchVulnerabilityAffected = `
|
||||
SELECT vulnerability_id, feature_name, affected_version, fixedin
|
||||
FROM vulnerability_affected_feature
|
||||
WHERE vulnerability_id = ANY($1)
|
||||
`
|
||||
|
||||
searchVulnerabilityByID = `
|
||||
SELECT v.name, v.description, v.link, v.severity, v.metadata, n.name, n.version_format
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id
|
||||
AND v.id = $1`
|
||||
|
||||
searchVulnerabilityPotentialAffected = `
|
||||
WITH req AS (
|
||||
SELECT vaf.id AS vaf_id, n.id AS n_id, vaf.feature_name AS name, v.id AS vulnerability_id
|
||||
FROM vulnerability_affected_feature AS vaf,
|
||||
vulnerability AS v,
|
||||
namespace AS n
|
||||
WHERE vaf.vulnerability_id = ANY($1)
|
||||
AND v.id = vaf.vulnerability_id
|
||||
AND n.id = v.namespace_id
|
||||
)
|
||||
SELECT req.vulnerability_id, nf.id, f.version, req.vaf_id AS added_by
|
||||
FROM feature AS f, namespaced_feature AS nf, req
|
||||
WHERE f.name = req.name
|
||||
AND nf.namespace_id = req.n_id
|
||||
AND nf.feature_id = f.id`
|
||||
|
||||
insertVulnerabilityAffectedNamespacedFeature = `
|
||||
INSERT INTO vulnerability_affected_namespaced_feature(vulnerability_id, namespaced_feature_id, added_by)
|
||||
VALUES ($1, $2, $3)`
|
||||
|
||||
insertVulnerability = `
|
||||
INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at)
|
||||
VALUES($1, $2, $3, $4, $5, $6, CURRENT_TIMESTAMP)
|
||||
RETURNING id`
|
||||
|
||||
soiVulnerabilityFixedInFeature = `
|
||||
WITH new_fixedinfeature AS (
|
||||
INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version)
|
||||
SELECT CAST($1 AS INTEGER), CAST($2 AS INTEGER), CAST($3 AS VARCHAR)
|
||||
WHERE NOT EXISTS (SELECT id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2)
|
||||
RETURNING id
|
||||
WITH ns AS (
|
||||
SELECT id FROM namespace WHERE name = $6 AND version_format = $7
|
||||
)
|
||||
SELECT false, id FROM Vulnerability_FixedIn_Feature WHERE vulnerability_id = $1 AND feature_id = $2
|
||||
UNION
|
||||
SELECT true, id FROM new_fixedinfeature`
|
||||
|
||||
searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = $1`
|
||||
INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at)
|
||||
VALUES((SELECT id FROM ns), $1, $2, $3, $4, $5, CURRENT_TIMESTAMP)
|
||||
RETURNING id`
|
||||
|
||||
removeVulnerability = `
|
||||
UPDATE Vulnerability
|
||||
SET deleted_at = CURRENT_TIMESTAMP
|
||||
WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1)
|
||||
AND name = $2
|
||||
AND deleted_at IS NULL
|
||||
RETURNING id`
|
||||
SET deleted_at = CURRENT_TIMESTAMP
|
||||
WHERE namespace_id = (SELECT id FROM Namespace WHERE name = $1)
|
||||
AND name = $2
|
||||
AND deleted_at IS NULL
|
||||
RETURNING id`
|
||||
|
||||
// notification.go
|
||||
insertNotification = `
|
||||
INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id)
|
||||
VALUES($1, CURRENT_TIMESTAMP, $2, $3)`
|
||||
VALUES ($1, $2, $3, $4)`
|
||||
|
||||
updatedNotificationNotified = `
|
||||
UPDATE Vulnerability_Notification
|
||||
@ -202,10 +176,10 @@ const (
|
||||
removeNotification = `
|
||||
UPDATE Vulnerability_Notification
|
||||
SET deleted_at = CURRENT_TIMESTAMP
|
||||
WHERE name = $1`
|
||||
WHERE name = $1 AND deleted_at IS NULL`
|
||||
|
||||
searchNotificationAvailable = `
|
||||
SELECT id, name, created_at, notified_at, deleted_at
|
||||
SELECT name, created_at, notified_at, deleted_at
|
||||
FROM Vulnerability_Notification
|
||||
WHERE (notified_at IS NULL OR notified_at < $1)
|
||||
AND deleted_at IS NULL
|
||||
@ -214,43 +188,231 @@ const (
|
||||
LIMIT 1`
|
||||
|
||||
searchNotification = `
|
||||
SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
|
||||
SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
|
||||
FROM Vulnerability_Notification
|
||||
WHERE name = $1`
|
||||
|
||||
searchNotificationLayerIntroducingVulnerability = `
|
||||
WITH LDFV AS (
|
||||
SELECT DISTINCT ldfv.layer_id
|
||||
FROM Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv
|
||||
WHERE ldfv.layer_id >= $2
|
||||
AND vafv.vulnerability_id = $1
|
||||
AND vafv.featureversion_id = fv.id
|
||||
AND ldfv.featureversion_id = fv.id
|
||||
AND ldfv.modification = 'add'
|
||||
ORDER BY ldfv.layer_id
|
||||
)
|
||||
SELECT l.id, l.name
|
||||
FROM LDFV, Layer l
|
||||
WHERE LDFV.layer_id = l.id
|
||||
LIMIT $3`
|
||||
searchNotificationVulnerableAncestry = `
|
||||
SELECT DISTINCT ON (a.id)
|
||||
a.id, a.name
|
||||
FROM vulnerability_affected_namespaced_feature AS vanf,
|
||||
ancestry AS a, ancestry_feature AS af
|
||||
WHERE vanf.vulnerability_id = $1
|
||||
AND a.id >= $2
|
||||
AND a.id = af.ancestry_id
|
||||
AND af.namespaced_feature_id = vanf.namespaced_feature_id
|
||||
ORDER BY a.id ASC
|
||||
LIMIT $3;`
|
||||
|
||||
// complex_test.go
|
||||
searchComplexTestFeatureVersionAffects = `
|
||||
SELECT v.name
|
||||
FROM FeatureVersion fv
|
||||
LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id
|
||||
JOIN Vulnerability v ON vaf.vulnerability_id = v.id
|
||||
WHERE featureversion_id = $1`
|
||||
// ancestry.go
|
||||
persistAncestryLister = `
|
||||
INSERT INTO ancestry_lister (ancestry_id, lister)
|
||||
SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT)
|
||||
WHERE NOT EXISTS (SELECT id FROM ancestry_lister WHERE ancestry_id = $1 AND lister = $2) ON CONFLICT DO NOTHING`
|
||||
|
||||
persistAncestryDetector = `
|
||||
INSERT INTO ancestry_detector (ancestry_id, detector)
|
||||
SELECT CAST ($1 AS INTEGER), CAST ($2 AS TEXT)
|
||||
WHERE NOT EXISTS (SELECT id FROM ancestry_detector WHERE ancestry_id = $1 AND detector = $2) ON CONFLICT DO NOTHING`
|
||||
|
||||
insertAncestry = `INSERT INTO ancestry (name) VALUES ($1) RETURNING id`
|
||||
|
||||
searchAncestryLayer = `
|
||||
SELECT layer.hash
|
||||
FROM layer, ancestry_layer
|
||||
WHERE ancestry_layer.ancestry_id = $1
|
||||
AND ancestry_layer.layer_id = layer.id
|
||||
ORDER BY ancestry_layer.ancestry_index ASC`
|
||||
|
||||
searchAncestryFeatures = `
|
||||
SELECT namespace.name, namespace.version_format, feature.name, feature.version
|
||||
FROM namespace, feature, ancestry, namespaced_feature, ancestry_feature
|
||||
WHERE ancestry.name = $1
|
||||
AND ancestry.id = ancestry_feature.ancestry_id
|
||||
AND ancestry_feature.namespaced_feature_id = namespaced_feature.id
|
||||
AND namespaced_feature.feature_id = feature.id
|
||||
AND namespaced_feature.namespace_id = namespace.id`
|
||||
|
||||
searchAncestry = `SELECT id FROM ancestry WHERE name = $1`
|
||||
searchAncestryDetectors = `SELECT detector FROM ancestry_detector WHERE ancestry_id = $1`
|
||||
searchAncestryListers = `SELECT lister FROM ancestry_lister WHERE ancestry_id = $1`
|
||||
removeAncestry = `DELETE FROM ancestry WHERE name = $1`
|
||||
insertAncestryLayer = `INSERT INTO ancestry_layer(ancestry_id, ancestry_index, layer_id) VALUES($1,$2,$3)`
|
||||
insertAncestryFeature = `INSERT INTO ancestry_feature(ancestry_id, namespaced_feature_id) VALUES ($1, $2)`
|
||||
)
|
||||
|
||||
// buildInputArray constructs a PostgreSQL input array from the specified integers.
|
||||
// Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using
|
||||
// a single placeholder.
|
||||
func buildInputArray(ints []int) string {
|
||||
str := "{"
|
||||
for i := 0; i < len(ints)-1; i++ {
|
||||
str = str + strconv.Itoa(ints[i]) + ","
|
||||
}
|
||||
str = str + strconv.Itoa(ints[len(ints)-1]) + "}"
|
||||
return str
|
||||
// NOTE(Sida): Every search query can only have count less than postgres set
|
||||
// stack depth. IN will be resolved to nested OR_s and the parser might exceed
|
||||
// stack depth. TODO(Sida): Generate different queries for different count: if
|
||||
// count < 5120, use IN; for count > 5120 and < 65536, use temporary table; for
|
||||
// count > 65535, use is expected to split data into batches.
|
||||
func querySearchLastDeletedVulnerabilityID(count int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT vid, vname, nname FROM (
|
||||
SELECT v.id AS vid, v.name AS vname, n.name AS nname,
|
||||
row_number() OVER (
|
||||
PARTITION by (v.name, n.name)
|
||||
ORDER BY v.deleted_at DESC
|
||||
) AS rownum
|
||||
FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id
|
||||
AND (v.name, n.name) IN ( %s )
|
||||
AND v.deleted_at IS NOT NULL
|
||||
) tmp WHERE rownum <= 1`,
|
||||
queryString(2, count))
|
||||
}
|
||||
|
||||
func querySearchNotDeletedVulnerabilityID(count int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT v.id, v.name, n.name FROM vulnerability AS v, namespace AS n
|
||||
WHERE v.namespace_id = n.id AND (v.name, n.name) IN (%s)
|
||||
AND v.deleted_at IS NULL`,
|
||||
queryString(2, count))
|
||||
}
|
||||
|
||||
func querySearchFeatureID(featureCount int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT id, name, version, version_format
|
||||
FROM Feature WHERE (name, version, version_format) IN (%s)`,
|
||||
queryString(3, featureCount),
|
||||
)
|
||||
}
|
||||
|
||||
func querySearchNamespacedFeature(nsfCount int) string {
|
||||
return fmt.Sprintf(`
|
||||
SELECT nf.id, f.name, f.version, f.version_format, n.name
|
||||
FROM namespaced_feature AS nf, feature AS f, namespace AS n
|
||||
WHERE nf.feature_id = f.id
|
||||
AND nf.namespace_id = n.id
|
||||
AND n.version_format = f.version_format
|
||||
AND (f.name, f.version, f.version_format, n.name) IN (%s)`,
|
||||
queryString(4, nsfCount),
|
||||
)
|
||||
}
|
||||
|
||||
func querySearchNamespace(nsCount int) string {
|
||||
return fmt.Sprintf(
|
||||
`SELECT id, name, version_format
|
||||
FROM namespace WHERE (name, version_format) IN (%s)`,
|
||||
queryString(2, nsCount),
|
||||
)
|
||||
}
|
||||
|
||||
func queryInsert(count int, table string, columns ...string) string {
|
||||
base := `INSERT INTO %s (%s) VALUES %s`
|
||||
t := pq.QuoteIdentifier(table)
|
||||
cols := make([]string, len(columns))
|
||||
for i, c := range columns {
|
||||
cols[i] = pq.QuoteIdentifier(c)
|
||||
}
|
||||
colsQuoted := strings.Join(cols, ",")
|
||||
return fmt.Sprintf(base, t, colsQuoted, queryString(len(columns), count))
|
||||
}
|
||||
|
||||
func queryPersist(count int, table, constraint string, columns ...string) string {
|
||||
ct := ""
|
||||
if constraint != "" {
|
||||
ct = fmt.Sprintf("ON CONSTRAINT %s", constraint)
|
||||
}
|
||||
return fmt.Sprintf("%s ON CONFLICT %s DO NOTHING", queryInsert(count, table, columns...), ct)
|
||||
}
|
||||
|
||||
func queryInsertNotifications(count int) string {
|
||||
return queryInsert(count,
|
||||
"vulnerability_notification",
|
||||
"name",
|
||||
"created_at",
|
||||
"old_vulnerability_id",
|
||||
"new_vulnerability_id",
|
||||
)
|
||||
}
|
||||
|
||||
func queryPersistFeature(count int) string {
|
||||
return queryPersist(count,
|
||||
"feature",
|
||||
"feature_name_version_version_format_key",
|
||||
"name",
|
||||
"version",
|
||||
"version_format")
|
||||
}
|
||||
|
||||
func queryPersistLayerFeature(count int) string {
|
||||
return queryPersist(count,
|
||||
"layer_feature",
|
||||
"layer_feature_layer_id_feature_id_key",
|
||||
"layer_id",
|
||||
"feature_id")
|
||||
}
|
||||
|
||||
func queryPersistNamespace(count int) string {
|
||||
return queryPersist(count,
|
||||
"namespace",
|
||||
"namespace_name_version_format_key",
|
||||
"name",
|
||||
"version_format")
|
||||
}
|
||||
|
||||
func queryPersistLayerListers(count int) string {
|
||||
return queryPersist(count,
|
||||
"layer_lister",
|
||||
"layer_lister_layer_id_lister_key",
|
||||
"layer_id",
|
||||
"lister")
|
||||
}
|
||||
|
||||
func queryPersistLayerDetectors(count int) string {
|
||||
return queryPersist(count,
|
||||
"layer_detector",
|
||||
"layer_detector_layer_id_detector_key",
|
||||
"layer_id",
|
||||
"detector")
|
||||
}
|
||||
|
||||
func queryPersistLayerNamespace(count int) string {
|
||||
return queryPersist(count,
|
||||
"layer_namespace",
|
||||
"layer_namespace_layer_id_namespace_id_key",
|
||||
"layer_id",
|
||||
"namespace_id")
|
||||
}
|
||||
|
||||
// size of key and array should be both greater than 0
|
||||
func queryString(keySize, arraySize int) string {
|
||||
if arraySize <= 0 || keySize <= 0 {
|
||||
panic("Bulk Query requires size of element tuple and number of elements to be greater than 0")
|
||||
}
|
||||
keys := make([]string, 0, arraySize)
|
||||
for i := 0; i < arraySize; i++ {
|
||||
key := make([]string, keySize)
|
||||
for j := 0; j < keySize; j++ {
|
||||
key[j] = fmt.Sprintf("$%d", i*keySize+j+1)
|
||||
}
|
||||
keys = append(keys, fmt.Sprintf("(%s)", strings.Join(key, ",")))
|
||||
}
|
||||
return strings.Join(keys, ",")
|
||||
}
|
||||
|
||||
func queryPersistNamespacedFeature(count int) string {
|
||||
return queryPersist(count, "namespaced_feature",
|
||||
"namespaced_feature_namespace_id_feature_id_key",
|
||||
"feature_id",
|
||||
"namespace_id")
|
||||
}
|
||||
|
||||
func queryPersistVulnerabilityAffectedNamespacedFeature(count int) string {
|
||||
return queryPersist(count, "vulnerability_affected_namespaced_feature",
|
||||
"vulnerability_affected_namesp_vulnerability_id_namespaced_f_key",
|
||||
"vulnerability_id",
|
||||
"namespaced_feature_id",
|
||||
"added_by")
|
||||
}
|
||||
|
||||
func queryPersistLayer(count int) string {
|
||||
return queryPersist(count, "layer", "", "hash")
|
||||
}
|
||||
|
||||
func queryInvalidateVulnerabilityCache(count int) string {
|
||||
return fmt.Sprintf(`DELETE FROM vulnerability_affected_feature
|
||||
WHERE vulnerability_id = (%s)`,
|
||||
queryString(1, count))
|
||||
}
|
||||
|
148
database/pgsql/testdata/data.sql
vendored
148
database/pgsql/testdata/data.sql
vendored
@ -1,73 +1,117 @@
|
||||
-- Copyright 2015 clair authors
|
||||
--
|
||||
-- Licensed under the Apache License, Version 2.0 (the "License");
|
||||
-- you may not use this file except in compliance with the License.
|
||||
-- You may obtain a copy of the License at
|
||||
--
|
||||
-- http://www.apache.org/licenses/LICENSE-2.0
|
||||
--
|
||||
-- Unless required by applicable law or agreed to in writing, software
|
||||
-- distributed under the License is distributed on an "AS IS" BASIS,
|
||||
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
-- See the License for the specific language governing permissions and
|
||||
-- limitations under the License.
|
||||
|
||||
INSERT INTO namespace (id, name, version_format) VALUES
|
||||
(1, 'debian:7', 'dpkg'),
|
||||
(2, 'debian:8', 'dpkg');
|
||||
(1, 'debian:7', 'dpkg'),
|
||||
(2, 'debian:8', 'dpkg'),
|
||||
(3, 'fake:1.0', 'rpm');
|
||||
|
||||
INSERT INTO feature (id, namespace_id, name) VALUES
|
||||
(1, 1, 'wechat'),
|
||||
(2, 1, 'openssl'),
|
||||
(4, 1, 'libssl'),
|
||||
(3, 2, 'openssl');
|
||||
INSERT INTO feature (id, name, version, version_format) VALUES
|
||||
(1, 'wechat', '0.5', 'dpkg'),
|
||||
(2, 'openssl', '1.0', 'dpkg'),
|
||||
(3, 'openssl', '2.0', 'dpkg'),
|
||||
(4, 'fake', '2.0', 'rpm');
|
||||
|
||||
INSERT INTO featureversion (id, feature_id, version) VALUES
|
||||
(1, 1, '0.5'),
|
||||
(2, 2, '1.0'),
|
||||
(3, 2, '2.0'),
|
||||
(4, 3, '1.0');
|
||||
INSERT INTO layer (id, hash) VALUES
|
||||
(1, 'layer-0'), -- blank
|
||||
(2, 'layer-1'), -- debian:7; wechat 0.5, openssl 1.0
|
||||
(3, 'layer-2'), -- debian:7; wechat 0.5, openssl 2.0
|
||||
(4, 'layer-3a'),-- debian:7;
|
||||
(5, 'layer-3b'),-- debian:8; wechat 0.5, openssl 1.0
|
||||
(6, 'layer-4'); -- debian:7, fake:1.0; openssl 2.0 (debian), fake 2.0 (fake)
|
||||
|
||||
INSERT INTO layer (id, name, engineversion, parent_id) VALUES
|
||||
(1, 'layer-0', 1, NULL),
|
||||
(2, 'layer-1', 1, 1),
|
||||
(3, 'layer-2', 1, 2),
|
||||
(4, 'layer-3a', 1, 3),
|
||||
(5, 'layer-3b', 1, 3);
|
||||
|
||||
INSERT INTO layer_namespace (id, layer_id, namespace_id) VALUES
|
||||
INSERT INTO layer_namespace(id, layer_id, namespace_id) VALUES
|
||||
(1, 2, 1),
|
||||
(2, 3, 1),
|
||||
(3, 4, 1),
|
||||
(4, 5, 2),
|
||||
(5, 5, 1);
|
||||
(5, 6, 1),
|
||||
(6, 6, 3);
|
||||
|
||||
INSERT INTO layer_diff_featureversion (id, layer_id, featureversion_id, modification) VALUES
|
||||
(1, 2, 1, 'add'),
|
||||
(2, 2, 2, 'add'),
|
||||
(3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0
|
||||
(4, 3, 3, 'add'), -- ^
|
||||
(5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0
|
||||
(6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0
|
||||
INSERT INTO layer_feature(id, layer_id, feature_id) VALUES
|
||||
(1, 2, 1),
|
||||
(2, 2, 2),
|
||||
(3, 3, 1),
|
||||
(4, 3, 3),
|
||||
(5, 5, 1),
|
||||
(6, 5, 2),
|
||||
(7, 6, 4),
|
||||
(8, 6, 3);
|
||||
|
||||
INSERT INTO layer_lister(id, layer_id, lister) VALUES
|
||||
(1, 1, 'dpkg'),
|
||||
(2, 2, 'dpkg'),
|
||||
(3, 3, 'dpkg'),
|
||||
(4, 4, 'dpkg'),
|
||||
(5, 5, 'dpkg'),
|
||||
(6, 6, 'dpkg'),
|
||||
(7, 6, 'rpm');
|
||||
|
||||
INSERT INTO layer_detector(id, layer_id, detector) VALUES
|
||||
(1, 1, 'os-release'),
|
||||
(2, 2, 'os-release'),
|
||||
(3, 3, 'os-release'),
|
||||
(4, 4, 'os-release'),
|
||||
(5, 5, 'os-release'),
|
||||
(6, 6, 'os-release'),
|
||||
(7, 6, 'apt-sources');
|
||||
|
||||
INSERT INTO ancestry (id, name) VALUES
|
||||
(1, 'ancestry-1'), -- layer-0, layer-1, layer-2, layer-3a
|
||||
(2, 'ancestry-2'), -- layer-0, layer-1, layer-2, layer-3b
|
||||
(3, 'ancestry-3'), -- empty; just for testing the vulnerable ancestry
|
||||
(4, 'ancestry-4'); -- empty; just for testing the vulnerable ancestry
|
||||
|
||||
INSERT INTO ancestry_lister (id, ancestry_id, lister) VALUES
|
||||
(1, 1, 'dpkg'),
|
||||
(2, 2, 'dpkg');
|
||||
|
||||
INSERT INTO ancestry_detector (id, ancestry_id, detector) VALUES
|
||||
(1, 1, 'os-release'),
|
||||
(2, 2, 'os-release');
|
||||
|
||||
INSERT INTO ancestry_layer (id, ancestry_id, layer_id, ancestry_index) VALUES
|
||||
(1, 1, 1, 0),(2, 1, 2, 1),(3, 1, 3, 2),(4, 1, 4, 3),
|
||||
(5, 2, 1, 0),(6, 2, 2, 1),(7, 2, 3, 2),(8, 2, 5, 3);
|
||||
|
||||
INSERT INTO namespaced_feature(id, feature_id, namespace_id) VALUES
|
||||
(1, 1, 1), -- wechat 0.5, debian:7
|
||||
(2, 2, 1), -- openssl 1.0, debian:7
|
||||
(3, 2, 2), -- openssl 1.0, debian:8
|
||||
(4, 3, 1); -- openssl 2.0, debian:7
|
||||
|
||||
INSERT INTO ancestry_feature (id, ancestry_id, namespaced_feature_id) VALUES
|
||||
(1, 1, 1), (2, 1, 4),
|
||||
(3, 2, 1), (4, 2, 3),
|
||||
(5, 3, 2), (6, 4, 2); -- assume that ancestry-3 and ancestry-4 are vulnerable.
|
||||
|
||||
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity) VALUES
|
||||
(1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'),
|
||||
(2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown');
|
||||
|
||||
INSERT INTO vulnerability_fixedin_feature (id, vulnerability_id, feature_id, version) VALUES
|
||||
(1, 1, 2, '2.0'),
|
||||
(2, 1, 4, '1.9-abc');
|
||||
INSERT INTO vulnerability (id, namespace_id, name, description, link, severity, deleted_at) VALUES
|
||||
(3, 1, 'CVE-DELETED', '', '', 'Unknown', '2017-08-08 17:49:31.668483');
|
||||
|
||||
INSERT INTO vulnerability_affected_feature(id, vulnerability_id, feature_name, affected_version, fixedin) VALUES
|
||||
(1, 1, 'openssl', '2.0', '2.0'),
|
||||
(2, 1, 'libssl', '1.9-abc', '1.9-abc');
|
||||
|
||||
INSERT INTO vulnerability_affects_featureversion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES
|
||||
(1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0
|
||||
INSERT INTO vulnerability_affected_namespaced_feature(id, vulnerability_id, namespaced_feature_id, added_by) VALUES
|
||||
(1, 1, 2, 1);
|
||||
|
||||
INSERT INTO vulnerability_notification(id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id) VALUES
|
||||
(1, 'test', NULL, NULL, NULL, 2, 1); -- 'CVE-NOPE' -> 'CVE-OPENSSL-1-DEB7'
|
||||
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('namespace', 'id'), (SELECT MAX(id) FROM namespace)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry', 'id'), (SELECT MAX(id) FROM ancestry)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_layer', 'id'), (SELECT MAX(id) FROM ancestry_layer)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_feature', 'id'), (SELECT MAX(id) FROM ancestry_feature)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_detector', 'id'), (SELECT MAX(id) FROM ancestry_detector)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('ancestry_lister', 'id'), (SELECT MAX(id) FROM ancestry_lister)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('feature', 'id'), (SELECT MAX(id) FROM feature)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('featureversion', 'id'), (SELECT MAX(id) FROM featureversion)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('namespaced_feature', 'id'), (SELECT MAX(id) FROM namespaced_feature)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('layer', 'id'), (SELECT MAX(id) FROM layer)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_namespace', 'id'), (SELECT MAX(id) FROM layer_namespace)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_diff_featureversion', 'id'), (SELECT MAX(id) FROM layer_diff_featureversion)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_fixedin_feature', 'id'), (SELECT MAX(id) FROM vulnerability_fixedin_feature)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affects_featureversion', 'id'), (SELECT MAX(id) FROM vulnerability_affects_featureversion)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_detector', 'id'), (SELECT MAX(id) FROM layer_detector)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('layer_lister', 'id'), (SELECT MAX(id) FROM layer_lister)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability', 'id'), (SELECT MAX(id) FROM vulnerability)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_feature)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_affected_namespaced_feature', 'id'), (SELECT MAX(id) FROM vulnerability_affected_namespaced_feature)+1);
|
||||
SELECT pg_catalog.setval(pg_get_serial_sequence('vulnerability_notification', 'id'), (SELECT MAX(id) FROM vulnerability_notification)+1);
|
||||
|
@ -17,352 +17,207 @@ package pgsql
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/guregu/null/zero"
|
||||
"github.com/lib/pq"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
// compareStringLists returns the strings that are present in X but not in Y.
|
||||
func compareStringLists(X, Y []string) []string {
|
||||
m := make(map[string]bool)
|
||||
var (
|
||||
errVulnerabilityNotFound = errors.New("vulnerability is not in database")
|
||||
)
|
||||
|
||||
for _, y := range Y {
|
||||
m[y] = true
|
||||
}
|
||||
|
||||
diff := []string{}
|
||||
for _, x := range X {
|
||||
if m[x] {
|
||||
continue
|
||||
}
|
||||
|
||||
diff = append(diff, x)
|
||||
m[x] = true
|
||||
}
|
||||
|
||||
return diff
|
||||
type affectedAncestry struct {
|
||||
name string
|
||||
id int64
|
||||
}
|
||||
|
||||
func compareStringListsInBoth(X, Y []string) []string {
|
||||
m := make(map[string]struct{})
|
||||
|
||||
for _, y := range Y {
|
||||
m[y] = struct{}{}
|
||||
}
|
||||
|
||||
diff := []string{}
|
||||
for _, x := range X {
|
||||
if _, e := m[x]; e {
|
||||
diff = append(diff, x)
|
||||
delete(m, x)
|
||||
}
|
||||
}
|
||||
|
||||
return diff
|
||||
type affectRelation struct {
|
||||
vulnerabilityID int64
|
||||
namespacedFeatureID int64
|
||||
addedBy int64
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) {
|
||||
defer observeQueryTime("listVulnerabilities", "all", time.Now())
|
||||
type affectedFeatureRows struct {
|
||||
rows map[int64]database.AffectedFeature
|
||||
}
|
||||
|
||||
// Query Namespace.
|
||||
var id int
|
||||
err := pgSQL.QueryRow(searchNamespace, namespaceName).Scan(&id)
|
||||
func (tx *pgSession) FindVulnerabilities(vulnerabilities []database.VulnerabilityID) ([]database.NullableVulnerability, error) {
|
||||
defer observeQueryTime("findVulnerabilities", "", time.Now())
|
||||
resultVuln := make([]database.NullableVulnerability, len(vulnerabilities))
|
||||
vulnIDMap := map[int64][]*database.NullableVulnerability{}
|
||||
|
||||
//TODO(Sida): Change to bulk search.
|
||||
stmt, err := tx.Prepare(searchVulnerability)
|
||||
if err != nil {
|
||||
return nil, -1, handleError("searchNamespace", err)
|
||||
} else if id == 0 {
|
||||
return nil, -1, commonerr.ErrNotFound
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Query.
|
||||
query := searchVulnerabilityBase + searchVulnerabilityByNamespace
|
||||
rows, err := pgSQL.Query(query, namespaceName, startID, limit+1)
|
||||
if err != nil {
|
||||
return nil, -1, handleError("searchVulnerabilityByNamespace", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var vulns []database.Vulnerability
|
||||
nextID := -1
|
||||
size := 0
|
||||
// Scan query.
|
||||
for rows.Next() {
|
||||
var vulnerability database.Vulnerability
|
||||
|
||||
err := rows.Scan(
|
||||
&vulnerability.ID,
|
||||
&vulnerability.Name,
|
||||
&vulnerability.Namespace.ID,
|
||||
&vulnerability.Namespace.Name,
|
||||
&vulnerability.Namespace.VersionFormat,
|
||||
&vulnerability.Description,
|
||||
&vulnerability.Link,
|
||||
&vulnerability.Severity,
|
||||
&vulnerability.Metadata,
|
||||
// load vulnerabilities
|
||||
for i, key := range vulnerabilities {
|
||||
var (
|
||||
id sql.NullInt64
|
||||
vuln = database.NullableVulnerability{
|
||||
VulnerabilityWithAffected: database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: key.Name,
|
||||
Namespace: database.Namespace{
|
||||
Name: key.Namespace,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
err := stmt.QueryRow(key.Name, key.Namespace).Scan(
|
||||
&id,
|
||||
&vuln.Description,
|
||||
&vuln.Link,
|
||||
&vuln.Severity,
|
||||
&vuln.Metadata,
|
||||
&vuln.Namespace.VersionFormat,
|
||||
)
|
||||
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
stmt.Close()
|
||||
return nil, handleError("searchVulnerability", err)
|
||||
}
|
||||
vuln.Valid = id.Valid
|
||||
resultVuln[i] = vuln
|
||||
if id.Valid {
|
||||
vulnIDMap[id.Int64] = append(vulnIDMap[id.Int64], &resultVuln[i])
|
||||
}
|
||||
}
|
||||
|
||||
if err := stmt.Close(); err != nil {
|
||||
return nil, handleError("searchVulnerability", err)
|
||||
}
|
||||
|
||||
toQuery := make([]int64, 0, len(vulnIDMap))
|
||||
for id := range vulnIDMap {
|
||||
toQuery = append(toQuery, id)
|
||||
}
|
||||
|
||||
// load vulnerability affected features
|
||||
rows, err := tx.Query(searchVulnerabilityAffected, pq.Array(toQuery))
|
||||
if err != nil {
|
||||
return nil, handleError("searchVulnerabilityAffected", err)
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
id int64
|
||||
f database.AffectedFeature
|
||||
)
|
||||
|
||||
err := rows.Scan(&id, &f.FeatureName, &f.AffectedVersion, &f.FixedInVersion)
|
||||
if err != nil {
|
||||
return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err)
|
||||
return nil, handleError("searchVulnerabilityAffected", err)
|
||||
}
|
||||
size++
|
||||
if size > limit {
|
||||
nextID = vulnerability.ID
|
||||
} else {
|
||||
vulns = append(vulns, vulnerability)
|
||||
|
||||
for _, vuln := range vulnIDMap[id] {
|
||||
f.Namespace = vuln.Namespace
|
||||
vuln.Affected = append(vuln.Affected, f)
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err)
|
||||
return resultVuln, nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) InsertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) error {
|
||||
defer observeQueryTime("insertVulnerabilities", "all", time.Now())
|
||||
// bulk insert vulnerabilities
|
||||
vulnIDs, err := tx.insertVulnerabilities(vulnerabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return vulns, nextID, nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) {
|
||||
return findVulnerability(pgSQL, namespaceName, name, false)
|
||||
}
|
||||
|
||||
func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) {
|
||||
defer observeQueryTime("findVulnerability", "all", time.Now())
|
||||
|
||||
queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName"
|
||||
query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName
|
||||
if forUpdate {
|
||||
queryName = queryName + "+searchVulnerabilityForUpdate"
|
||||
query = query + searchVulnerabilityForUpdate
|
||||
// bulk insert vulnerability affected features
|
||||
vulnFeatureMap, err := tx.insertVulnerabilityAffected(vulnIDs, vulnerabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name))
|
||||
return tx.cacheVulnerabiltyAffectedNamespacedFeature(vulnFeatureMap)
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) {
|
||||
defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now())
|
||||
|
||||
queryName := "searchVulnerabilityBase+searchVulnerabilityByID"
|
||||
query := searchVulnerabilityBase + searchVulnerabilityByID
|
||||
|
||||
return scanVulnerability(pgSQL, queryName, pgSQL.QueryRow(query, id))
|
||||
}
|
||||
|
||||
func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) {
|
||||
var vulnerability database.Vulnerability
|
||||
|
||||
err := vulnerabilityRow.Scan(
|
||||
&vulnerability.ID,
|
||||
&vulnerability.Name,
|
||||
&vulnerability.Namespace.ID,
|
||||
&vulnerability.Namespace.Name,
|
||||
&vulnerability.Namespace.VersionFormat,
|
||||
&vulnerability.Description,
|
||||
&vulnerability.Link,
|
||||
&vulnerability.Severity,
|
||||
&vulnerability.Metadata,
|
||||
// insertVulnerabilityAffected inserts a set of vulnerability affected features for each vulnerability provided.
|
||||
//
|
||||
// i_th vulnerabilityIDs corresponds to i_th vulnerabilities provided.
|
||||
func (tx *pgSession) insertVulnerabilityAffected(vulnerabilityIDs []int64, vulnerabilities []database.VulnerabilityWithAffected) (map[int64]affectedFeatureRows, error) {
|
||||
var (
|
||||
vulnFeature = map[int64]affectedFeatureRows{}
|
||||
affectedID int64
|
||||
)
|
||||
|
||||
//TODO(Sida): Change to bulk insert.
|
||||
stmt, err := tx.Prepare(insertVulnerabilityAffected)
|
||||
if err != nil {
|
||||
return vulnerability, handleError(queryName+".Scan()", err)
|
||||
return nil, handleError("insertVulnerabilityAffected", err)
|
||||
}
|
||||
|
||||
if vulnerability.ID == 0 {
|
||||
return vulnerability, commonerr.ErrNotFound
|
||||
}
|
||||
|
||||
// Query the FixedIn FeatureVersion now.
|
||||
rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID)
|
||||
if err != nil {
|
||||
return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var featureVersionID zero.Int
|
||||
var featureVersionVersion zero.String
|
||||
var featureVersionFeatureName zero.String
|
||||
|
||||
err := rows.Scan(
|
||||
&featureVersionVersion,
|
||||
&featureVersionID,
|
||||
&featureVersionFeatureName,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
|
||||
}
|
||||
|
||||
if !featureVersionID.IsZero() {
|
||||
// Note that the ID we fill in featureVersion is actually a Feature ID, and not
|
||||
// a FeatureVersion ID.
|
||||
featureVersion := database.FeatureVersion{
|
||||
Model: database.Model{ID: int(featureVersionID.Int64)},
|
||||
Feature: database.Feature{
|
||||
Model: database.Model{ID: int(featureVersionID.Int64)},
|
||||
Namespace: vulnerability.Namespace,
|
||||
Name: featureVersionFeatureName.String,
|
||||
},
|
||||
Version: featureVersionVersion.String,
|
||||
defer stmt.Close()
|
||||
for i, vuln := range vulnerabilities {
|
||||
// affected feature row ID -> affected feature
|
||||
affectedFeatures := map[int64]database.AffectedFeature{}
|
||||
for _, f := range vuln.Affected {
|
||||
err := stmt.QueryRow(vulnerabilityIDs[i], f.FeatureName, f.AffectedVersion, f.FixedInVersion).Scan(&affectedID)
|
||||
if err != nil {
|
||||
return nil, handleError("insertVulnerabilityAffected", err)
|
||||
}
|
||||
vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion)
|
||||
affectedFeatures[affectedID] = f
|
||||
}
|
||||
vulnFeature[vulnerabilityIDs[i]] = affectedFeatureRows{rows: affectedFeatures}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err)
|
||||
}
|
||||
|
||||
return vulnerability, nil
|
||||
return vulnFeature, nil
|
||||
}
|
||||
|
||||
// FixedIn.Namespace are not necessary, they are overwritten by the vuln.
|
||||
// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore.
|
||||
func (pgSQL *pgSQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error {
|
||||
for _, vulnerability := range vulnerabilities {
|
||||
err := pgSQL.insertVulnerability(vulnerability, false, generateNotifications)
|
||||
// insertVulnerabilities inserts a set of unique vulnerabilities into database,
|
||||
// under the assumption that all vulnerabilities are valid.
|
||||
func (tx *pgSession) insertVulnerabilities(vulnerabilities []database.VulnerabilityWithAffected) ([]int64, error) {
|
||||
var (
|
||||
vulnID int64
|
||||
vulnIDs = make([]int64, 0, len(vulnerabilities))
|
||||
vulnMap = map[database.VulnerabilityID]struct{}{}
|
||||
)
|
||||
|
||||
for _, v := range vulnerabilities {
|
||||
key := database.VulnerabilityID{
|
||||
Name: v.Name,
|
||||
Namespace: v.Namespace.Name,
|
||||
}
|
||||
|
||||
// Ensure uniqueness of vulnerability IDs
|
||||
if _, ok := vulnMap[key]; ok {
|
||||
return nil, errors.New("inserting duplicated vulnerabilities is not allowed")
|
||||
}
|
||||
vulnMap[key] = struct{}{}
|
||||
}
|
||||
|
||||
//TODO(Sida): Change to bulk insert.
|
||||
stmt, err := tx.Prepare(insertVulnerability)
|
||||
if err != nil {
|
||||
return nil, handleError("insertVulnerability", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
for _, vuln := range vulnerabilities {
|
||||
err := stmt.QueryRow(vuln.Name, vuln.Description,
|
||||
vuln.Link, &vuln.Severity, &vuln.Metadata,
|
||||
vuln.Namespace.Name, vuln.Namespace.VersionFormat).Scan(&vulnID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error {
|
||||
tf := time.Now()
|
||||
|
||||
// Verify parameters
|
||||
if vulnerability.Name == "" || vulnerability.Namespace.Name == "" {
|
||||
return commonerr.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace")
|
||||
}
|
||||
|
||||
for i := 0; i < len(vulnerability.FixedIn); i++ {
|
||||
fifv := &vulnerability.FixedIn[i]
|
||||
|
||||
if fifv.Feature.Namespace.Name == "" {
|
||||
// As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's
|
||||
// Namespace.
|
||||
fifv.Feature.Namespace = vulnerability.Namespace
|
||||
} else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name {
|
||||
msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability"
|
||||
log.Warning(msg)
|
||||
return commonerr.NewBadRequestError(msg)
|
||||
}
|
||||
}
|
||||
|
||||
// We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities.
|
||||
defer observeQueryTime("insertVulnerability", "all", tf)
|
||||
|
||||
// Begin transaction.
|
||||
tx, err := pgSQL.Begin()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("insertVulnerability.Begin()", err)
|
||||
}
|
||||
|
||||
// Find existing vulnerability and its Vulnerability_FixedIn_Features (for update).
|
||||
existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, true)
|
||||
if err != nil && err != commonerr.ErrNotFound {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if onlyFixedIn {
|
||||
// Because this call tries to update FixedIn FeatureVersion, import all other data from the
|
||||
// existing one.
|
||||
if existingVulnerability.ID == 0 {
|
||||
return commonerr.ErrNotFound
|
||||
return nil, handleError("insertVulnerability", err)
|
||||
}
|
||||
|
||||
fixedIn := vulnerability.FixedIn
|
||||
vulnerability = existingVulnerability
|
||||
vulnerability.FixedIn = fixedIn
|
||||
vulnIDs = append(vulnIDs, vulnID)
|
||||
}
|
||||
|
||||
if existingVulnerability.ID != 0 {
|
||||
updateMetadata := vulnerability.Description != existingVulnerability.Description ||
|
||||
vulnerability.Link != existingVulnerability.Link ||
|
||||
vulnerability.Severity != existingVulnerability.Severity ||
|
||||
!reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata)
|
||||
|
||||
// Construct the entire list of FixedIn FeatureVersion, by using the
|
||||
// the FixedIn list of the old vulnerability.
|
||||
//
|
||||
// TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the
|
||||
// existing vulnerability in order to make metadata updates much faster.
|
||||
var updateFixedIn bool
|
||||
vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn)
|
||||
|
||||
if !updateMetadata && !updateFixedIn {
|
||||
tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Mark the old vulnerability as non latest.
|
||||
_, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("removeVulnerability", err)
|
||||
}
|
||||
} else {
|
||||
// The vulnerability is new, we don't want to have any
|
||||
// versionfmt.MinVersion as they are only used for diffing existing
|
||||
// vulnerabilities.
|
||||
var fixedIn []database.FeatureVersion
|
||||
for _, fv := range vulnerability.FixedIn {
|
||||
if fv.Version != versionfmt.MinVersion {
|
||||
fixedIn = append(fixedIn, fv)
|
||||
}
|
||||
}
|
||||
vulnerability.FixedIn = fixedIn
|
||||
}
|
||||
|
||||
// Find or insert Vulnerability's Namespace.
|
||||
namespaceID, err := pgSQL.insertNamespace(vulnerability.Namespace)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Insert vulnerability.
|
||||
err = tx.QueryRow(
|
||||
insertVulnerability,
|
||||
namespaceID,
|
||||
vulnerability.Name,
|
||||
vulnerability.Description,
|
||||
vulnerability.Link,
|
||||
&vulnerability.Severity,
|
||||
&vulnerability.Metadata,
|
||||
).Scan(&vulnerability.ID)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("insertVulnerability", err)
|
||||
}
|
||||
|
||||
// Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now.
|
||||
err = pgSQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a notification.
|
||||
if generateNotification {
|
||||
err = createNotification(tx, existingVulnerability.ID, vulnerability.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Commit transaction.
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("insertVulnerability.Commit()", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return vulnIDs, nil
|
||||
}
|
||||
|
||||
// castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that
|
||||
@ -376,241 +231,208 @@ func castMetadata(m database.MetadataMap) database.MetadataMap {
|
||||
return c
|
||||
}
|
||||
|
||||
// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result.
|
||||
func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) {
|
||||
currentMap, currentNames := createFeatureVersionNameMap(currentList)
|
||||
diffMap, diffNames := createFeatureVersionNameMap(diff)
|
||||
|
||||
addedNames := compareStringLists(diffNames, currentNames)
|
||||
inBothNames := compareStringListsInBoth(diffNames, currentNames)
|
||||
|
||||
different := false
|
||||
|
||||
for _, name := range addedNames {
|
||||
if diffMap[name].Version == versionfmt.MinVersion {
|
||||
// MinVersion only makes sense when a Feature is already fixed in some version,
|
||||
// in which case we would be in the "inBothNames".
|
||||
continue
|
||||
}
|
||||
|
||||
currentMap[name] = diffMap[name]
|
||||
different = true
|
||||
}
|
||||
|
||||
for _, name := range inBothNames {
|
||||
fv := diffMap[name]
|
||||
|
||||
if fv.Version == versionfmt.MinVersion {
|
||||
// MinVersion means that the Feature doesn't affect the Vulnerability anymore.
|
||||
delete(currentMap, name)
|
||||
different = true
|
||||
} else if fv.Version != currentMap[name].Version {
|
||||
// The version got updated.
|
||||
currentMap[name] = diffMap[name]
|
||||
different = true
|
||||
}
|
||||
}
|
||||
|
||||
// Convert currentMap to a slice and return it.
|
||||
var newList []database.FeatureVersion
|
||||
for _, fv := range currentMap {
|
||||
newList = append(newList, fv)
|
||||
}
|
||||
|
||||
return newList, different
|
||||
}
|
||||
|
||||
func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) {
|
||||
m := make(map[string]database.FeatureVersion, 0)
|
||||
s := make([]string, 0, len(features))
|
||||
|
||||
for i := 0; i < len(features); i++ {
|
||||
featureVersion := features[i]
|
||||
m[featureVersion.Feature.Name] = featureVersion
|
||||
s = append(s, featureVersion.Feature.Name)
|
||||
}
|
||||
|
||||
return m, s
|
||||
}
|
||||
|
||||
// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given
|
||||
// vulnerability with the specified database.FeatureVersion list and uses
|
||||
// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to
|
||||
// Vulnerability_Affects_FeatureVersion.
|
||||
func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error {
|
||||
defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now())
|
||||
|
||||
// Insert or find the Features.
|
||||
// TODO(Quentin-M): Batch me.
|
||||
var err error
|
||||
var features []*database.Feature
|
||||
for i := 0; i < len(fixedIn); i++ {
|
||||
features = append(features, &fixedIn[i].Feature)
|
||||
}
|
||||
for _, feature := range features {
|
||||
if feature.ID == 0 {
|
||||
if feature.ID, err = pgSQL.insertFeature(*feature); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lock Vulnerability_Affects_FeatureVersion exclusively.
|
||||
// We want to prevent InsertFeatureVersion to modify it.
|
||||
promConcurrentLockVAFV.Inc()
|
||||
defer promConcurrentLockVAFV.Dec()
|
||||
t := time.Now()
|
||||
_, err = tx.Exec(lockVulnerabilityAffects)
|
||||
observeQueryTime("insertVulnerability", "lock", t)
|
||||
|
||||
func (tx *pgSession) lockFeatureVulnerabilityCache() error {
|
||||
_, err := tx.Exec(lockVulnerabilityAffects)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("insertVulnerability.lockVulnerabilityAffects", err)
|
||||
return handleError("lockVulnerabilityAffects", err)
|
||||
}
|
||||
|
||||
for _, fv := range fixedIn {
|
||||
var fixedInID int
|
||||
var created bool
|
||||
|
||||
// Find or create entry in Vulnerability_FixedIn_Feature.
|
||||
err = tx.QueryRow(
|
||||
soiVulnerabilityFixedInFeature,
|
||||
vulnerabilityID, fv.Feature.ID,
|
||||
&fv.Version,
|
||||
).Scan(&created, &fixedInID)
|
||||
|
||||
if err != nil {
|
||||
return handleError("insertVulnerabilityFixedInFeature", err)
|
||||
}
|
||||
|
||||
if !created {
|
||||
// The relationship between the feature and the vulnerability already
|
||||
// existed, no need to update Vulnerability_Affects_FeatureVersion.
|
||||
continue
|
||||
}
|
||||
|
||||
// Insert Vulnerability_Affects_FeatureVersion.
|
||||
err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Feature.Namespace.VersionFormat, fv.Version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, versionFormat, fixedInVersion string) error {
|
||||
// Find every FeatureVersions of the Feature that the vulnerability affects.
|
||||
// TODO(Quentin-M): LIMIT
|
||||
rows, err := tx.Query(searchFeatureVersionByFeature, featureID)
|
||||
if err != nil {
|
||||
return handleError("searchFeatureVersionByFeature", err)
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var affecteds []database.FeatureVersion
|
||||
for rows.Next() {
|
||||
var affected database.FeatureVersion
|
||||
|
||||
err := rows.Scan(&affected.ID, &affected.Version)
|
||||
if err != nil {
|
||||
return handleError("searchFeatureVersionByFeature.Scan()", err)
|
||||
}
|
||||
|
||||
cmp, err := versionfmt.Compare(versionFormat, affected.Version, fixedInVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if cmp < 0 {
|
||||
// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
|
||||
// thus, this FeatureVersion is affected by it.
|
||||
affecteds = append(affecteds, affected)
|
||||
}
|
||||
}
|
||||
if err = rows.Err(); err != nil {
|
||||
return handleError("searchFeatureVersionByFeature.Rows()", err)
|
||||
}
|
||||
rows.Close()
|
||||
|
||||
// Insert into Vulnerability_Affects_FeatureVersion.
|
||||
for _, affected := range affecteds {
|
||||
// TODO(Quentin-M): Batch me.
|
||||
_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, affected.ID, fixedInID)
|
||||
if err != nil {
|
||||
return handleError("insertVulnerabilityAffectsFeatureVersion", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error {
|
||||
defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now())
|
||||
|
||||
v := database.Vulnerability{
|
||||
Name: vulnerabilityName,
|
||||
Namespace: database.Namespace{
|
||||
Name: vulnerabilityNamespace,
|
||||
},
|
||||
FixedIn: fixes,
|
||||
}
|
||||
|
||||
return pgSQL.insertVulnerability(v, true, true)
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error {
|
||||
defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now())
|
||||
|
||||
v := database.Vulnerability{
|
||||
Name: vulnerabilityName,
|
||||
Namespace: database.Namespace{
|
||||
Name: vulnerabilityNamespace,
|
||||
},
|
||||
FixedIn: []database.FeatureVersion{
|
||||
{
|
||||
Feature: database.Feature{
|
||||
Name: featureName,
|
||||
Namespace: database.Namespace{
|
||||
Name: vulnerabilityNamespace,
|
||||
},
|
||||
},
|
||||
Version: versionfmt.MinVersion,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
return pgSQL.insertVulnerability(v, true, true)
|
||||
}
|
||||
|
||||
func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error {
|
||||
defer observeQueryTime("DeleteVulnerability", "all", time.Now())
|
||||
|
||||
// Begin transaction.
|
||||
tx, err := pgSQL.Begin()
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("DeleteVulnerability.Begin()", err)
|
||||
}
|
||||
|
||||
var vulnerabilityID int
|
||||
err = tx.QueryRow(removeVulnerability, namespaceName, name).Scan(&vulnerabilityID)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("removeVulnerability", err)
|
||||
}
|
||||
|
||||
// Create a notification.
|
||||
err = createNotification(tx, vulnerabilityID, 0)
|
||||
// cacheVulnerabiltyAffectedNamespacedFeature takes in a map of vulnerability ID
|
||||
// to affected feature rows and caches them.
|
||||
func (tx *pgSession) cacheVulnerabiltyAffectedNamespacedFeature(affected map[int64]affectedFeatureRows) error {
|
||||
// Prevent InsertNamespacedFeatures to modify it.
|
||||
err := tx.lockFeatureVulnerabilityCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Commit transaction.
|
||||
err = tx.Commit()
|
||||
vulnIDs := []int64{}
|
||||
for id := range affected {
|
||||
vulnIDs = append(vulnIDs, id)
|
||||
}
|
||||
|
||||
rows, err := tx.Query(searchVulnerabilityPotentialAffected, pq.Array(vulnIDs))
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return handleError("DeleteVulnerability.Commit()", err)
|
||||
return handleError("searchVulnerabilityPotentialAffected", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
relation := []affectRelation{}
|
||||
for rows.Next() {
|
||||
var (
|
||||
vulnID int64
|
||||
nsfID int64
|
||||
fVersion string
|
||||
addedBy int64
|
||||
)
|
||||
|
||||
err := rows.Scan(&vulnID, &nsfID, &fVersion, &addedBy)
|
||||
if err != nil {
|
||||
return handleError("searchVulnerabilityPotentialAffected", err)
|
||||
}
|
||||
|
||||
candidate, ok := affected[vulnID].rows[addedBy]
|
||||
|
||||
if !ok {
|
||||
return errors.New("vulnerability affected feature not found")
|
||||
}
|
||||
|
||||
if in, err := versionfmt.InRange(candidate.Namespace.VersionFormat,
|
||||
fVersion,
|
||||
candidate.AffectedVersion); err == nil {
|
||||
if in {
|
||||
relation = append(relation,
|
||||
affectRelation{
|
||||
vulnerabilityID: vulnID,
|
||||
namespacedFeatureID: nsfID,
|
||||
addedBy: addedBy,
|
||||
})
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
//TODO(Sida): Change to bulk insert.
|
||||
for _, r := range relation {
|
||||
result, err := tx.Exec(insertVulnerabilityAffectedNamespacedFeature, r.vulnerabilityID, r.namespacedFeatureID, r.addedBy)
|
||||
if err != nil {
|
||||
return handleError("insertVulnerabilityAffectedNamespacedFeature", err)
|
||||
}
|
||||
|
||||
if num, err := result.RowsAffected(); err == nil {
|
||||
if num <= 0 {
|
||||
return errors.New("Nothing cached in database")
|
||||
}
|
||||
} else {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Cached %d features in vulnerability_affected_namespaced_feature", len(relation))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) DeleteVulnerabilities(vulnerabilities []database.VulnerabilityID) error {
|
||||
defer observeQueryTime("DeleteVulnerability", "all", time.Now())
|
||||
|
||||
vulnIDs, err := tx.markVulnerabilitiesAsDeleted(vulnerabilities)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.invalidateVulnerabilityCache(vulnIDs); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) invalidateVulnerabilityCache(vulnerabilityIDs []int64) error {
|
||||
if len(vulnerabilityIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Prevent InsertNamespacedFeatures to modify it.
|
||||
err := tx.lockFeatureVulnerabilityCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//TODO(Sida): Make a nicer interface for bulk inserting.
|
||||
keys := make([]interface{}, len(vulnerabilityIDs))
|
||||
for i, id := range vulnerabilityIDs {
|
||||
keys[i] = id
|
||||
}
|
||||
|
||||
_, err = tx.Exec(queryInvalidateVulnerabilityCache(len(vulnerabilityIDs)), keys...)
|
||||
if err != nil {
|
||||
return handleError("removeVulnerabilityAffectedFeature", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tx *pgSession) markVulnerabilitiesAsDeleted(vulnerabilities []database.VulnerabilityID) ([]int64, error) {
|
||||
var (
|
||||
vulnID sql.NullInt64
|
||||
vulnIDs []int64
|
||||
)
|
||||
|
||||
// mark vulnerabilities deleted
|
||||
stmt, err := tx.Prepare(removeVulnerability)
|
||||
if err != nil {
|
||||
return nil, handleError("removeVulnerability", err)
|
||||
}
|
||||
|
||||
defer stmt.Close()
|
||||
for _, vuln := range vulnerabilities {
|
||||
err := stmt.QueryRow(vuln.Namespace, vuln.Name).Scan(&vulnID)
|
||||
if err != nil {
|
||||
return nil, handleError("removeVulnerability", err)
|
||||
}
|
||||
if !vulnID.Valid {
|
||||
return nil, handleError("removeVulnerability", errors.New("Vulnerability to be removed is not in database"))
|
||||
}
|
||||
vulnIDs = append(vulnIDs, vulnID.Int64)
|
||||
}
|
||||
return vulnIDs, nil
|
||||
}
|
||||
|
||||
// findLatestDeletedVulnerabilityIDs requires all elements in vulnIDs are in
|
||||
// database and the order of output array is not guaranteed.
|
||||
func (tx *pgSession) findLatestDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
||||
return tx.findVulnerabilityIDs(vulnIDs, true)
|
||||
}
|
||||
|
||||
func (tx *pgSession) findNotDeletedVulnerabilityIDs(vulnIDs []database.VulnerabilityID) ([]sql.NullInt64, error) {
|
||||
return tx.findVulnerabilityIDs(vulnIDs, false)
|
||||
}
|
||||
|
||||
func (tx *pgSession) findVulnerabilityIDs(vulnIDs []database.VulnerabilityID, withLatestDeleted bool) ([]sql.NullInt64, error) {
|
||||
if len(vulnIDs) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
vulnIDMap := map[database.VulnerabilityID]sql.NullInt64{}
|
||||
keys := make([]interface{}, len(vulnIDs)*2)
|
||||
for i, vulnID := range vulnIDs {
|
||||
keys[i*2] = vulnID.Name
|
||||
keys[i*2+1] = vulnID.Namespace
|
||||
vulnIDMap[vulnID] = sql.NullInt64{}
|
||||
}
|
||||
|
||||
query := ""
|
||||
if withLatestDeleted {
|
||||
query = querySearchLastDeletedVulnerabilityID(len(vulnIDs))
|
||||
} else {
|
||||
query = querySearchNotDeletedVulnerabilityID(len(vulnIDs))
|
||||
}
|
||||
|
||||
rows, err := tx.Query(query, keys...)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Query", err)
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
var (
|
||||
id sql.NullInt64
|
||||
vulnID database.VulnerabilityID
|
||||
)
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&id, &vulnID.Name, &vulnID.Namespace)
|
||||
if err != nil {
|
||||
return nil, handleError("querySearchVulnerabilityID.LatestDeleted.Scan", err)
|
||||
}
|
||||
vulnIDMap[vulnID] = id
|
||||
}
|
||||
|
||||
ids := make([]sql.NullInt64, len(vulnIDs))
|
||||
for i, v := range vulnIDs {
|
||||
ids[i] = vulnIDMap[v]
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
@ -15,282 +15,329 @@
|
||||
package pgsql
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coreos/clair/database"
|
||||
"github.com/coreos/clair/ext/versionfmt"
|
||||
"github.com/coreos/clair/ext/versionfmt/dpkg"
|
||||
"github.com/coreos/clair/pkg/commonerr"
|
||||
)
|
||||
|
||||
func TestFindVulnerability(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("FindVulnerability", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
func TestInsertVulnerabilities(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "InsertVulnerabilities", true)
|
||||
|
||||
ns1 := database.Namespace{
|
||||
Name: "name",
|
||||
VersionFormat: "random stuff",
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Find a vulnerability that does not exist.
|
||||
_, err = datastore.FindVulnerability("", "")
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
ns2 := database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
// Find a normal vulnerability.
|
||||
// invalid vulnerability
|
||||
v1 := database.Vulnerability{
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
|
||||
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
|
||||
Severity: database.HighSeverity,
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
FixedIn: []database.FeatureVersion{
|
||||
{
|
||||
Feature: database.Feature{Name: "openssl"},
|
||||
Version: "2.0",
|
||||
},
|
||||
{
|
||||
Feature: database.Feature{Name: "libssl"},
|
||||
Version: "1.9-abc",
|
||||
},
|
||||
},
|
||||
Name: "invalid",
|
||||
Namespace: ns1,
|
||||
}
|
||||
|
||||
v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
|
||||
if assert.Nil(t, err) {
|
||||
equalsVuln(t, &v1, &v1f)
|
||||
vwa1 := database.VulnerabilityWithAffected{
|
||||
Vulnerability: v1,
|
||||
}
|
||||
|
||||
// Find a vulnerability that has no link, no severity and no FixedIn.
|
||||
// valid vulnerability
|
||||
v2 := database.Vulnerability{
|
||||
Name: "CVE-NOPE",
|
||||
Description: "A vulnerability affecting nothing",
|
||||
Namespace: database.Namespace{
|
||||
Name: "debian:7",
|
||||
Name: "valid",
|
||||
Namespace: ns2,
|
||||
Severity: database.UnknownSeverity,
|
||||
}
|
||||
|
||||
vwa2 := database.VulnerabilityWithAffected{
|
||||
Vulnerability: v2,
|
||||
}
|
||||
|
||||
// empty
|
||||
err := tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{})
|
||||
assert.Nil(t, err)
|
||||
|
||||
// invalid content: vwa1 is invalid
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa1, vwa2})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tx = restartSession(t, store, tx, false)
|
||||
// invalid content: duplicated input
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2, vwa2})
|
||||
assert.NotNil(t, err)
|
||||
|
||||
tx = restartSession(t, store, tx, false)
|
||||
// valid content
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
|
||||
assert.Nil(t, err)
|
||||
|
||||
tx = restartSession(t, store, tx, true)
|
||||
// ensure the content is in database
|
||||
vulns, err := tx.FindVulnerabilities([]database.VulnerabilityID{{Name: "valid", Namespace: "debian:7"}})
|
||||
if assert.Nil(t, err) && assert.Len(t, vulns, 1) {
|
||||
assert.True(t, vulns[0].Valid)
|
||||
}
|
||||
|
||||
tx = restartSession(t, store, tx, false)
|
||||
// valid content: vwa2 removed and inserted
|
||||
err = tx.DeleteVulnerabilities([]database.VulnerabilityID{{Name: vwa2.Name, Namespace: vwa2.Namespace.Name}})
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vwa2})
|
||||
assert.Nil(t, err)
|
||||
|
||||
closeTest(t, store, tx)
|
||||
}
|
||||
|
||||
func TestCachingVulnerable(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "CachingVulnerable", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
ns := database.Namespace{
|
||||
Name: "debian:8",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}
|
||||
|
||||
f := database.NamespacedFeature{
|
||||
Feature: database.Feature{
|
||||
Name: "openssl",
|
||||
Version: "1.0",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
},
|
||||
Severity: database.UnknownSeverity,
|
||||
Namespace: ns,
|
||||
}
|
||||
|
||||
v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE")
|
||||
if assert.Nil(t, err) {
|
||||
equalsVuln(t, &v2, &v2f)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteVulnerability(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("InsertVulnerability", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Delete non-existing Vulnerability.
|
||||
err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7")
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1")
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
|
||||
// Delete Vulnerability.
|
||||
err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
|
||||
if assert.Nil(t, err) {
|
||||
_, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
|
||||
assert.Equal(t, commonerr.ErrNotFound, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInsertVulnerability(t *testing.T) {
|
||||
datastore, err := openDatabaseForTest("InsertVulnerability", false)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer datastore.Close()
|
||||
|
||||
// Create some data.
|
||||
n1 := database.Namespace{
|
||||
Name: "TestInsertVulnerabilityNamespace1",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}
|
||||
n2 := database.Namespace{
|
||||
Name: "TestInsertVulnerabilityNamespace2",
|
||||
VersionFormat: dpkg.ParserName,
|
||||
}
|
||||
|
||||
f1 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion1",
|
||||
Namespace: n1,
|
||||
vuln := database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: "CVE-YAY",
|
||||
Namespace: ns,
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
Version: "1.0",
|
||||
}
|
||||
f2 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion1",
|
||||
Namespace: n2,
|
||||
Affected: []database.AffectedFeature{
|
||||
{
|
||||
Namespace: ns,
|
||||
FeatureName: "openssl",
|
||||
AffectedVersion: "2.0",
|
||||
FixedInVersion: "2.1",
|
||||
},
|
||||
},
|
||||
Version: "1.0",
|
||||
}
|
||||
f3 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion2",
|
||||
},
|
||||
Version: versionfmt.MaxVersion,
|
||||
}
|
||||
f4 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion2",
|
||||
},
|
||||
Version: "1.4",
|
||||
}
|
||||
f5 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion3",
|
||||
},
|
||||
Version: "1.5",
|
||||
}
|
||||
f6 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion4",
|
||||
},
|
||||
Version: "0.1",
|
||||
}
|
||||
f7 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion5",
|
||||
},
|
||||
Version: versionfmt.MaxVersion,
|
||||
}
|
||||
f8 := database.FeatureVersion{
|
||||
Feature: database.Feature{
|
||||
Name: "TestInsertVulnerabilityFeatureVersion5",
|
||||
},
|
||||
Version: versionfmt.MinVersion,
|
||||
}
|
||||
|
||||
// Insert invalid vulnerabilities.
|
||||
for _, vulnerability := range []database.Vulnerability{
|
||||
{
|
||||
Name: "",
|
||||
Namespace: n1,
|
||||
FixedIn: []database.FeatureVersion{f1},
|
||||
Severity: database.UnknownSeverity,
|
||||
vuln2 := database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: "CVE-YAY2",
|
||||
Namespace: ns,
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
{
|
||||
Name: "TestInsertVulnerability0",
|
||||
Namespace: database.Namespace{},
|
||||
FixedIn: []database.FeatureVersion{f1},
|
||||
Severity: database.UnknownSeverity,
|
||||
Affected: []database.AffectedFeature{
|
||||
{
|
||||
Namespace: ns,
|
||||
FeatureName: "openssl",
|
||||
AffectedVersion: "2.1",
|
||||
FixedInVersion: "2.2",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "TestInsertVulnerability0-",
|
||||
Namespace: database.Namespace{},
|
||||
FixedIn: []database.FeatureVersion{f1},
|
||||
}
|
||||
|
||||
vulnFixed1 := database.VulnerabilityWithFixedIn{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: "CVE-YAY",
|
||||
Namespace: ns,
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
{
|
||||
Name: "TestInsertVulnerability0",
|
||||
Namespace: n1,
|
||||
FixedIn: []database.FeatureVersion{f2},
|
||||
Severity: database.UnknownSeverity,
|
||||
FixedInVersion: "2.1",
|
||||
}
|
||||
|
||||
vulnFixed2 := database.VulnerabilityWithFixedIn{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Name: "CVE-YAY2",
|
||||
Namespace: ns,
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
} {
|
||||
err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true)
|
||||
assert.Error(t, err)
|
||||
FixedInVersion: "2.2",
|
||||
}
|
||||
|
||||
// Insert a simple vulnerability and find it.
|
||||
v1meta := make(map[string]interface{})
|
||||
v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1"
|
||||
v1meta["TestInsertVulnerabilityMetadata2"] = struct {
|
||||
Test string
|
||||
}{
|
||||
Test: "TestInsertVulnerabilityMetadataValue1",
|
||||
if !assert.Nil(t, tx.InsertVulnerabilities([]database.VulnerabilityWithAffected{vuln, vuln2})) {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
v1 := database.Vulnerability{
|
||||
Name: "TestInsertVulnerability1",
|
||||
Namespace: n1,
|
||||
FixedIn: []database.FeatureVersion{f1, f3, f6, f7},
|
||||
Severity: database.LowSeverity,
|
||||
Description: "TestInsertVulnerabilityDescription1",
|
||||
Link: "TestInsertVulnerabilityLink1",
|
||||
Metadata: v1meta,
|
||||
}
|
||||
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true)
|
||||
if assert.Nil(t, err) {
|
||||
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
|
||||
if assert.Nil(t, err) {
|
||||
equalsVuln(t, &v1, &v1f)
|
||||
}
|
||||
}
|
||||
|
||||
// Update vulnerability.
|
||||
v1.Description = "TestInsertVulnerabilityLink2"
|
||||
v1.Link = "TestInsertVulnerabilityLink2"
|
||||
v1.Severity = database.HighSeverity
|
||||
// Update f3 in f4, add fixed in f5, add fixed in f6 which already exists,
|
||||
// removes fixed in f7 by adding f8 which is f7 but with MinVersion, and
|
||||
// add fixed by f5 a second time (duplicated).
|
||||
v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8, f5}
|
||||
|
||||
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true)
|
||||
if assert.Nil(t, err) {
|
||||
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
|
||||
if assert.Nil(t, err) {
|
||||
// Remove f8 from the struct for comparison as it was just here to cancel f7.
|
||||
// Remove one of the f5 too as it was twice in the struct but the database
|
||||
// implementation should have dedup'd it.
|
||||
v1.FixedIn = v1.FixedIn[:len(v1.FixedIn)-2]
|
||||
|
||||
// We already had f1 before the update.
|
||||
// Add it to the struct for comparison.
|
||||
v1.FixedIn = append(v1.FixedIn, f1)
|
||||
|
||||
equalsVuln(t, &v1, &v1f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) {
|
||||
assert.Equal(t, expected.Name, actual.Name)
|
||||
assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name)
|
||||
assert.Equal(t, expected.Description, actual.Description)
|
||||
assert.Equal(t, expected.Link, actual.Link)
|
||||
assert.Equal(t, expected.Severity, actual.Severity)
|
||||
assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata))
|
||||
|
||||
if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) {
|
||||
for _, actualFeatureVersion := range actual.FixedIn {
|
||||
found := false
|
||||
for _, expectedFeatureVersion := range expected.FixedIn {
|
||||
if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name {
|
||||
found = true
|
||||
|
||||
assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name)
|
||||
assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version)
|
||||
r, err := tx.FindAffectedNamespacedFeatures([]database.NamespacedFeature{f})
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, r, 1)
|
||||
for _, anf := range r {
|
||||
if assert.True(t, anf.Valid) && assert.Len(t, anf.AffectedBy, 2) {
|
||||
for _, a := range anf.AffectedBy {
|
||||
if a.Name == "CVE-YAY" {
|
||||
assert.Equal(t, vulnFixed1, a)
|
||||
} else if a.Name == "CVE-YAY2" {
|
||||
assert.Equal(t, vulnFixed2, a)
|
||||
} else {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindVulnerabilities(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "FindVulnerabilities", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
vuln, err := tx.FindVulnerabilities([]database.VulnerabilityID{
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
{Name: "CVE-NOPE", Namespace: "debian:7"},
|
||||
{Name: "CVE-NOT HERE"},
|
||||
})
|
||||
|
||||
ns := database.Namespace{
|
||||
Name: "debian:7",
|
||||
VersionFormat: "dpkg",
|
||||
}
|
||||
|
||||
expectedExisting := []database.VulnerabilityWithAffected{
|
||||
{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Namespace: ns,
|
||||
Name: "CVE-OPENSSL-1-DEB7",
|
||||
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
|
||||
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
|
||||
Severity: database.HighSeverity,
|
||||
},
|
||||
Affected: []database.AffectedFeature{
|
||||
{
|
||||
FeatureName: "openssl",
|
||||
AffectedVersion: "2.0",
|
||||
FixedInVersion: "2.0",
|
||||
Namespace: ns,
|
||||
},
|
||||
{
|
||||
FeatureName: "libssl",
|
||||
AffectedVersion: "1.9-abc",
|
||||
FixedInVersion: "1.9-abc",
|
||||
Namespace: ns,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Vulnerability: database.Vulnerability{
|
||||
Namespace: ns,
|
||||
Name: "CVE-NOPE",
|
||||
Description: "A vulnerability affecting nothing",
|
||||
Severity: database.UnknownSeverity,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
expectedExistingMap := map[database.VulnerabilityID]database.VulnerabilityWithAffected{}
|
||||
for _, v := range expectedExisting {
|
||||
expectedExistingMap[database.VulnerabilityID{Name: v.Name, Namespace: v.Namespace.Name}] = v
|
||||
}
|
||||
|
||||
nonexisting := database.VulnerabilityWithAffected{
|
||||
Vulnerability: database.Vulnerability{Name: "CVE-NOT HERE"},
|
||||
}
|
||||
|
||||
if assert.Nil(t, err) {
|
||||
for _, v := range vuln {
|
||||
if v.Valid {
|
||||
key := database.VulnerabilityID{
|
||||
Name: v.Name,
|
||||
Namespace: v.Namespace.Name,
|
||||
}
|
||||
|
||||
expected, ok := expectedExistingMap[key]
|
||||
if assert.True(t, ok, "vulnerability not found: "+key.Name+":"+key.Namespace) {
|
||||
assertVulnerabilityWithAffectedEqual(t, expected, v.VulnerabilityWithAffected)
|
||||
}
|
||||
} else if !assert.Equal(t, nonexisting, v.VulnerabilityWithAffected) {
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// same vulnerability
|
||||
r, err := tx.FindVulnerabilities([]database.VulnerabilityID{
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
})
|
||||
|
||||
if assert.Nil(t, err) {
|
||||
for _, vuln := range r {
|
||||
if assert.True(t, vuln.Valid) {
|
||||
expected, _ := expectedExistingMap[database.VulnerabilityID{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"}]
|
||||
assertVulnerabilityWithAffectedEqual(t, expected, vuln.VulnerabilityWithAffected)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStringComparison(t *testing.T) {
|
||||
cmp := compareStringLists([]string{"a", "b", "b", "a"}, []string{"a", "c"})
|
||||
assert.Len(t, cmp, 1)
|
||||
assert.NotContains(t, cmp, "a")
|
||||
assert.Contains(t, cmp, "b")
|
||||
func TestDeleteVulnerabilities(t *testing.T) {
|
||||
datastore, tx := openSessionForTest(t, "DeleteVulnerabilities", true)
|
||||
defer closeTest(t, datastore, tx)
|
||||
|
||||
cmp = compareStringListsInBoth([]string{"a", "a", "b", "c"}, []string{"a", "c", "c"})
|
||||
assert.Len(t, cmp, 2)
|
||||
assert.NotContains(t, cmp, "b")
|
||||
assert.Contains(t, cmp, "a")
|
||||
assert.Contains(t, cmp, "c")
|
||||
remove := []database.VulnerabilityID{}
|
||||
// empty case
|
||||
assert.Nil(t, tx.DeleteVulnerabilities(remove))
|
||||
// invalid case
|
||||
remove = append(remove, database.VulnerabilityID{})
|
||||
assert.NotNil(t, tx.DeleteVulnerabilities(remove))
|
||||
|
||||
// valid case
|
||||
validRemove := []database.VulnerabilityID{
|
||||
{Name: "CVE-OPENSSL-1-DEB7", Namespace: "debian:7"},
|
||||
{Name: "CVE-NOPE", Namespace: "debian:7"},
|
||||
}
|
||||
|
||||
assert.Nil(t, tx.DeleteVulnerabilities(validRemove))
|
||||
vuln, err := tx.FindVulnerabilities(validRemove)
|
||||
if assert.Nil(t, err) {
|
||||
for _, v := range vuln {
|
||||
assert.False(t, v.Valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindVulnerabilityIDs(t *testing.T) {
|
||||
store, tx := openSessionForTest(t, "FindVulnerabilityIDs", true)
|
||||
defer closeTest(t, store, tx)
|
||||
|
||||
ids, err := tx.findLatestDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-DELETED", Namespace: "debian:7"}})
|
||||
if assert.Nil(t, err) {
|
||||
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 3, ids[0].Int64)) {
|
||||
assert.Fail(t, "")
|
||||
}
|
||||
}
|
||||
|
||||
ids, err = tx.findNotDeletedVulnerabilityIDs([]database.VulnerabilityID{{Name: "CVE-NOPE", Namespace: "debian:7"}})
|
||||
if assert.Nil(t, err) {
|
||||
if !(assert.Len(t, ids, 1) && assert.True(t, ids[0].Valid) && assert.Equal(t, 2, ids[0].Int64)) {
|
||||
assert.Fail(t, "")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func assertVulnerabilityWithAffectedEqual(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool {
|
||||
return assert.Equal(t, expected.Vulnerability, actual.Vulnerability) && assertAffectedFeaturesEqual(t, expected.Affected, actual.Affected)
|
||||
}
|
||||
|
||||
func assertAffectedFeaturesEqual(t *testing.T, expected []database.AffectedFeature, actual []database.AffectedFeature) bool {
|
||||
if assert.Len(t, actual, len(expected)) {
|
||||
has := map[database.AffectedFeature]bool{}
|
||||
for _, i := range expected {
|
||||
has[i] = false
|
||||
}
|
||||
for _, i := range actual {
|
||||
if visited, ok := has[i]; !ok {
|
||||
return false
|
||||
} else if visited {
|
||||
return false
|
||||
}
|
||||
has[i] = true
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user