database: add support mysql

Signed-off-by: Lei Jitang <leijitang@huawei.com>
This commit is contained in:
Lei Jitang 2016-03-22 04:30:43 -04:00
parent 2121e76f98
commit 3ba218dbef
30 changed files with 3941 additions and 98 deletions

View File

@ -20,12 +20,15 @@ import (
"math/rand"
"os"
"os/signal"
"strings"
"syscall"
"time"
"github.com/coreos/clair/api"
"github.com/coreos/clair/api/context"
"github.com/coreos/clair/config"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/mysql"
"github.com/coreos/clair/database/pgsql"
"github.com/coreos/clair/notifier"
"github.com/coreos/clair/updater"
@ -40,14 +43,22 @@ var log = capnslog.NewPackageLogger("github.com/coreos/clair", "main")
func Boot(config *config.Config) {
rand.Seed(time.Now().UnixNano())
st := utils.NewStopper()
var (
db database.Datastore
err error
)
// Open database
db, err := pgsql.Open(config.Database)
if strings.HasPrefix(config.Database.Source, "postgres") {
db, err = pgsql.Open(config.Database)
} else if strings.HasPrefix(config.Database.Source, "mysql") {
db, err = mysql.Open(config.Database)
} else {
log.Fatal("database source '%s' does not support, support 'postgres' and 'mysql' for now", config.Database.Source)
}
if err != nil {
log.Fatal(err)
}
defer db.Close()
// Start notifier
st.Begin()
go notifier.Run(config.Notifier, db, st)

View File

@ -15,9 +15,10 @@
# The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined.
clair:
database:
# PostgreSQL Connection string
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html
source:
# database Connection string
# see http://www.postgresql.org/docs/9.4/static/libpq-connect.html for PostgreSQL
# see https://github.com/go-sql-driver/mysql#dsn-data-source-name for MySQL. e.g. "mysql://root@tcp(127.0.0.1:3306)/"
source: mysql://root@tcp(172.17.0.3:3306)/
# Number of elements kept in the cache
# Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database.

View File

@ -0,0 +1,158 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"fmt"
"math/rand"
"runtime"
"strconv"
"sync"
"testing"
"time"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
"github.com/coreos/clair/utils/types"
"github.com/pborman/uuid"
"github.com/stretchr/testify/assert"
)
const (
numVulnerabilities = 100
numFeatureVersions = 100
)
func TestRaceAffects(t *testing.T) {
datastore, err := OpenForTest("RaceAffects", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Insert the Feature on which we'll work.
feature := database.Feature{
Namespace: database.Namespace{Name: "TestRaceAffectsFeatureNamespace1"},
Name: "TestRaceAffecturesFeature1",
}
_, err = datastore.insertFeature(feature)
if err != nil {
t.Error(err)
return
}
// Initialize random generator and enforce max procs.
rand.Seed(time.Now().UnixNano())
runtime.GOMAXPROCS(runtime.NumCPU())
// Generate FeatureVersions.
featureVersions := make([]database.FeatureVersion, numFeatureVersions)
for i := 0; i < numFeatureVersions; i++ {
version := rand.Intn(numFeatureVersions)
featureVersions[i] = database.FeatureVersion{
Feature: feature,
Version: types.NewVersionUnsafe(strconv.Itoa(version)),
}
}
// Generate vulnerabilities.
// They are mapped by fixed version, which will make verification really easy afterwards.
vulnerabilities := make(map[int][]database.Vulnerability)
for i := 0; i < numVulnerabilities; i++ {
version := rand.Intn(numFeatureVersions) + 1
// if _, ok := vulnerabilities[version]; !ok {
// vulnerabilities[version] = make([]database.Vulnerability)
// }
vulnerability := database.Vulnerability{
Name: uuid.New(),
Namespace: feature.Namespace,
FixedIn: []database.FeatureVersion{
{
Feature: feature,
Version: types.NewVersionUnsafe(strconv.Itoa(version)),
},
},
Severity: types.Unknown,
}
vulnerabilities[version] = append(vulnerabilities[version], vulnerability)
}
// Insert featureversions and vulnerabilities in parallel.
var wg sync.WaitGroup
wg.Add(2)
go func() {
defer wg.Done()
for _, vulnerabilitiesM := range vulnerabilities {
for _, vulnerability := range vulnerabilitiesM {
err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true)
assert.Nil(t, err)
}
}
fmt.Println("finished to insert vulnerabilities")
}()
go func() {
defer wg.Done()
for i := 0; i < len(featureVersions); i++ {
featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i])
assert.Nil(t, err)
}
fmt.Println("finished to insert featureVersions")
}()
wg.Wait()
// Verify consistency now.
var actualAffectedNames []string
var expectedAffectedNames []string
for _, featureVersion := range featureVersions {
featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String())
// Get actual affects.
rows, err := datastore.Query(searchComplexTestFeatureVersionAffects,
featureVersion.ID)
assert.Nil(t, err)
defer rows.Close()
var vulnName string
for rows.Next() {
err = rows.Scan(&vulnName)
if !assert.Nil(t, err) {
continue
}
actualAffectedNames = append(actualAffectedNames, vulnName)
}
if assert.Nil(t, rows.Err()) {
rows.Close()
}
// Get expected affects.
for i := numVulnerabilities; i > featureVersionVersion; i-- {
for _, vulnerability := range vulnerabilities[i] {
expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name)
}
}
assert.Len(t, utils.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0)
assert.Len(t, utils.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0)
}
}

305
database/mysql/feature.go Normal file
View File

@ -0,0 +1,305 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"database/sql"
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
)
func (mySQL *mySQL) insertFeatureiWithTransaction(queryer Queryer, feature database.Feature) (int, error) {
if feature.Name == "" {
return 0, cerrors.NewBadRequestError("could not find/insert invalid Feature")
}
// Do cache lookup.
if mySQL.cache != nil {
database.PromCacheQueriesTotal.WithLabelValues("feature").Inc()
id, found := mySQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name)
if found {
database.PromCacheHitsTotal.WithLabelValues("feature").Inc()
return id.(int), nil
}
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached features.
defer database.ObserveQueryTime("insertFeature", "all", time.Now())
// Find or create Namespace.
namespaceID, err := mySQL.insertNamespaceWithTransaction(queryer, feature.Namespace)
if err != nil {
return 0, err
}
// Find or create Feature.
var id int
res, err := queryer.Exec(insertFeature, feature.Name, namespaceID, feature.Name, namespaceID)
if err != nil {
return 0, handleError("insertFeatureVersion", err)
}
tmpid, err := res.LastInsertId()
if err != nil {
return 0, handleError("insertFeatureVersion", err)
}
id = int(tmpid)
// if id==0 means the feature already exists, use query to get id
if id == 0 {
err = queryer.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id)
if err != nil {
return 0, handleError("soiFeature", err)
}
}
if mySQL.cache != nil {
mySQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id)
}
return id, nil
}
func (mySQL *mySQL) insertFeature(feature database.Feature) (int, error) {
if feature.Name == "" {
return 0, cerrors.NewBadRequestError("could not find/insert invalid Feature")
}
// Do cache lookup.
if mySQL.cache != nil {
database.PromCacheQueriesTotal.WithLabelValues("feature").Inc()
id, found := mySQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name)
if found {
database.PromCacheHitsTotal.WithLabelValues("feature").Inc()
return id.(int), nil
}
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached features.
defer database.ObserveQueryTime("insertFeature", "all", time.Now())
// Find or create Namespace.
namespaceID, err := mySQL.insertNamespace(feature.Namespace)
if err != nil {
return 0, err
}
// Find or create Feature.
var id int
res, err := mySQL.Exec(insertFeature, feature.Name, namespaceID, feature.Name, namespaceID)
if err != nil {
return 0, handleError("insertFeatureVersion", err)
}
tmpid, err := res.LastInsertId()
if err != nil {
return 0, handleError("insertFeatureVersion", err)
}
id = int(tmpid)
// if id==0 means the feature already exists, use query to get id
if id == 0 {
err = mySQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id)
if err != nil {
return 0, handleError("soiFeature", err)
}
}
if mySQL.cache != nil {
mySQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id)
}
return id, nil
}
func (mySQL *mySQL) insertFeatureVersion(featureVersion database.FeatureVersion) (id int, err error) {
if featureVersion.Version.String() == "" {
return 0, cerrors.NewBadRequestError("could not find/insert invalid FeatureVersion")
}
// Do cache lookup.
cacheIndex := "featureversion:" + featureVersion.Feature.Namespace.Name + ":" + featureVersion.Feature.Name + ":" + featureVersion.Version.String()
if mySQL.cache != nil {
database.PromCacheQueriesTotal.WithLabelValues("featureversion").Inc()
id, found := mySQL.cache.Get(cacheIndex)
if found {
database.PromCacheHitsTotal.WithLabelValues("featureversion").Inc()
return id.(int), nil
}
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached featureversions.
defer database.ObserveQueryTime("insertFeatureVersion", "all", time.Now())
// Find or create Feature first.
t := time.Now()
featureID, err := mySQL.insertFeature(featureVersion.Feature)
database.ObserveQueryTime("insertFeatureVersion", "insertFeature", t)
if err != nil {
return 0, err
}
featureVersion.Feature.ID = featureID
// Try to find the FeatureVersion.
//
// In a populated database, the likelihood of the FeatureVersion already being there is high.
// If we can find it here, we then avoid using a transaction and locking the database.
err = mySQL.QueryRow(searchFeatureVersion, featureID, &featureVersion.Version).
Scan(&featureVersion.ID)
if err != nil && err != sql.ErrNoRows {
return 0, handleError("searchFeatureVersion", err)
}
if err == nil {
if mySQL.cache != nil {
mySQL.cache.Add(cacheIndex, featureVersion.ID)
}
return featureVersion.ID, nil
}
// Begin transaction.
tx, err := mySQL.Begin()
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.Begin()", err)
}
// Lock Vulnerability_Affects_FeatureVersion exclusively.
// We want to prevent InsertVulnerability to modify it.
database.PromConcurrentLockVAFV.Inc()
defer database.PromConcurrentLockVAFV.Dec()
t = time.Now()
var tmp int64
err = tx.QueryRow(lockVulnerabilityAffects).Scan(&tmp)
database.ObserveQueryTime("insertFeatureVersion", "lock", t)
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err)
}
// Find or create FeatureVersion.
var newOrExisting string
t = time.Now()
_, err = tx.Exec(insertFeatureVersion, featureID, &featureVersion.Version, featureID, &featureVersion.Version)
database.ObserveQueryTime("insertFeatureVersion", "soiFeatureVersion", t)
if err != nil {
tx.Rollback()
return 0, handleError("insertFeatureVersion", err)
}
t = time.Now()
err = tx.QueryRow(soiFeatureVersion, featureID, &featureVersion.Version).
Scan(&newOrExisting, &featureVersion.ID)
database.ObserveQueryTime("insertFeatureVersion", "soiFeatureVersion", t)
if err != nil {
tx.Rollback()
return 0, handleError("soiFeatureVersion", err)
}
if newOrExisting == "exi" {
// That featureVersion already exists, return its id.
tx.Commit()
if mySQL.cache != nil {
mySQL.cache.Add(cacheIndex, featureVersion.ID)
}
return featureVersion.ID, nil
}
// Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in
// Vulnerability_Affects_FeatureVersion.
t = time.Now()
err = linkFeatureVersionToVulnerabilities(tx, featureVersion)
database.ObserveQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t)
if err != nil {
tx.Rollback()
return 0, err
}
// Commit transaction.
err = tx.Commit()
if err != nil {
return 0, handleError("insertFeatureVersion.Commit()", err)
}
if mySQL.cache != nil {
mySQL.cache.Add(cacheIndex, featureVersion.ID)
}
return featureVersion.ID, nil
}
// TODO(Quentin-M): Batch me
func (mySQL *mySQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) {
IDs := make([]int, 0, len(featureVersions))
for i := 0; i < len(featureVersions); i++ {
id, err := mySQL.insertFeatureVersion(featureVersions[i])
if err != nil {
return IDs, err
}
IDs = append(IDs, id)
}
return IDs, nil
}
type vulnerabilityAffectsFeatureVersion struct {
vulnerabilityID int
fixedInID int
fixedInVersion types.Version
}
func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error {
// Select every vulnerability and the fixed version that affect this Feature.
// TODO(Quentin-M): LIMIT
rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID)
if err != nil {
return handleError("searchVulnerabilityFixedInFeature", err)
}
defer rows.Close()
var affects []vulnerabilityAffectsFeatureVersion
for rows.Next() {
var affect vulnerabilityAffectsFeatureVersion
err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion)
if err != nil {
return handleError("searchVulnerabilityFixedInFeature.Scan()", err)
}
if featureVersion.Version.Compare(affect.fixedInVersion) < 0 {
// The version of the FeatureVersion we are inserting is lower than the fixed version on this
// Vulnerability, thus, this FeatureVersion is affected by it.
affects = append(affects, affect)
}
}
if err = rows.Err(); err != nil {
return handleError("searchVulnerabilityFixedInFeature.Rows()", err)
}
rows.Close()
// Insert into Vulnerability_Affects_FeatureVersion.
for _, affect := range affects {
// TODO(Quentin-M): Batch me.
_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID,
featureVersion.ID, affect.fixedInID, affect.vulnerabilityID, featureVersion.ID, affect.fixedInID)
if err != nil {
return handleError("insertVulnerabilityAffectsFeatureVersion", err)
}
}
return nil
}

View File

@ -0,0 +1,102 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"testing"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
)
func TestInsertFeature(t *testing.T) {
datastore, err := OpenForTest("InsertFeature", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Invalid Feature.
id0, err := datastore.insertFeature(database.Feature{})
assert.NotNil(t, err)
assert.Zero(t, id0)
id0, err = datastore.insertFeature(database.Feature{
Namespace: database.Namespace{},
Name: "TestInsertFeature0",
})
assert.NotNil(t, err)
assert.Zero(t, id0)
// Insert Feature and ensure we can find it.
feature := database.Feature{
Namespace: database.Namespace{Name: "TestInsertFeatureNamespace1"},
Name: "TestInsertFeature1",
}
id1, err := datastore.insertFeature(feature)
assert.Nil(t, err)
id2, err := datastore.insertFeature(feature)
assert.Nil(t, err)
assert.Equal(t, id1, id2)
// Insert invalid FeatureVersion.
for _, invalidFeatureVersion := range []database.FeatureVersion{
{
Feature: database.Feature{},
Version: types.NewVersionUnsafe("1.0"),
},
{
Feature: database.Feature{
Namespace: database.Namespace{},
Name: "TestInsertFeature2",
},
Version: types.NewVersionUnsafe("1.0"),
},
{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertFeatureNamespace2"},
Name: "TestInsertFeature2",
},
Version: types.NewVersionUnsafe(""),
},
{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertFeatureNamespace2"},
Name: "TestInsertFeature2",
},
Version: types.NewVersionUnsafe("bad version"),
},
} {
id3, err := datastore.insertFeatureVersion(invalidFeatureVersion)
assert.Error(t, err)
assert.Zero(t, id3)
}
// Insert FeatureVersion and ensure we can find it.
featureVersion := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertFeatureNamespace1"},
Name: "TestInsertFeature1",
},
Version: types.NewVersionUnsafe("2:3.0-imba"),
}
id4, err := datastore.insertFeatureVersion(featureVersion)
assert.Nil(t, err)
id5, err := datastore.insertFeatureVersion(featureVersion)
assert.Nil(t, err)
assert.Equal(t, id4, id5)
}

View File

@ -0,0 +1,83 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"database/sql"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"time"
)
// InsertKeyValue stores (or updates) a single key / value tuple.
func (mySQL *mySQL) InsertKeyValue(key, value string) (err error) {
if key == "" || value == "" {
log.Warning("could not insert a flag which has an empty name or value")
return cerrors.NewBadRequestError("could not insert a flag which has an empty name or value")
}
defer database.ObserveQueryTime("InsertKeyValue", "all", time.Now())
// Upsert.
//
// Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS.
// The best solution is currently the use of http://dba.stackexchange.com/a/13477
// but the key/value storage doesn't need to be super-efficient and super-safe at the
// moment so we can just use a client-side solution with transactions, based on
// http://postgresql.org/docs/current/static/plpgsql-control-structures.html.
// TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable.
for {
// First, try to update.
r, err := mySQL.Exec(updateKeyValue, value, key)
if err != nil {
return handleError("updateKeyValue", err)
}
if n, _ := r.RowsAffected(); n > 0 {
// Updated successfully.
return nil
}
// Try to insert the key.
// If someone else inserts the same key concurrently, we could get a unique-key violation error.
_, err = mySQL.Exec(insertKeyValue, key, value)
if err != nil {
if isErrUniqueViolation(err) {
// Got unique constraint violation, retry.
continue
}
return handleError("insertKeyValue", err)
}
return nil
}
}
// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist.
func (mySQL *mySQL) GetKeyValue(key string) (string, error) {
defer database.ObserveQueryTime("GetKeyValue", "all", time.Now())
var value string
err := mySQL.QueryRow(searchKeyValue, key).Scan(&value)
if err == sql.ErrNoRows {
return "", nil
}
if err != nil {
return "", handleError("searchKeyValue", err)
}
return value, nil
}

View File

@ -0,0 +1,52 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestKeyValue(t *testing.T) {
datastore, err := OpenForTest("KeyValue", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Get non-existing key/value
f, err := datastore.GetKeyValue("test")
assert.Nil(t, err)
assert.Empty(t, "", f)
// Try to insert invalid key/value.
assert.Error(t, datastore.InsertKeyValue("test", ""))
assert.Error(t, datastore.InsertKeyValue("", "test"))
assert.Error(t, datastore.InsertKeyValue("", ""))
// Insert and verify.
assert.Nil(t, datastore.InsertKeyValue("test", "test1"))
f, err = datastore.GetKeyValue("test")
assert.Nil(t, err)
assert.Equal(t, "test1", f)
// Update and verify.
assert.Nil(t, datastore.InsertKeyValue("test", "test2"))
f, err = datastore.GetKeyValue("test")
assert.Nil(t, err)
assert.Equal(t, "test2", f)
}

414
database/mysql/layer.go Normal file
View File

@ -0,0 +1,414 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"database/sql"
"fmt"
"time"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/guregu/null/zero"
)
func (mySQL *mySQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) {
subquery := "all"
if withFeatures {
subquery += "/features"
} else if withVulnerabilities {
subquery += "/features+vulnerabilities"
}
defer database.ObserveQueryTime("FindLayer", subquery, time.Now())
// Find the layer
var layer database.Layer
var parentID zero.Int
var parentName zero.String
var namespaceID zero.Int
var namespaceName sql.NullString
t := time.Now()
err := mySQL.QueryRow(searchLayer, name).
Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID,
&namespaceName)
database.ObserveQueryTime("FindLayer", "searchLayer", t)
if err != nil {
return layer, handleError("searchLayer", err)
}
if !parentID.IsZero() {
layer.Parent = &database.Layer{
Model: database.Model{ID: int(parentID.Int64)},
Name: parentName.String,
}
}
if !namespaceID.IsZero() {
layer.Namespace = &database.Namespace{
Model: database.Model{ID: int(namespaceID.Int64)},
Name: namespaceName.String,
}
}
// Find its features
if withFeatures || withVulnerabilities {
// Create a transaction to disable hash/merge joins as our experiments have shown that
// PostgreSQL 9.4 makes bad planning decisions about:
// - joining the layer tree to feature versions and feature
// - joining the feature versions to affected/fixed feature version and vulnerabilities
// It would for instance do a merge join between affected feature versions (300 rows, estimated
// 3000 rows) and fixed in feature version (100k rows). In this case, it is much more
// preferred to use a nested loop.
tx, err := mySQL.Begin()
if err != nil {
return layer, handleError("FindLayer.Begin()", err)
}
defer tx.Commit()
t = time.Now()
featureVersions, err := getLayerFeatureVersions(tx, layer.ID)
database.ObserveQueryTime("FindLayer", "getLayerFeatureVersions", t)
if err != nil {
return layer, err
}
layer.Features = featureVersions
if withVulnerabilities {
// Load the vulnerabilities that affect the FeatureVersions.
t = time.Now()
err := loadAffectedBy(tx, layer.Features)
database.ObserveQueryTime("FindLayer", "loadAffectedBy", t)
if err != nil {
return layer, err
}
}
}
return layer, nil
}
// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has.
func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) {
var featureVersions []database.FeatureVersion
// Query.
rows, err := tx.Query(searchLayerFeatureVersion, layerID)
if err != nil {
return featureVersions, handleError("searchLayerFeatureVersion", err)
}
defer rows.Close()
// Scan query.
var modification string
mapFeatureVersions := make(map[int]database.FeatureVersion)
for rows.Next() {
var featureVersion database.FeatureVersion
err = rows.Scan(&featureVersion.ID, &modification, &featureVersion.Feature.Namespace.ID,
&featureVersion.Feature.Namespace.Name, &featureVersion.Feature.ID,
&featureVersion.Feature.Name, &featureVersion.ID, &featureVersion.Version,
&featureVersion.AddedBy.ID, &featureVersion.AddedBy.Name)
if err != nil {
return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err)
}
// Do transitive closure.
switch modification {
case "add":
mapFeatureVersions[featureVersion.ID] = featureVersion
case "del":
delete(mapFeatureVersions, featureVersion.ID)
default:
log.Warningf("unknown Layer_diff_FeatureVersion's modification: %s", modification)
return featureVersions, database.ErrInconsistent
}
}
if err = rows.Err(); err != nil {
return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err)
}
// Build result by converting our map to a slice.
for _, featureVersion := range mapFeatureVersions {
featureVersions = append(featureVersions, featureVersion)
}
return featureVersions, nil
}
// loadAffectedBy returns the list of database.Vulnerability that affect the given
// FeatureVersion.
func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error {
if len(featureVersions) == 0 {
return nil
}
// Construct list of FeatureVersion IDs, we will do a single query
featureVersionIDs := make([]int, 0, len(featureVersions))
for i := 0; i < len(featureVersions); i++ {
featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID)
}
searchFeatureVersionVulnerabilityQuery := fmt.Sprintf(searchFeatureVersionVulnerability, buildInputArray(featureVersionIDs))
rows, err := tx.Query(searchFeatureVersionVulnerabilityQuery)
if err != nil && err != sql.ErrNoRows {
return handleError("searchFeatureVersionVulnerability", err)
}
defer rows.Close()
vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions))
var featureversionID int
for rows.Next() {
var vulnerability database.Vulnerability
err := rows.Scan(&featureversionID, &vulnerability.ID, &vulnerability.Name,
&vulnerability.Description, &vulnerability.Link, &vulnerability.Severity,
&vulnerability.Metadata, &vulnerability.Namespace.Name, &vulnerability.FixedBy)
if err != nil {
return handleError("searchFeatureVersionVulnerability.Scan()", err)
}
vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability)
}
if err = rows.Err(); err != nil {
return handleError("searchFeatureVersionVulnerability.Rows()", err)
}
// Assign vulnerabilities to every FeatureVersions
for i := 0; i < len(featureVersions); i++ {
featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID]
}
return nil
}
// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent,
// the Feature list will be compared to the parent's Feature list and the difference will be stored.
// Note that when the Namespace of a layer differs from its parent, it is expected that several
// Feature that were already included a parent will have their Namespace updated as well
// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed
// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't
// been modified.
func (mySQL *mySQL) InsertLayer(layer database.Layer) error {
tf := time.Now()
// Verify parameters
if layer.Name == "" {
log.Warning("could not insert a layer which has an empty Name")
return cerrors.NewBadRequestError("could not insert a layer which has an empty Name")
}
// Get a potentially existing layer.
existingLayer, err := mySQL.FindLayer(layer.Name, true, false)
if err != nil && err != cerrors.ErrNotFound {
return err
} else if err == nil {
if existingLayer.EngineVersion >= layer.EngineVersion {
// The layer exists and has an equal or higher engine version, do nothing.
return nil
}
layer.ID = existingLayer.ID
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe existing layers.
defer database.ObserveQueryTime("InsertLayer", "all", tf)
// Get parent ID.
var parentID zero.Int
if layer.Parent != nil {
if layer.Parent.ID == 0 {
log.Warning("Parent is expected to be retrieved from database when inserting a layer.")
return cerrors.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.")
}
parentID = zero.IntFrom(int64(layer.Parent.ID))
}
// Find or insert namespace if provided.
var namespaceID zero.Int
if layer.Namespace != nil {
n, err := mySQL.insertNamespace(*layer.Namespace)
if err != nil {
return err
}
namespaceID = zero.IntFrom(int64(n))
} else if layer.Namespace == nil && layer.Parent != nil {
// Import the Namespace from the parent if it has one and this layer doesn't specify one.
if layer.Parent.Namespace != nil {
namespaceID = zero.IntFrom(int64(layer.Parent.Namespace.ID))
}
}
// Begin transaction.
tx, err := mySQL.Begin()
if err != nil {
tx.Rollback()
return handleError("InsertLayer.Begin()", err)
}
if layer.ID == 0 {
// Insert a new layer.
res, err := tx.Exec(insertLayer, layer.Name, layer.EngineVersion, parentID, namespaceID)
if err != nil {
tx.Rollback()
if isErrUniqueViolation(err) {
// Ignore this error, another process collided.
log.Debug("Attempted to insert duplicate layer.")
return nil
}
return handleError("insertLayer", err)
}
tmpid, err := res.LastInsertId()
if err != nil {
tx.Rollback()
return handleError("inserLayer", err)
}
layer.ID = int(tmpid)
// if LastInsertId == 0, it means the layer already exists
if layer.ID == 0 {
err = tx.QueryRow(getLayerId, layer.Name).Scan(&layer.ID)
if err != nil {
tx.Rollback()
if isErrUniqueViolation(err) {
// Ignore this error, another process collided.
return nil
}
return handleError("getLayerId", err)
}
}
} else {
// Update an existing layer.
_, err = tx.Exec(updateLayer, layer.EngineVersion, namespaceID, layer.ID)
if err != nil {
tx.Rollback()
return handleError("updateLayer", err)
}
// Remove all existing Layer_diff_FeatureVersion.
_, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID)
if err != nil {
tx.Rollback()
return handleError("removeLayerDiffFeatureVersion", err)
}
}
// Update Layer_diff_FeatureVersion now.
err = mySQL.updateDiffFeatureVersions(tx, &layer, &existingLayer)
if err != nil {
tx.Rollback()
return err
}
// Commit transaction.
err = tx.Commit()
if err != nil {
tx.Rollback()
return handleError("InsertLayer.Commit()", err)
}
return nil
}
func (mySQL *mySQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error {
// add and del are the FeatureVersion diff we should insert.
var add []database.FeatureVersion
var del []database.FeatureVersion
if layer.Parent == nil {
// There is no parent, every Features are added.
add = append(add, layer.Features...)
} else if layer.Parent != nil {
// There is a parent, we need to diff the Features with it.
// Build name:version structures.
layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features)
parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features)
// Calculate the added and deleted FeatureVersions name:version.
addNV := utils.CompareStringLists(layerFeaturesNV, parentLayerFeaturesNV)
delNV := utils.CompareStringLists(parentLayerFeaturesNV, layerFeaturesNV)
// Fill the structures containing the added and deleted FeatureVersions
for _, nv := range addNV {
add = append(add, *layerFeaturesMapNV[nv])
}
for _, nv := range delNV {
del = append(del, *parentLayerFeaturesMapNV[nv])
}
}
// Insert FeatureVersions in the database.
addIDs, err := mySQL.insertFeatureVersions(add)
if err != nil {
return err
}
delIDs, err := mySQL.insertFeatureVersions(del)
if err != nil {
return err
}
// Insert diff in the database.
if len(addIDs) > 0 {
insertLayerDiffFeatureVersionStmt := fmt.Sprintf(insertLayerDiffFeatureVersion, buildInputArray(addIDs))
_, err = tx.Exec(insertLayerDiffFeatureVersionStmt, layer.ID, "add")
if err != nil {
return handleError("insertLayerDiffFeatureVersion.Add", err)
}
}
if len(delIDs) > 0 {
insertLayerDiffFeatureVersionStmt := fmt.Sprintf(insertLayerDiffFeatureVersion, buildInputArray(delIDs))
_, err = tx.Exec(insertLayerDiffFeatureVersionStmt, layer.ID, "del")
if err != nil {
return handleError("insertLayerDiffFeatureVersion.Del", err)
}
}
return nil
}
func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) {
mapNV := make(map[string]*database.FeatureVersion, 0)
sliceNV := make([]string, 0, len(features))
for i := 0; i < len(features); i++ {
featureVersion := &features[i]
nv := featureVersion.Feature.Name + ":" + featureVersion.Version.String()
mapNV[nv] = featureVersion
sliceNV = append(sliceNV, nv)
}
return mapNV, sliceNV
}
func (mySQL *mySQL) DeleteLayer(name string) error {
defer database.ObserveQueryTime("DeleteLayer", "all", time.Now())
result, err := mySQL.Exec(removeLayer, name)
if err != nil {
return handleError("removeLayer", err)
}
affected, err := result.RowsAffected()
if err != nil {
return handleError("removeLayer.RowsAffected()", err)
}
if affected <= 0 {
return cerrors.ErrNotFound
}
return nil
}

View File

@ -0,0 +1,350 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"fmt"
"testing"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
)
func TestFindLayer(t *testing.T) {
datastore, err := OpenForTest("FindLayer", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Layer-0: no parent, no namespace, no feature, no vulnerability
layer, err := datastore.FindLayer("layer-0", false, false)
if assert.Nil(t, err) && assert.NotNil(t, layer) {
assert.Equal(t, "layer-0", layer.Name)
assert.Nil(t, layer.Namespace)
assert.Nil(t, layer.Parent)
assert.Equal(t, 1, layer.EngineVersion)
assert.Len(t, layer.Features, 0)
}
layer, err = datastore.FindLayer("layer-0", true, false)
if assert.Nil(t, err) && assert.NotNil(t, layer) {
assert.Len(t, layer.Features, 0)
}
// Layer-1: one parent, adds two features, one vulnerability
layer, err = datastore.FindLayer("layer-1", false, false)
if assert.Nil(t, err) && assert.NotNil(t, layer) {
assert.Equal(t, layer.Name, "layer-1")
assert.Equal(t, "debian:7", layer.Namespace.Name)
if assert.NotNil(t, layer.Parent) {
assert.Equal(t, "layer-0", layer.Parent.Name)
}
assert.Equal(t, 1, layer.EngineVersion)
assert.Len(t, layer.Features, 0)
}
layer, err = datastore.FindLayer("layer-1", true, false)
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
for _, featureVersion := range layer.Features {
assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name)
switch featureVersion.Feature.Name {
case "wechat":
assert.Equal(t, types.NewVersionUnsafe("0.5"), featureVersion.Version)
case "openssl":
assert.Equal(t, types.NewVersionUnsafe("1.0"), featureVersion.Version)
default:
t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name)
}
}
}
layer, err = datastore.FindLayer("layer-1", true, true)
if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) {
for _, featureVersion := range layer.Features {
assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name)
switch featureVersion.Feature.Name {
case "wechat":
assert.Equal(t, types.NewVersionUnsafe("0.5"), featureVersion.Version)
case "openssl":
assert.Equal(t, types.NewVersionUnsafe("1.0"), featureVersion.Version)
if assert.Len(t, featureVersion.AffectedBy, 1) {
assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name)
assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name)
assert.Equal(t, types.High, featureVersion.AffectedBy[0].Severity)
assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description)
assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link)
assert.Equal(t, types.NewVersionUnsafe("2.0"), featureVersion.AffectedBy[0].FixedBy)
}
default:
t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name)
}
}
}
}
func TestInsertLayer(t *testing.T) {
datastore, err := OpenForTest("InsertLayer", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Insert invalid layer.
testInsertLayerInvalid(t, datastore)
// Insert a layer tree.
testInsertLayerTree(t, datastore)
// Update layer.
testInsertLayerUpdate(t, datastore)
// Delete layer.
testInsertLayerDelete(t, datastore)
}
func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) {
invalidLayers := []database.Layer{
{},
{Name: "layer0", Parent: &database.Layer{}},
{Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}},
}
for _, invalidLayer := range invalidLayers {
err := datastore.InsertLayer(invalidLayer)
assert.Error(t, err)
}
}
func testInsertLayerTree(t *testing.T, datastore database.Datastore) {
f1 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace2"},
Name: "TestInsertLayerFeature1",
},
Version: types.NewVersionUnsafe("1.0"),
}
f2 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace2"},
Name: "TestInsertLayerFeature2",
},
Version: types.NewVersionUnsafe("0.34"),
}
f3 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace2"},
Name: "TestInsertLayerFeature3",
},
Version: types.NewVersionUnsafe("0.56"),
}
f4 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"},
Name: "TestInsertLayerFeature2",
},
Version: types.NewVersionUnsafe("0.34"),
}
f5 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"},
Name: "TestInsertLayerFeature3",
},
Version: types.NewVersionUnsafe("0.57"),
}
f6 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"},
Name: "TestInsertLayerFeature4",
},
Version: types.NewVersionUnsafe("0.666"),
}
layers := []database.Layer{
{
Name: "TestInsertLayer1",
},
{
Name: "TestInsertLayer2",
Parent: &database.Layer{Name: "TestInsertLayer1"},
Namespace: &database.Namespace{Name: "TestInsertLayerNamespace1"},
},
// This layer changes the namespace and adds Features.
{
Name: "TestInsertLayer3",
Parent: &database.Layer{Name: "TestInsertLayer2"},
Namespace: &database.Namespace{Name: "TestInsertLayerNamespace2"},
Features: []database.FeatureVersion{f1, f2, f3},
},
// This layer covers the case where the last layer doesn't provide any new Feature.
{
Name: "TestInsertLayer4a",
Parent: &database.Layer{Name: "TestInsertLayer3"},
Features: []database.FeatureVersion{f1, f2, f3},
},
// This layer covers the case where the last layer provides Features.
// It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their
// Namespaces should then remain unchanged.
{
Name: "TestInsertLayer4b",
Parent: &database.Layer{Name: "TestInsertLayer3"},
Namespace: &database.Namespace{Name: "TestInsertLayerNamespace3"},
Features: []database.FeatureVersion{
// Deletes TestInsertLayerFeature1.
// Keep TestInsertLayerFeature2 (old Namespace should be kept):
f4,
// Upgrades TestInsertLayerFeature3 (with new Namespace):
f5,
// Adds TestInsertLayerFeature4:
f6,
},
},
}
var err error
retrievedLayers := make(map[string]database.Layer)
for _, layer := range layers {
if layer.Parent != nil {
// Retrieve from database its parent and assign.
parent := retrievedLayers[layer.Parent.Name]
layer.Parent = &parent
}
err = datastore.InsertLayer(layer)
assert.Nil(t, err)
retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false)
assert.Nil(t, err)
}
l4a := retrievedLayers["TestInsertLayer4a"]
if assert.NotNil(t, l4a.Namespace) {
assert.Equal(t, "TestInsertLayerNamespace2", l4a.Namespace.Name)
}
assert.Len(t, l4a.Features, 3)
for _, featureVersion := range l4a.Features {
if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) {
assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3))
}
}
l4b := retrievedLayers["TestInsertLayer4b"]
if assert.NotNil(t, l4b.Namespace) {
assert.Equal(t, "TestInsertLayerNamespace3", l4b.Namespace.Name)
}
assert.Len(t, l4b.Features, 3)
for _, featureVersion := range l4b.Features {
if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) {
assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6))
}
}
}
func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) {
f7 := database.FeatureVersion{
Feature: database.Feature{
Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"},
Name: "TestInsertLayerFeature7",
},
Version: types.NewVersionUnsafe("0.01"),
}
l3, _ := datastore.FindLayer("TestInsertLayer3", true, false)
l3u := database.Layer{
Name: l3.Name,
Parent: l3.Parent,
Namespace: &database.Namespace{Name: "TestInsertLayerNamespaceUpdated1"},
Features: []database.FeatureVersion{f7},
}
l4u := database.Layer{
Name: "TestInsertLayer4",
Parent: &database.Layer{Name: "TestInsertLayer3"},
Features: []database.FeatureVersion{f7},
EngineVersion: 2,
}
// Try to re-insert without increasing the EngineVersion.
err := datastore.InsertLayer(l3u)
assert.Nil(t, err)
l3uf, err := datastore.FindLayer(l3u.Name, true, false)
if assert.Nil(t, err) {
assert.Equal(t, l3.Namespace.Name, l3uf.Namespace.Name)
assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion)
assert.Len(t, l3uf.Features, len(l3.Features))
}
// Update layer l3.
// Verify that the Namespace, EngineVersion and FeatureVersions got updated.
l3u.EngineVersion = 2
err = datastore.InsertLayer(l3u)
assert.Nil(t, err)
l3uf, err = datastore.FindLayer(l3u.Name, true, false)
if assert.Nil(t, err) {
assert.Equal(t, l3u.Namespace.Name, l3uf.Namespace.Name)
assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion)
if assert.Len(t, l3uf.Features, 1) {
assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0])
}
}
// Update layer l4.
// Verify that the Namespace got updated from its new Parent's, and also verify the
// EnginVersion and FeatureVersions.
l4u.Parent = &l3uf
err = datastore.InsertLayer(l4u)
assert.Nil(t, err)
l4uf, err := datastore.FindLayer(l3u.Name, true, false)
if assert.Nil(t, err) {
assert.Equal(t, l3u.Namespace.Name, l4uf.Namespace.Name)
assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion)
if assert.Len(t, l4uf.Features, 1) {
assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0])
}
}
}
func testInsertLayerDelete(t *testing.T, datastore database.Datastore) {
err := datastore.DeleteLayer("TestInsertLayerX")
assert.Equal(t, cerrors.ErrNotFound, err)
err = datastore.DeleteLayer("TestInsertLayer3")
assert.Nil(t, err)
_, err = datastore.FindLayer("TestInsertLayer3", false, false)
assert.Equal(t, cerrors.ErrNotFound, err)
_, err = datastore.FindLayer("TestInsertLayer4a", false, false)
assert.Equal(t, cerrors.ErrNotFound, err)
_, err = datastore.FindLayer("TestInsertLayer4b", true, false)
assert.Equal(t, cerrors.ErrNotFound, err)
}
func cmpFV(a, b database.FeatureVersion) bool {
return a.Feature.Name == b.Feature.Name &&
a.Feature.Namespace.Name == b.Feature.Namespace.Name &&
a.Version.String() == b.Version.String()
}

106
database/mysql/lock.go Normal file
View File

@ -0,0 +1,106 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
)
// Lock tries to set a temporary lock in the database.
//
// Lock does not block, instead, it returns true and its expiration time
// is the lock has been successfully acquired or false otherwise
func (mySQL *mySQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) {
if name == "" || owner == "" || duration == 0 {
log.Warning("could not create an invalid lock")
return false, time.Time{}
}
defer database.ObserveQueryTime("Lock", "all", time.Now())
// Compute expiration.
until := time.Now().Add(duration)
if renew {
// Renew lock.
r, err := mySQL.Exec(updateLock, until, name, owner)
if err != nil {
handleError("updateLock", err)
return false, until
}
if n, _ := r.RowsAffected(); n > 0 {
// Updated successfully.
return true, until
}
} else {
// Prune locks.
mySQL.pruneLocks()
}
// Lock.
_, err := mySQL.Exec(insertLock, name, owner, until)
if err != nil {
if !isErrUniqueViolation(err) {
handleError("insertLock", err)
}
return false, until
}
return true, until
}
// Unlock unlocks a lock specified by its name if I own it
func (mySQL *mySQL) Unlock(name, owner string) {
if name == "" || owner == "" {
log.Warning("could not delete an invalid lock")
return
}
defer database.ObserveQueryTime("Unlock", "all", time.Now())
mySQL.Exec(removeLock, name, owner)
}
// FindLock returns the owner of a lock specified by its name and its
// expiration time.
func (mySQL *mySQL) FindLock(name string) (string, time.Time, error) {
if name == "" {
log.Warning("could not find an invalid lock")
return "", time.Time{}, cerrors.NewBadRequestError("could not find an invalid lock")
}
defer database.ObserveQueryTime("FindLock", "all", time.Now())
var owner string
var until time.Time
err := mySQL.QueryRow(searchLock, name).Scan(&owner, &until)
if err != nil {
return owner, until, handleError("searchLock", err)
}
return owner, until, nil
}
// pruneLocks removes every expired locks from the database
func (mySQL *mySQL) pruneLocks() {
defer database.ObserveQueryTime("pruneLocks", "all", time.Now())
if _, err := mySQL.Exec(removeLockExpired); err != nil {
handleError("removeLockExpired", err)
}
}

View File

@ -0,0 +1,69 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func TestLock(t *testing.T) {
datastore, err := OpenForTest("InsertNamespace", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
var l bool
var et time.Time
// Create a first lock.
l, _ = datastore.Lock("test1", "owner1", time.Minute, false)
assert.True(t, l)
// Try to lock the same lock with another owner.
l, _ = datastore.Lock("test1", "owner2", time.Minute, true)
assert.False(t, l)
l, _ = datastore.Lock("test1", "owner2", time.Minute, false)
assert.False(t, l)
// Renew the lock.
l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true)
assert.True(t, l)
// Unlock and then relock by someone else.
datastore.Unlock("test1", "owner1")
l, et = datastore.Lock("test1", "owner2", time.Minute, false)
assert.True(t, l)
// LockInfo
o, et2, err := datastore.FindLock("test1")
assert.Nil(t, err)
assert.Equal(t, "owner2", o)
assert.Equal(t, et.Second(), et2.Second())
// Create a second lock which is actually already expired ...
l, _ = datastore.Lock("test2", "owner1", -time.Minute, false)
assert.True(t, l)
// Take over the lock
l, _ = datastore.Lock("test2", "owner2", time.Minute, false)
assert.True(t, l)
}

View File

@ -0,0 +1,177 @@
-- Copyright 2015 clair authors
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
-- +goose Up
-- -----------------------------------------------------
-- Table Namespace
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Namespace (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NULL);
-- -----------------------------------------------------
-- Table Layer
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Layer (
id SERIAL PRIMARY KEY,
name VARCHAR(128) NOT NULL UNIQUE,
engineversion SMALLINT NOT NULL,
parent_id BIGINT UNSIGNED NULL,
namespace_id BIGINT UNSIGNED NULL,
created_at TIMESTAMP,
FOREIGN KEY(parent_id) REFERENCES Layer(id) ON DELETE CASCADE,
FOREIGN KEY(namespace_id) REFERENCES Namespace(id),
INDEX (parent_id),
INDEX (namespace_id));
-- -----------------------------------------------------
-- Table Feature
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Feature (
id SERIAL PRIMARY KEY,
namespace_id BIGINT UNSIGNED NOT NULL,
name VARCHAR(128) NOT NULL,
FOREIGN KEY(namespace_id) REFERENCES Namespace(ID),
UNIQUE (namespace_id, name));
-- -----------------------------------------------------
-- Table FeatureVersion
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS FeatureVersion (
id SERIAL PRIMARY KEY,
feature_id BIGINT UNSIGNED NOT NULL,
version VARCHAR(128) NOT NULL,
FOREIGN KEY(feature_id) REFERENCES Feature(id),
INDEX (feature_id));
-- -----------------------------------------------------
-- Table Layer_diff_FeatureVersion
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion (
id SERIAL PRIMARY KEY,
layer_id BIGINT UNSIGNED NOT NULL ,
featureversion_id BIGINT UNSIGNED NOT NULL ,
modification ENUM('add', 'del') NOT NULL,
FOREIGN KEY (layer_id) REFERENCES Layer(id) ON DELETE CASCADE,
FOREIGN KEY (featureversion_id) REFERENCES FeatureVersion(id),
INDEX (layer_id),
INDEX (featureversion_id),
InDEX (featureversion_id, layer_id),
UNIQUE (layer_id, featureversion_id));
-- -----------------------------------------------------
-- Table Vulnerability
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability (
id SERIAL PRIMARY KEY,
namespace_id INT NOT NULL REFERENCES Namespace,
name VARCHAR(128) NOT NULL,
description TEXT NULL,
link VARCHAR(128) NULL,
severity ENUM('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1') NOT NULL,
metadata TEXT NULL,
created_at TIMESTAMP,
deleted_at TIMESTAMP NULL);
-- -----------------------------------------------------
-- Table Vulnerability_FixedIn_Feature
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature (
id SERIAL PRIMARY KEY,
vulnerability_id BIGINT UNSIGNED NOT NULL,
feature_id BIGINT UNSIGNED NOT NULL,
version VARCHAR(128) NOT NULL,
INDEX (feature_id, vulnerability_id),
FOREIGN KEY (vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE,
FOREIGN KEY (feature_id) REFERENCES Feature(id),
UNIQUE (vulnerability_id, feature_id));
-- -----------------------------------------------------
-- Table Vulnerability_Affects_FeatureVersion
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion (
id SERIAL PRIMARY KEY,
vulnerability_id BIGINT UNSIGNED NOT NULL,
featureversion_id BIGINT UNSIGNED NOT NULL,
fixedin_id BIGINT UNSIGNED NOT NULL,
INDEX (fixedin_id),
INDEX (featureversion_id, vulnerability_id),
FOREIGN KEY (vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE,
FOREIGN KEY (fixedin_id) REFERENCES Vulnerability_FixedIn_Feature (id) ON DELETE CASCADE,
FOREIGN KEY (featureversion_id) REFERENCES FeatureVersion (id),
UNIQUE (vulnerability_id, featureversion_id));
-- -----------------------------------------------------
-- Table KeyValue
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS KeyValue (
id SERIAL PRIMARY KEY,
`key` VARCHAR(128) NOT NULL UNIQUE,
`value` TEXT);
-- -----------------------------------------------------
-- Table Lock
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS `Lock` (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
owner VARCHAR(64) NOT NULL,
until TIMESTAMP,
INDEX (owner));
-- -----------------------------------------------------
-- Table VulnerabilityNotification
-- -----------------------------------------------------
CREATE TABLE IF NOT EXISTS Vulnerability_Notification (
id SERIAL PRIMARY KEY,
name VARCHAR(64) NOT NULL UNIQUE,
created_at TIMESTAMP ,
notified_at TIMESTAMP NULL,
deleted_at TIMESTAMP NULL,
old_vulnerability_id BIGINT UNSIGNED NULL,
new_vulnerability_id BIGINT UNSIGNED NULL,
FOREIGN KEY (old_vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE,
FOREIGN KEY (new_vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE,
INDEX (notified_at));
-- +goose Down
DROP TABLE IF EXISTS Namespace,
Layer,
Feature,
FeatureVersion,
Layer_diff_FeatureVersion,
Vulnerability,
Vulnerability_FixedIn_Feature,
Vulnerability_Affects_FeatureVersion,
Vulnerability_Notification,
`KeyValue`,
`Lock`
CASCADE;

264
database/mysql/mysql.go Normal file
View File

@ -0,0 +1,264 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package mysql implements database.Datastore with MySQL.
package mysql
import (
"database/sql"
"fmt"
"io/ioutil"
"os"
"path"
"runtime"
"strings"
"bitbucket.org/liamstask/goose/lib/goose"
"github.com/coreos/clair/config"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/pkg/capnslog"
_ "github.com/go-sql-driver/mysql"
"github.com/hashicorp/golang-lru"
"github.com/pborman/uuid"
)
var (
log = capnslog.NewPackageLogger("github.com/coreos/clair", "mysql")
)
const DATABASENAME = "clair"
const DEFAULTFLAG = "?charset=utf8&parseTime=True"
const DEFAULTSOURCE = DATABASENAME + DEFAULTFLAG
type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
Exec(query string, args ...interface{}) (sql.Result, error)
}
type mySQL struct {
*sql.DB
cache *lru.ARCCache
}
func (mySQL *mySQL) Close() {
mySQL.DB.Close()
}
func (mySQL *mySQL) Ping() bool {
return mySQL.DB.Ping() == nil
}
// Open creates a Datastore backed by a PostgreSQL database.
// It will run immediately every necessary migration on the database.
func Open(config *config.DatabaseConfig) (database.Datastore, error) {
source := config.Source
if strings.HasPrefix(source, "mysql://") {
source = strings.TrimPrefix(source, "mysql://")
}
config.Source = source + DEFAULTSOURCE
// Create Database if not exists
err := createDatabase(source, DATABASENAME)
if err != nil {
log.Error(err)
return nil, database.ErrCantOpen
}
return open(config)
}
func open(config *config.DatabaseConfig) (database.Datastore, error) {
// Run migrations.
if err := migrate(config.Source); err != nil {
log.Error(err)
return nil, database.ErrCantOpen
}
// Open database.
db, err := sql.Open("mysql", config.Source)
if err != nil {
log.Error(err)
return nil, database.ErrCantOpen
}
// Initialize cache.
// TODO(Quentin-M): Benchmark with a simple LRU Cache.
var cache *lru.ARCCache
if config.CacheSize > 0 {
cache, _ = lru.NewARC(config.CacheSize)
}
return &mySQL{DB: db, cache: cache}, nil
}
// migrate runs all available migrations on a pgSQL database.
func migrate(dataSource string) error {
log.Info("running database migrations")
_, filename, _, _ := runtime.Caller(1)
migrationDir := path.Join(path.Dir(filename), "/migrations/")
conf := &goose.DBConf{
MigrationsDir: migrationDir,
Driver: goose.DBDriver{
Name: "mysql",
OpenStr: dataSource,
Import: "github.com/go-sql-driver/mysql",
Dialect: &goose.MySqlDialect{},
},
}
// Determine the most recent revision available from the migrations folder.
target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir)
if err != nil {
return err
}
// Run migrations
err = goose.RunMigrations(conf, conf.MigrationsDir, target)
if err != nil {
return err
}
log.Info("database migration ran successfully")
return nil
}
// TODO:
// createDatabase creates a new database.
// The dataSource parameter should not contain a dbname.
func createDatabase(dataSource, databaseName string) error {
// Open database.
log.Info("Create database: ", databaseName)
db, err := sql.Open("mysql", dataSource)
if err != nil {
return fmt.Errorf("could not open database (CreateDatabase): %v", err)
}
defer db.Close()
// Create database.
_, err = db.Exec("CREATE DATABASE IF NOT EXISTS " + databaseName + " DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci")
if err != nil {
return fmt.Errorf("could not create database: %v", err)
}
return nil
}
// dropDatabase drops an existing database.
// The dataSource parameter should not contain a dbname.
func dropDatabase(dataSource, databaseName string) error {
// Open database.
db, err := sql.Open("mysql", dataSource)
if err != nil {
return fmt.Errorf("could not open database (DropDatabase): %v", err)
}
defer db.Close()
// Drop database.
if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil {
return fmt.Errorf("could not drop database: %v", err)
}
return nil
}
// TODO
// pgSQLTest wraps pgSQL for testing purposes.
// Its Close() method drops the database.
type pgSQLTest struct {
*mySQL
dataSourceDefaultDatabase string
dbName string
}
// OpenForTest creates a test Datastore backed by a new PostgreSQL database.
// It creates a new unique and prefixed ("test_") database.
// Using Close() will drop the database.
func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
// Define the PostgreSQL connection strings.
dataSource := "root@tcp(127.0.0.1:3306)/"
if dataSourceEnv := os.Getenv("CLAIR_TEST_MYSQL"); dataSourceEnv != "" {
dataSource = dataSourceEnv
}
dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1)
dataSourceDefaultDatabase := dataSource
dataSourceTestDatabase := dataSource + dbName + "?charset=utf8&parseTime=True"
// Create database.
if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil {
log.Error(err)
return nil, database.ErrCantOpen
}
// Open database.
db, err := open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0})
if err != nil {
dropDatabase(dataSourceDefaultDatabase, dbName)
log.Error(err)
return nil, database.ErrCantOpen
}
// Load test data if specified.
if withTestData {
_, filename, _, _ := runtime.Caller(0)
d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql")
queries := strings.Split(fmt.Sprintf("%s", d), ";")
for _, q := range queries {
_, err = db.(*mySQL).Exec(q)
if err != nil {
dropDatabase(dataSourceDefaultDatabase, dbName)
log.Error(err)
return nil, database.ErrCantOpen
}
}
}
return &pgSQLTest{
mySQL: db.(*mySQL),
dataSourceDefaultDatabase: dataSourceDefaultDatabase,
dbName: dbName}, nil
}
func (pgSQL *pgSQLTest) Close() {
pgSQL.DB.Close()
dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName)
}
// handleError logs an error with an extra description and masks the error if it's an SQL one.
// This ensures we never return plain SQL errors and leak anything.
func handleError(desc string, err error) error {
if err == nil {
return nil
}
if err == sql.ErrNoRows {
return cerrors.ErrNotFound
}
log.Errorf("%s: %v", desc, err)
database.PromErrorsTotal.WithLabelValues(desc).Inc()
if err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
return database.ErrBackendException
}
return err
}
// isErrUniqueViolation determines is the given error is a duplicate entry error.
func isErrUniqueViolation(err error) bool {
if strings.Contains(fmt.Sprintf("%v", err), "Error 1062") {
return true
}
return false
}

125
database/mysql/namespace.go Normal file
View File

@ -0,0 +1,125 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
)
func (mySQL *mySQL) insertNamespaceWithTransaction(queryer Queryer, namespace database.Namespace) (int, error) {
if namespace.Name == "" {
return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace")
}
if mySQL.cache != nil {
database.PromCacheQueriesTotal.WithLabelValues("namespace").Inc()
if id, found := mySQL.cache.Get("namespace:" + namespace.Name); found {
database.PromCacheHitsTotal.WithLabelValues("namespace").Inc()
return id.(int), nil
}
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached namespaces.
defer database.ObserveQueryTime("insertNamespace", "all", time.Now())
var id int
res, err := queryer.Exec(insertNamespace, namespace.Name, namespace.Name)
if err != nil {
return 0, handleError("insertNamespace", err)
}
tmpid, err := res.LastInsertId()
if err != nil {
return 0, handleError("insertNamespace", err)
}
id = int(tmpid)
if id == 0 {
err = queryer.QueryRow(soiNamespace, namespace.Name).Scan(&id)
if err != nil {
return 0, handleError("soiNamespace", err)
}
}
if mySQL.cache != nil {
mySQL.cache.Add("namespace:"+namespace.Name, id)
}
return id, nil
}
func (mySQL *mySQL) insertNamespace(namespace database.Namespace) (int, error) {
if namespace.Name == "" {
return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace")
}
if mySQL.cache != nil {
database.PromCacheQueriesTotal.WithLabelValues("namespace").Inc()
if id, found := mySQL.cache.Get("namespace:" + namespace.Name); found {
database.PromCacheHitsTotal.WithLabelValues("namespace").Inc()
return id.(int), nil
}
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached namespaces.
defer database.ObserveQueryTime("insertNamespace", "all", time.Now())
var id int
res, err := mySQL.Exec(insertNamespace, namespace.Name, namespace.Name)
if err != nil {
return 0, handleError("insertNamespace", err)
}
tmpid, err := res.LastInsertId()
if err != nil {
return 0, handleError("insertNamespace", err)
}
id = int(tmpid)
if id == 0 {
err = mySQL.QueryRow(soiNamespace, namespace.Name).Scan(&id)
if err != nil {
return 0, handleError("soiNamespace", err)
}
}
if mySQL.cache != nil {
mySQL.cache.Add("namespace:"+namespace.Name, id)
}
return id, nil
}
func (mySQL *mySQL) ListNamespaces() (namespaces []database.Namespace, err error) {
rows, err := mySQL.Query(listNamespace)
if err != nil {
return namespaces, handleError("listNamespace", err)
}
defer rows.Close()
for rows.Next() {
var namespace database.Namespace
err = rows.Scan(&namespace.ID, &namespace.Name)
if err != nil {
return namespaces, handleError("listNamespace.Scan()", err)
}
namespaces = append(namespaces, namespace)
}
if err = rows.Err(); err != nil {
return namespaces, handleError("listNamespace.Rows()", err)
}
return namespaces, err
}

View File

@ -0,0 +1,66 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"fmt"
"testing"
"github.com/coreos/clair/database"
"github.com/stretchr/testify/assert"
)
func TestInsertNamespace(t *testing.T) {
datastore, err := OpenForTest("InsertNamespace", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Invalid Namespace.
id0, err := datastore.insertNamespace(database.Namespace{})
assert.NotNil(t, err)
assert.Zero(t, id0)
// Insert Namespace and ensure we can find it.
id1, err := datastore.insertNamespace(database.Namespace{Name: "TestInsertNamespace1"})
assert.Nil(t, err)
id2, err := datastore.insertNamespace(database.Namespace{Name: "TestInsertNamespace1"})
assert.Nil(t, err)
assert.Equal(t, id1, id2)
}
func TestListNamespace(t *testing.T) {
datastore, err := OpenForTest("ListNamespaces", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
namespaces, err := datastore.ListNamespaces()
assert.Nil(t, err)
if assert.Len(t, namespaces, 2) {
for _, namespace := range namespaces {
switch namespace.Name {
case "debian:7", "debian:8":
continue
default:
assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name))
}
}
}
}

View File

@ -0,0 +1,214 @@
package mysql
import (
"database/sql"
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/guregu/null/zero"
"github.com/pborman/uuid"
)
// do it in tx so we won't insert/update a vuln without notification and vice-versa.
// name and created doesn't matter.
func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error {
defer database.ObserveQueryTime("createNotification", "all", time.Now())
// Insert Notification.
oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0}
newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0}
_, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID)
if err != nil {
tx.Rollback()
return handleError("insertNotification", err)
}
return nil
}
// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)).
// Does not fill new/old vuln.
func (mySQL *mySQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) {
defer database.ObserveQueryTime("GetAvailableNotification", "all", time.Now())
before := time.Now().Add(-renotifyInterval)
row := mySQL.QueryRow(searchNotificationAvailable, before)
notification, err := mySQL.scanNotification(row, false)
return notification, handleError("searchNotificationAvailable", err)
}
func (mySQL *mySQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) {
defer database.ObserveQueryTime("GetNotification", "all", time.Now())
// Get Notification.
notification, err := mySQL.scanNotification(mySQL.QueryRow(searchNotification, name), true)
if err != nil {
return notification, page, handleError("searchNotification", err)
}
// Load vulnerabilities' LayersIntroducingVulnerability.
page.OldVulnerability, err = mySQL.loadLayerIntroducingVulnerability(
notification.OldVulnerability,
limit,
page.OldVulnerability,
)
if err != nil {
return notification, page, err
}
page.NewVulnerability, err = mySQL.loadLayerIntroducingVulnerability(
notification.NewVulnerability,
limit,
page.NewVulnerability,
)
if err != nil {
return notification, page, err
}
return notification, page, nil
}
func (mySQL *mySQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) {
var notification database.VulnerabilityNotification
var created zero.Time
var notified zero.Time
var deleted zero.Time
var oldVulnerabilityNullableID sql.NullInt64
var newVulnerabilityNullableID sql.NullInt64
// Scan notification.
if hasVulns {
err := row.Scan(
&notification.ID,
&notification.Name,
&created,
&notified,
&deleted,
&oldVulnerabilityNullableID,
&newVulnerabilityNullableID,
)
if err != nil {
return notification, err
}
} else {
err := row.Scan(&notification.ID, &notification.Name, &created, &notified, &deleted)
if err != nil {
return notification, err
}
}
notification.Created = created.Time
notification.Notified = notified.Time
notification.Deleted = deleted.Time
if hasVulns {
if oldVulnerabilityNullableID.Valid {
vulnerability, err := mySQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64))
if err != nil {
return notification, err
}
notification.OldVulnerability = &vulnerability
}
if newVulnerabilityNullableID.Valid {
vulnerability, err := mySQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64))
if err != nil {
return notification, err
}
notification.NewVulnerability = &vulnerability
}
}
return notification, nil
}
// Fills Vulnerability.LayersIntroducingVulnerability.
// limit -1: won't do anything
// limit 0: will just get the startID of the second page
func (mySQL *mySQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) {
tf := time.Now()
if vulnerability == nil {
return -1, nil
}
// A startID equals to -1 means that we reached the end already.
if startID == -1 || limit == -1 {
return -1, nil
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe invalid calls.
defer database.ObserveQueryTime("loadLayerIntroducingVulnerability", "all", tf)
// Query with limit + 1, the last item will be used to know the next starting ID.
rows, err := mySQL.Query(searchNotificationLayerIntroducingVulnerability,
vulnerability.ID, startID, limit+1)
if err != nil {
return 0, handleError("searchNotificationLayerIntroducingVulnerability", err)
}
defer rows.Close()
var layers []database.Layer
for rows.Next() {
var layer database.Layer
if err := rows.Scan(&layer.ID, &layer.Name); err != nil {
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err)
}
layers = append(layers, layer)
}
if err = rows.Err(); err != nil {
return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err)
}
size := limit
if len(layers) < limit {
size = len(layers)
}
vulnerability.LayersIntroducingVulnerability = layers[:size]
nextID := -1
if len(layers) > limit {
nextID = layers[limit].ID
}
return nextID, nil
}
func (mySQL *mySQL) SetNotificationNotified(name string) error {
defer database.ObserveQueryTime("SetNotificationNotified", "all", time.Now())
if _, err := mySQL.Exec(updatedNotificationNotified, name); err != nil {
return handleError("updatedNotificationNotified", err)
}
return nil
}
func (mySQL *mySQL) DeleteNotification(name string) error {
defer database.ObserveQueryTime("DeleteNotification", "all", time.Now())
result, err := mySQL.Exec(removeNotification, name)
if err != nil {
return handleError("removeNotification", err)
}
affected, err := result.RowsAffected()
if err != nil {
return handleError("removeNotification.RowsAffected()", err)
}
if affected <= 0 {
return cerrors.ErrNotFound
}
return nil
}

View File

@ -0,0 +1,209 @@
package mysql
import (
"testing"
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
)
func TestNotification(t *testing.T) {
datastore, err := OpenForTest("Notification", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Try to get a notification when there is none.
_, err = datastore.GetAvailableNotification(time.Second)
assert.Equal(t, cerrors.ErrNotFound, err)
// Create some data.
f1 := database.Feature{
Name: "TestNotificationFeature1",
Namespace: database.Namespace{Name: "TestNotificationNamespace1"},
}
f2 := database.Feature{
Name: "TestNotificationFeature2",
Namespace: database.Namespace{Name: "TestNotificationNamespace1"},
}
l1 := database.Layer{
Name: "TestNotificationLayer1",
Features: []database.FeatureVersion{
{
Feature: f1,
Version: types.NewVersionUnsafe("0.1"),
},
},
}
l2 := database.Layer{
Name: "TestNotificationLayer2",
Features: []database.FeatureVersion{
{
Feature: f1,
Version: types.NewVersionUnsafe("0.2"),
},
},
}
l3 := database.Layer{
Name: "TestNotificationLayer3",
Features: []database.FeatureVersion{
{
Feature: f1,
Version: types.NewVersionUnsafe("0.3"),
},
},
}
l4 := database.Layer{
Name: "TestNotificationLayer4",
Features: []database.FeatureVersion{
{
Feature: f2,
Version: types.NewVersionUnsafe("0.1"),
},
},
}
if !assert.Nil(t, datastore.InsertLayer(l1)) ||
!assert.Nil(t, datastore.InsertLayer(l2)) ||
!assert.Nil(t, datastore.InsertLayer(l3)) ||
!assert.Nil(t, datastore.InsertLayer(l4)) {
return
}
// Insert a new vulnerability that is introduced by three layers.
v1 := database.Vulnerability{
Name: "TestNotificationVulnerability1",
Namespace: f1.Namespace,
Description: "TestNotificationDescription1",
Link: "TestNotificationLink1",
Severity: "Unknown",
FixedIn: []database.FeatureVersion{
{
Feature: f1,
Version: types.NewVersionUnsafe("1.0"),
},
},
}
assert.Nil(t, datastore.insertVulnerability(v1, false, true))
// Get the notification associated to the previously inserted vulnerability.
notification, err := datastore.GetAvailableNotification(time.Second)
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
// Verify the renotify behaviour.
if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) {
_, err := datastore.GetAvailableNotification(time.Second)
assert.Equal(t, cerrors.ErrNotFound, err)
time.Sleep(50 * time.Millisecond)
notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond)
assert.Nil(t, err)
assert.Equal(t, notification.Name, notificationB.Name)
datastore.SetNotificationNotified(notification.Name)
}
// Get notification.
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
if assert.Nil(t, err) {
assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage)
assert.Nil(t, filledNotification.OldVulnerability)
if assert.NotNil(t, filledNotification.NewVulnerability) {
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2)
}
}
// Get second page.
filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage)
if assert.Nil(t, err) {
assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage)
assert.Nil(t, filledNotification.OldVulnerability)
if assert.NotNil(t, filledNotification.NewVulnerability) {
assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
}
}
// Delete notification.
assert.Nil(t, datastore.DeleteNotification(notification.Name))
_, err = datastore.GetAvailableNotification(time.Millisecond)
assert.Equal(t, cerrors.ErrNotFound, err)
}
// Update a vulnerability and ensure that the old/new vulnerabilities are correct.
v1b := v1
v1b.Severity = types.High
v1b.FixedIn = []database.FeatureVersion{
{
Feature: f1,
Version: types.MinVersion,
},
{
Feature: f2,
Version: types.MaxVersion,
},
}
if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) {
notification, err = datastore.GetAvailableNotification(time.Second)
assert.Nil(t, err)
assert.NotEmpty(t, notification.Name)
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
if assert.Nil(t, err) {
if assert.NotNil(t, filledNotification.OldVulnerability) {
assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name)
assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity)
assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2)
}
if assert.NotNil(t, filledNotification.NewVulnerability) {
assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name)
assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity)
assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1)
}
assert.Equal(t, -1, nextPage.NewVulnerability)
}
assert.Nil(t, datastore.DeleteNotification(notification.Name))
}
}
// Delete a vulnerability and verify the notification.
if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) {
notification, err = datastore.GetAvailableNotification(time.Second)
assert.Nil(t, err)
assert.NotEmpty(t, notification.Name)
if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) {
filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage)
if assert.Nil(t, err) {
assert.Nil(t, filledNotification.NewVulnerability)
if assert.NotNil(t, filledNotification.OldVulnerability) {
assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name)
assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity)
assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1)
}
}
assert.Nil(t, datastore.DeleteNotification(notification.Name))
}
}
}

204
database/mysql/queries.go Normal file
View File

@ -0,0 +1,204 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import "strconv"
const (
// This is for lock the table
lockVulnerabilityAffects = `select count(*) from Vulnerability_Affects_FeatureVersion where vulnerability_id > 0 for update`
// keyvalue.go
updateKeyValue = "UPDATE KeyValue SET `value` = ? WHERE `key` = ?"
insertKeyValue = "INSERT INTO KeyValue(`key`, `value`) VALUES(?, ?)"
searchKeyValue = "SELECT `value` FROM KeyValue WHERE `key` = ?"
// namespace.go
insertNamespace = `insert into Namespace(name) select ? from dual where not exists (select * from Namespace where name = ?)`
soiNamespace = `SELECT id from Namespace WHERE name = ?`
listNamespace = `SELECT id, name FROM Namespace`
searchNamespace = `SELECT id FROM Namespace WHERE name = ?`
// feature.go
insertFeature = `insert into Feature(name, namespace_id) select CAST(? AS CHAR), CAST(? AS UNSIGNED) FROM dual WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = ? AND namespace_id = ?)`
soiFeature = `SELECT id FROM Feature WHERE name = ? AND namespace_id = ?`
insertFeatureVersion = `
Insert into FeatureVersion(feature_id, version) select cast(? AS unsigned),cast(? as char) from dual where not exists (select * from FeatureVersion where feature_id = ? AND version = ?) LIMIT 1`
// TODO: need to handle 'exi'
soiFeatureVersion = `SELECT 'new', id FROM FeatureVersion WHERE feature_id = ? AND version = ?`
searchFeatureVersion = `SELECT id FROM FeatureVersion WHERE feature_id = ? AND version = ?`
searchVulnerabilityFixedInFeature = `
SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature WHERE feature_id = ?`
insertVulnerabilityAffectsFeatureVersion = `
INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id) select ?,?,? from dual where not exists (select * from Vulnerability_Affects_FeatureVersion where vulnerability_id =? and featureversion_id = ? and fixedin_id = ?)`
// layer.go
searchLayer = `
SELECT l.id, l.name, l.engineversion, p.id, p.name, n.id, n.name
FROM Layer l
LEFT JOIN Layer p ON l.parent_id = p.id
LEFT JOIN Namespace n ON l.namespace_id = n.id
WHERE l.name = ?`
//TODO:
searchLayerFeatureVersion = `
SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name
FROM Layer_diff_FeatureVersion ldf
JOIN (
SELECT h.id,h.name,h.parent_id FROM (SELECT @Id AS tempId,name,(SELECT @Id := parent_id FROM Layer WHERE Id = tempId) AS parent_id FROM(SELECT @Id := ?)initializeVars, Layer h WHERE @Id <> 0)a INNER JOIN Layer h ON h.Id = a.tempId order by h.id desc
) AS ltree ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn
WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id
ORDER BY ltree.id`
searchFeatureVersionVulnerability = `
SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata,
vn.name, vfif.version
FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v,
Namespace vn, Vulnerability_FixedIn_Feature vfif
WHERE vafv.featureversion_id IN (%s)
AND vfif.vulnerability_id = v.id
AND vafv.fixedin_id = vfif.id
AND v.namespace_id = vn.id
AND v.deleted_at IS NULL`
insertLayer = `
INSERT INTO Layer(name, engineversion, parent_id, namespace_id, created_at)
VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP)
`
getLayerId = `select id from Layer where name=?`
updateLayer = `UPDATE Layer SET engineversion = ?, namespace_id = ? WHERE id = ?`
removeLayerDiffFeatureVersion = `
DELETE FROM Layer_diff_FeatureVersion
WHERE layer_id = ?`
insertLayerDiffFeatureVersion = `
INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification)
SELECT ?, fv.id, ?
FROM FeatureVersion fv
WHERE fv.id in(%s)`
removeLayer = `DELETE FROM Layer WHERE name = ?`
// lock.go
insertLock = "INSERT INTO `Lock`(name, owner, until) VALUES(?, ?, ?)"
searchLock = "SELECT owner, until FROM `Lock` WHERE name = ?"
updateLock = "UPDATE `Lock` SET until = ? WHERE name = ? AND owner = ?"
removeLock = "DELETE FROM `Lock` WHERE name = ? AND owner = ?"
removeLockExpired = "DELETE FROM `Lock` WHERE until < CURRENT_TIMESTAMP"
// vulnerability.go
searchVulnerabilityBase = `
SELECT v.id, v.name, n.id, n.name, v.description, v.link, v.severity, v.metadata
FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id`
searchVulnerabilityForUpdate = ` FOR UPDATE `
searchVulnerabilityByNamespaceAndName = ` WHERE n.name = ? AND v.name = ? AND v.deleted_at IS NULL`
searchVulnerabilityByID = ` WHERE v.id = ?`
searchVulnerabilityByNamespace = ` WHERE n.name = ? AND v.deleted_at IS NULL
AND v.id >= ?
ORDER BY v.id
LIMIT ?`
searchVulnerabilityFixedIn = `
SELECT vfif.version, f.id, f.Name
FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id
WHERE vfif.vulnerability_id = ?`
insertVulnerability = `
INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at)
VALUES(?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP)
`
getVulId = `select id from Vulnerability where namespace_id=? and name=?`
getVulIdWithNamespaceName = `select id from Vulnerability where namespace_id=(select id from Namespace where name=?) and name=? and deleted_at IS NULL`
insertVulnerabilityFixedInFeature = `
INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version)
VALUES(?, ?, ?)
`
findVulnerabilityFixedInFeature = `select id from Vulnerability_FixedIn_Feature where vulnerability_id=? AND feature_id=? AND version=?`
searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = ?`
removeVulnerability = `
UPDATE Vulnerability
SET deleted_at = CURRENT_TIMESTAMP
WHERE namespace_id = (SELECT id FROM Namespace WHERE name = ?)
AND name = ?
AND deleted_at IS NULL`
// notification.go
insertNotification = `
INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id)
VALUES(?, CURRENT_TIMESTAMP, ?, ?)`
updatedNotificationNotified = `
UPDATE Vulnerability_Notification
SET notified_at = CURRENT_TIMESTAMP
WHERE name = ?`
removeNotification = `
UPDATE Vulnerability_Notification
SET deleted_at = CURRENT_TIMESTAMP
WHERE name = ?`
searchNotificationAvailable = " SELECT id, name, created_at, notified_at, deleted_at" +
" FROM Vulnerability_Notification" +
" WHERE (notified_at IS NULL OR notified_at < ?)" +
" AND deleted_at IS NULL" +
" AND name NOT IN (SELECT name FROM `Lock`)" +
" ORDER BY Rand()" +
" LIMIT 1"
searchNotification = `
SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
FROM Vulnerability_Notification
WHERE name = ?`
searchNotificationLayerIntroducingVulnerability = `
SELECT l.ID, l.name
FROM Vulnerability v, Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv, Layer l
WHERE v.id = ?
AND v.id = vafv.vulnerability_id
AND vafv.featureversion_id = fv.id
AND fv.id = ldfv.featureversion_id
AND ldfv.modification = 'add'
AND ldfv.layer_id = l.id
AND l.id >= ?
ORDER BY l.ID
LIMIT ?`
// complex_test.go
searchComplexTestFeatureVersionAffects = `
SELECT v.name
FROM FeatureVersion fv
LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id
JOIN Vulnerability v ON vaf.vulnerability_id = v.id
WHERE featureversion_id = ?`
)
// buildInputArray constructs a PostgreSQL input array from the specified integers.
// Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using
// a single placeholder.
func buildInputArray(ints []int) string {
str := ""
for i := 0; i < len(ints)-1; i++ {
str = str + strconv.Itoa(ints[i]) + ","
}
str = str + strconv.Itoa(ints[len(ints)-1])
return str
}

55
database/mysql/testdata/data.sql vendored Normal file
View File

@ -0,0 +1,55 @@
-- Copyright 2015 clair authors
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
INSERT INTO Namespace (id, name) VALUES
(1, 'debian:7'),
(2, 'debian:8');
INSERT INTO Feature (id, namespace_id, name) VALUES
(1, 1, 'wechat'),
(2, 1, 'openssl'),
(4, 1, 'libssl'),
(3, 2, 'openssl');
INSERT INTO FeatureVersion (id, feature_id, version) VALUES
(1, 1, '0.5'),
(2, 2, '1.0'),
(3, 2, '2.0'),
(4, 3, '1.0');
INSERT INTO Layer (id, name, engineversion, parent_id, namespace_id) VALUES
(1, 'layer-0', 1, NULL, NULL),
(2, 'layer-1', 1, 1, 1),
(3, 'layer-2', 1, 2, 1),
(4, 'layer-3a', 1, 3, 1),
(5, 'layer-3b', 1, 3, 2);
INSERT INTO Layer_diff_FeatureVersion (id, layer_id, featureversion_id, modification) VALUES
(1, 2, 1, 'add'),
(2, 2, 2, 'add'),
(3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0
(4, 3, 3, 'add'), -- ^
(5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0
(6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0
INSERT INTO Vulnerability (id, namespace_id, name, description, link, severity) VALUES
(1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'),
(2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown');
INSERT INTO Vulnerability_FixedIn_Feature (id, vulnerability_id, feature_id, version) VALUES
(1, 1, 2, '2.0'),
(2, 1, 4, '1.9-abc');
INSERT INTO Vulnerability_Affects_FeatureVersion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES
(1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0

View File

@ -0,0 +1,593 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"database/sql"
"encoding/json"
"fmt"
"reflect"
"time"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/guregu/null/zero"
)
func (mySQL *mySQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) {
defer database.ObserveQueryTime("listVulnerabilities", "all", time.Now())
// Query Namespace.
var id int
err := mySQL.QueryRow(searchNamespace, namespaceName).Scan(&id)
if err != nil {
return nil, -1, handleError("searchNamespace", err)
} else if id == 0 {
return nil, -1, cerrors.ErrNotFound
}
// Query.
query := searchVulnerabilityBase + searchVulnerabilityByNamespace
rows, err := mySQL.Query(query, namespaceName, startID, limit+1)
if err != nil {
return nil, -1, handleError("searchVulnerabilityByNamespace", err)
}
defer rows.Close()
var vulns []database.Vulnerability
nextID := -1
size := 0
// Scan query.
for rows.Next() {
var vulnerability database.Vulnerability
err := rows.Scan(
&vulnerability.ID,
&vulnerability.Name,
&vulnerability.Namespace.ID,
&vulnerability.Namespace.Name,
&vulnerability.Description,
&vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
)
if err != nil {
return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err)
}
size++
if size > limit {
nextID = vulnerability.ID
} else {
vulns = append(vulns, vulnerability)
}
}
if err := rows.Err(); err != nil {
return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err)
}
return vulns, nextID, nil
}
func (mySQL *mySQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) {
return findVulnerability(mySQL, namespaceName, name, false)
}
func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) {
defer database.ObserveQueryTime("findVulnerability", "all", time.Now())
queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName"
query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName
if forUpdate {
queryName = queryName + "+searchVulnerabilityForUpdate"
query = query + searchVulnerabilityForUpdate
}
return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name))
}
func (mySQL *mySQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) {
defer database.ObserveQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now())
queryName := "searchVulnerabilityBase+searchVulnerabilityByID"
query := searchVulnerabilityBase + searchVulnerabilityByID
return scanVulnerability(mySQL, queryName, mySQL.QueryRow(query, id))
}
func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) {
var vulnerability database.Vulnerability
err := vulnerabilityRow.Scan(
&vulnerability.ID,
&vulnerability.Name,
&vulnerability.Namespace.ID,
&vulnerability.Namespace.Name,
&vulnerability.Description,
&vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
)
if err != nil {
return vulnerability, handleError(queryName+".Scan()", err)
}
if vulnerability.ID == 0 {
return vulnerability, cerrors.ErrNotFound
}
// Query the FixedIn FeatureVersion now.
rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID)
if err != nil {
return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
}
defer rows.Close()
for rows.Next() {
var featureVersionID zero.Int
var featureVersionVersion zero.String
var featureVersionFeatureName zero.String
err := rows.Scan(
&featureVersionVersion,
&featureVersionID,
&featureVersionFeatureName,
)
if err != nil {
return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err)
}
if !featureVersionID.IsZero() {
// Note that the ID we fill in featureVersion is actually a Feature ID, and not
// a FeatureVersion ID.
featureVersion := database.FeatureVersion{
Model: database.Model{ID: int(featureVersionID.Int64)},
Feature: database.Feature{
Model: database.Model{ID: int(featureVersionID.Int64)},
Namespace: vulnerability.Namespace,
Name: featureVersionFeatureName.String,
},
Version: types.NewVersionUnsafe(featureVersionVersion.String),
}
vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion)
}
}
if err := rows.Err(); err != nil {
return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err)
}
return vulnerability, nil
}
// FixedIn.Namespace are not necessary, they are overwritten by the vuln.
// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore.
func (mySQL *mySQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error {
for _, vulnerability := range vulnerabilities {
err := mySQL.insertVulnerability(vulnerability, false, generateNotifications)
if err != nil {
fmt.Printf("%#v\n", vulnerability)
return err
}
}
return nil
}
func (mySQL *mySQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error {
tf := time.Now()
// Verify parameters
if vulnerability.Name == "" || vulnerability.Namespace.Name == "" {
return cerrors.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace")
}
if !onlyFixedIn && !vulnerability.Severity.IsValid() {
msg := fmt.Sprintf("could not insert a vulnerability that has an invalid Severity: %s", vulnerability.Severity)
log.Warning(msg)
return cerrors.NewBadRequestError(msg)
}
for i := 0; i < len(vulnerability.FixedIn); i++ {
fifv := &vulnerability.FixedIn[i]
if fifv.Feature.Namespace.Name == "" {
// As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's
// Namespace.
fifv.Feature.Namespace.Name = vulnerability.Namespace.Name
} else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name {
msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability"
log.Warning(msg)
return cerrors.NewBadRequestError(msg)
}
}
// We do `defer database.ObserveQueryTime` here because we don't want to observe invalid vulnerabilities.
defer database.ObserveQueryTime("insertVulnerability", "all", tf)
// Begin transaction.
tx, err := mySQL.Begin()
if err != nil {
tx.Rollback()
return handleError("insertVulnerability.Begin()", err)
}
// Find existing vulnerability and its Vulnerability_FixedIn_Features (for update).
existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, false)
if err != nil && err != cerrors.ErrNotFound {
tx.Rollback()
return err
}
if onlyFixedIn {
// Because this call tries to update FixedIn FeatureVersion, import all other data from the
// existing one.
if existingVulnerability.ID == 0 {
return cerrors.ErrNotFound
}
fixedIn := vulnerability.FixedIn
vulnerability = existingVulnerability
vulnerability.FixedIn = fixedIn
}
if existingVulnerability.ID != 0 {
updateMetadata := vulnerability.Description != existingVulnerability.Description ||
vulnerability.Link != existingVulnerability.Link ||
vulnerability.Severity != existingVulnerability.Severity ||
!reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata)
// Construct the entire list of FixedIn FeatureVersion, by using the
// the FixedIn list of the old vulnerability.
//
// TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the
// existing vulnerability in order to make metadata updates much faster.
var updateFixedIn bool
vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn)
if !updateMetadata && !updateFixedIn {
tx.Commit()
return nil
}
// Mark the old vulnerability as non latest.
_, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name)
if err != nil {
tx.Rollback()
return handleError("removeVulnerability", err)
}
} else {
// The vulnerability is new, we don't want to have any types.MinVersion as they are only used
// for diffing existing vulnerabilities.
var fixedIn []database.FeatureVersion
for _, fv := range vulnerability.FixedIn {
if fv.Version != types.MinVersion {
fixedIn = append(fixedIn, fv)
}
}
vulnerability.FixedIn = fixedIn
}
// Find or insert Vulnerability's Namespace.
namespaceID, err := mySQL.insertNamespaceWithTransaction(tx, vulnerability.Namespace)
if err != nil {
tx.Rollback()
return err
}
// Insert vulnerability.
res, err := tx.Exec(
insertVulnerability,
namespaceID,
vulnerability.Name,
vulnerability.Description,
vulnerability.Link,
&vulnerability.Severity,
&vulnerability.Metadata,
)
if err != nil {
tx.Rollback()
return handleError("insertVulnerability", err)
}
ID, err := res.LastInsertId()
if err != nil {
tx.Rollback()
return err
}
vulnerability.ID = int(ID)
// vulnerability.ID == 0 means the vulnerability is already exists
if vulnerability.ID == 0 {
err = tx.QueryRow(getVulId, namespaceID, vulnerability.Name).Scan(&vulnerability.ID)
if err != nil {
tx.Rollback()
return err
}
}
// Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now.
err = mySQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn)
if err != nil {
tx.Rollback()
return err
}
// Create a notification.
if generateNotification {
err = createNotification(tx, existingVulnerability.ID, vulnerability.ID)
if err != nil {
return err
}
}
// Commit transaction.
err = tx.Commit()
if err != nil {
tx.Rollback()
return handleError("insertVulnerability.Commit()", err)
}
return nil
}
// castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that
// everything has the interface{} type.
// It is required when comparing crafted MetadataMap against MetadataMap that we get from the
// database.
func castMetadata(m database.MetadataMap) database.MetadataMap {
c := make(database.MetadataMap)
j, _ := json.Marshal(m)
json.Unmarshal(j, &c)
return c
}
// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result.
func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) {
currentMap, currentNames := createFeatureVersionNameMap(currentList)
diffMap, diffNames := createFeatureVersionNameMap(diff)
addedNames := utils.CompareStringLists(diffNames, currentNames)
inBothNames := utils.CompareStringListsInBoth(diffNames, currentNames)
different := false
for _, name := range addedNames {
if diffMap[name].Version == types.MinVersion {
// MinVersion only makes sense when a Feature is already fixed in some version,
// in which case we would be in the "inBothNames".
continue
}
currentMap[name] = diffMap[name]
different = true
}
for _, name := range inBothNames {
fv := diffMap[name]
if fv.Version == types.MinVersion {
// MinVersion means that the Feature doesn't affect the Vulnerability anymore.
delete(currentMap, name)
different = true
} else if fv.Version != currentMap[name].Version {
// The version got updated.
currentMap[name] = diffMap[name]
different = true
}
}
// Convert currentMap to a slice and return it.
var newList []database.FeatureVersion
for _, fv := range currentMap {
newList = append(newList, fv)
}
return newList, different
}
func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) {
m := make(map[string]database.FeatureVersion, 0)
s := make([]string, 0, len(features))
for i := 0; i < len(features); i++ {
featureVersion := features[i]
m[featureVersion.Feature.Name] = featureVersion
s = append(s, featureVersion.Feature.Name)
}
return m, s
}
// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given
// vulnerability with the specified database.FeatureVersion list and uses
// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to
// Vulnerability_Affects_FeatureVersion.
func (mySQL *mySQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error {
defer database.ObserveQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now())
// Insert or find the Features.
// TODO(Quentin-M): Batch me.
var err error
var features []*database.Feature
for i := 0; i < len(fixedIn); i++ {
features = append(features, &fixedIn[i].Feature)
}
for _, feature := range features {
if feature.ID == 0 {
if feature.ID, err = mySQL.insertFeatureiWithTransaction(tx, *feature); err != nil {
return err
}
}
}
// Lock Vulnerability_Affects_FeatureVersion exclusively.
// We want to prevent InsertFeatureVersion to modify it.
database.PromConcurrentLockVAFV.Inc()
defer database.PromConcurrentLockVAFV.Dec()
t := time.Now()
var tmp int64
err = tx.QueryRow(lockVulnerabilityAffects).Scan(&tmp)
database.ObserveQueryTime("insertVulnerability", "lock", t)
if err != nil {
tx.Rollback()
return handleError("insertVulnerability.lockVulnerabilityAffects", err)
}
for _, fv := range fixedIn {
var fixedInID int
// Insert Vulnerability_FixedIn_Feature.
_, err = tx.Exec(
insertVulnerabilityFixedInFeature,
vulnerabilityID, fv.Feature.ID,
&fv.Version,
)
if err != nil {
return handleError("insertVulnerabilityFixedInFeature", err)
}
err = tx.QueryRow(findVulnerabilityFixedInFeature, vulnerabilityID, fv.Feature.ID, &fv.Version).Scan(&fixedInID)
if err != nil {
return handleError("findVulnerabilityFixedInFeature", err)
}
// Insert Vulnerability_Affects_FeatureVersion.
err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Version)
if err != nil {
return err
}
}
return nil
}
func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error {
// Find every FeatureVersions of the Feature that the vulnerability affects.
// TODO(Quentin-M): LIMIT
rows, err := tx.Query(searchFeatureVersionByFeature, featureID)
if err != nil {
return handleError("searchFeatureVersionByFeature", err)
}
defer rows.Close()
var affecteds []database.FeatureVersion
for rows.Next() {
var affected database.FeatureVersion
err := rows.Scan(&affected.ID, &affected.Version)
if err != nil {
return handleError("searchFeatureVersionByFeature.Scan()", err)
}
if affected.Version.Compare(fixedInVersion) < 0 {
// The version of the FeatureVersion is lower than the fixed version of this vulnerability,
// thus, this FeatureVersion is affected by it.
affecteds = append(affecteds, affected)
}
}
if err = rows.Err(); err != nil {
return handleError("searchFeatureVersionByFeature.Rows()", err)
}
rows.Close()
// Insert into Vulnerability_Affects_FeatureVersion.
for _, affected := range affecteds {
// TODO(Quentin-M): Batch me.
_, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID,
affected.ID, fixedInID, vulnerabilityID, affected.ID, fixedInID)
if err != nil {
return handleError("insertVulnerabilityAffectsFeatureVersion", err)
}
}
return nil
}
func (mySQL *mySQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error {
defer database.ObserveQueryTime("InsertVulnerabilityFixes", "all", time.Now())
v := database.Vulnerability{
Name: vulnerabilityName,
Namespace: database.Namespace{
Name: vulnerabilityNamespace,
},
FixedIn: fixes,
}
return mySQL.insertVulnerability(v, true, true)
}
func (mySQL *mySQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error {
defer database.ObserveQueryTime("DeleteVulnerabilityFix", "all", time.Now())
v := database.Vulnerability{
Name: vulnerabilityName,
Namespace: database.Namespace{
Name: vulnerabilityNamespace,
},
FixedIn: []database.FeatureVersion{
{
Feature: database.Feature{
Name: featureName,
Namespace: database.Namespace{
Name: vulnerabilityNamespace,
},
},
Version: types.MinVersion,
},
},
}
return mySQL.insertVulnerability(v, true, true)
}
func (mySQL *mySQL) DeleteVulnerability(namespaceName, name string) error {
defer database.ObserveQueryTime("DeleteVulnerability", "all", time.Now())
// Begin transaction.
tx, err := mySQL.Begin()
if err != nil {
tx.Rollback()
return handleError("DeleteVulnerability.Begin()", err)
}
var vulnerabilityID int
err = tx.QueryRow(getVulIdWithNamespaceName, namespaceName, name).Scan(&vulnerabilityID)
if err != nil {
tx.Rollback()
return handleError("removeVulnerability", err)
}
_, err = tx.Exec(removeVulnerability, namespaceName, name)
if err != nil {
tx.Rollback()
return handleError("removeVulnerability", err)
}
// Create a notification.
err = createNotification(tx, vulnerabilityID, 0)
if err != nil {
return err
}
// Commit transaction.
err = tx.Commit()
if err != nil {
tx.Rollback()
return handleError("DeleteVulnerability.Commit()", err)
}
return nil
}

View File

@ -0,0 +1,276 @@
// Copyright 2015 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package mysql
import (
"reflect"
"testing"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/clair/utils/types"
"github.com/stretchr/testify/assert"
)
func TestFindVulnerability(t *testing.T) {
datastore, err := OpenForTest("FindVulnerability", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Find a vulnerability that does not exist.
_, err = datastore.FindVulnerability("", "")
assert.Equal(t, cerrors.ErrNotFound, err)
// Find a normal vulnerability.
v1 := database.Vulnerability{
Name: "CVE-OPENSSL-1-DEB7",
Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0",
Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7",
Severity: types.High,
Namespace: database.Namespace{Name: "debian:7"},
FixedIn: []database.FeatureVersion{
{
Feature: database.Feature{Name: "openssl"},
Version: types.NewVersionUnsafe("2.0"),
},
{
Feature: database.Feature{Name: "libssl"},
Version: types.NewVersionUnsafe("1.9-abc"),
},
},
}
v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
if assert.Nil(t, err) {
equalsVuln(t, &v1, &v1f)
}
// Find a vulnerability that has no link, no severity and no FixedIn.
v2 := database.Vulnerability{
Name: "CVE-NOPE",
Description: "A vulnerability affecting nothing",
Namespace: database.Namespace{Name: "debian:7"},
Severity: types.Unknown,
}
v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE")
if assert.Nil(t, err) {
equalsVuln(t, &v2, &v2f)
}
}
func TestDeleteVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", true)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Delete non-existing Vulnerability.
err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7")
assert.Equal(t, cerrors.ErrNotFound, err)
err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1")
assert.Equal(t, cerrors.ErrNotFound, err)
// Delete Vulnerability.
err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
if assert.Nil(t, err) {
_, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7")
assert.Equal(t, cerrors.ErrNotFound, err)
}
}
func TestInsertVulnerability(t *testing.T) {
datastore, err := OpenForTest("InsertVulnerability", false)
if err != nil {
t.Error(err)
return
}
defer datastore.Close()
// Create some data.
n1 := database.Namespace{Name: "TestInsertVulnerabilityNamespace1"}
n2 := database.Namespace{Name: "TestInsertVulnerabilityNamespace2"}
f1 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion1",
Namespace: n1,
},
Version: types.NewVersionUnsafe("1.0"),
}
f2 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion1",
Namespace: n2,
},
Version: types.NewVersionUnsafe("1.0"),
}
f3 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion2",
},
Version: types.MaxVersion,
}
f4 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion2",
},
Version: types.NewVersionUnsafe("1.4"),
}
f5 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion3",
},
Version: types.NewVersionUnsafe("1.5"),
}
f6 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion4",
},
Version: types.NewVersionUnsafe("0.1"),
}
f7 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion5",
},
Version: types.MaxVersion,
}
f8 := database.FeatureVersion{
Feature: database.Feature{
Name: "TestInsertVulnerabilityFeatureVersion5",
},
Version: types.MinVersion,
}
// Insert invalid vulnerabilities.
for _, vulnerability := range []database.Vulnerability{
{
Name: "",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1},
Severity: types.Unknown,
},
{
Name: "TestInsertVulnerability0",
Namespace: database.Namespace{},
FixedIn: []database.FeatureVersion{f1},
Severity: types.Unknown,
},
{
Name: "TestInsertVulnerability0-",
Namespace: database.Namespace{},
FixedIn: []database.FeatureVersion{f1},
},
{
Name: "TestInsertVulnerability0",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1},
Severity: types.Priority(""),
},
{
Name: "TestInsertVulnerability0",
Namespace: n1,
FixedIn: []database.FeatureVersion{f2},
Severity: types.Unknown,
},
} {
err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true)
assert.Error(t, err)
}
// Insert a simple vulnerability and find it.
v1meta := make(map[string]interface{})
v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1"
v1meta["TestInsertVulnerabilityMetadata2"] = struct {
Test string
}{
Test: "TestInsertVulnerabilityMetadataValue1",
}
v1 := database.Vulnerability{
Name: "TestInsertVulnerability1",
Namespace: n1,
FixedIn: []database.FeatureVersion{f1, f3, f6, f7},
Severity: types.Low,
Description: "TestInsertVulnerabilityDescription1",
Link: "TestInsertVulnerabilityLink1",
Metadata: v1meta,
}
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true)
if assert.Nil(t, err) {
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
if assert.Nil(t, err) {
equalsVuln(t, &v1, &v1f)
}
}
// Update vulnerability.
v1.Description = "TestInsertVulnerabilityLink2"
v1.Link = "TestInsertVulnerabilityLink2"
v1.Severity = types.High
// Update f3 in f4, add fixed in f5, add fixed in f6 which already exists, removes fixed in f7 by
// adding f8 which is f7 but with MinVersion.
v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8}
err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true)
if assert.Nil(t, err) {
v1f, err := datastore.FindVulnerability(n1.Name, v1.Name)
if assert.Nil(t, err) {
// We already had f1 before the update.
// Add it to the struct for comparison.
v1.FixedIn = append(v1.FixedIn, f1)
// Removes f8 from the struct for comparison as it was just here to cancel f7.
for i := 0; i < len(v1.FixedIn); i++ {
if v1.FixedIn[i].Feature.Name == f8.Feature.Name {
v1.FixedIn = append(v1.FixedIn[:i], v1.FixedIn[i+1:]...)
}
}
equalsVuln(t, &v1, &v1f)
}
}
}
func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) {
assert.Equal(t, expected.Name, actual.Name)
assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name)
assert.Equal(t, expected.Description, actual.Description)
assert.Equal(t, expected.Link, actual.Link)
assert.Equal(t, expected.Severity, actual.Severity)
assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata))
if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) {
for _, actualFeatureVersion := range actual.FixedIn {
found := false
for _, expectedFeatureVersion := range expected.FixedIn {
if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name {
found = true
assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name)
assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version)
}
}
if !found {
t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name)
}
}
}
}

View File

@ -30,16 +30,16 @@ func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) {
// Do cache lookup.
if pgSQL.cache != nil {
promCacheQueriesTotal.WithLabelValues("feature").Inc()
database.PromCacheQueriesTotal.WithLabelValues("feature").Inc()
id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name)
if found {
promCacheHitsTotal.WithLabelValues("feature").Inc()
database.PromCacheHitsTotal.WithLabelValues("feature").Inc()
return id.(int), nil
}
}
// We do `defer observeQueryTime` here because we don't want to observe cached features.
defer observeQueryTime("insertFeature", "all", time.Now())
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached features.
defer database.ObserveQueryTime("insertFeature", "all", time.Now())
// Find or create Namespace.
namespaceID, err := pgSQL.insertNamespace(feature.Namespace)
@ -69,21 +69,21 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
// Do cache lookup.
cacheIndex := "featureversion:" + featureVersion.Feature.Namespace.Name + ":" + featureVersion.Feature.Name + ":" + featureVersion.Version.String()
if pgSQL.cache != nil {
promCacheQueriesTotal.WithLabelValues("featureversion").Inc()
database.PromCacheQueriesTotal.WithLabelValues("featureversion").Inc()
id, found := pgSQL.cache.Get(cacheIndex)
if found {
promCacheHitsTotal.WithLabelValues("featureversion").Inc()
database.PromCacheHitsTotal.WithLabelValues("featureversion").Inc()
return id.(int), nil
}
}
// We do `defer observeQueryTime` here because we don't want to observe cached featureversions.
defer observeQueryTime("insertFeatureVersion", "all", time.Now())
// We do `defer database.ObserveQueryTime` here because we don't want to observe cached featureversions.
defer database.ObserveQueryTime("insertFeatureVersion", "all", time.Now())
// Find or create Feature first.
t := time.Now()
featureID, err := pgSQL.insertFeature(featureVersion.Feature)
observeQueryTime("insertFeatureVersion", "insertFeature", t)
database.ObserveQueryTime("insertFeatureVersion", "insertFeature", t)
if err != nil {
return 0, err
@ -117,11 +117,11 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
// Lock Vulnerability_Affects_FeatureVersion exclusively.
// We want to prevent InsertVulnerability to modify it.
promConcurrentLockVAFV.Inc()
defer promConcurrentLockVAFV.Dec()
database.PromConcurrentLockVAFV.Inc()
defer database.PromConcurrentLockVAFV.Dec()
t = time.Now()
_, err = tx.Exec(lockVulnerabilityAffects)
observeQueryTime("insertFeatureVersion", "lock", t)
database.ObserveQueryTime("insertFeatureVersion", "lock", t)
if err != nil {
tx.Rollback()
@ -134,7 +134,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
t = time.Now()
err = tx.QueryRow(soiFeatureVersion, featureID, &featureVersion.Version).
Scan(&newOrExisting, &featureVersion.ID)
observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t)
database.ObserveQueryTime("insertFeatureVersion", "soiFeatureVersion", t)
if err != nil {
tx.Rollback()
@ -156,7 +156,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion)
// Vulnerability_Affects_FeatureVersion.
t = time.Now()
err = linkFeatureVersionToVulnerabilities(tx, featureVersion)
observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t)
database.ObserveQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t)
if err != nil {
tx.Rollback()

View File

@ -18,6 +18,7 @@ import (
"database/sql"
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
)
@ -28,7 +29,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
return cerrors.NewBadRequestError("could not insert a flag which has an empty name or value")
}
defer observeQueryTime("InsertKeyValue", "all", time.Now())
defer database.ObserveQueryTime("InsertKeyValue", "all", time.Now())
// Upsert.
//
@ -67,7 +68,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) {
// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist.
func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) {
defer observeQueryTime("GetKeyValue", "all", time.Now())
defer database.ObserveQueryTime("GetKeyValue", "all", time.Now())
var value string
err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value)

View File

@ -31,7 +31,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
} else if withVulnerabilities {
subquery += "/features+vulnerabilities"
}
defer observeQueryTime("FindLayer", subquery, time.Now())
defer database.ObserveQueryTime("FindLayer", subquery, time.Now())
// Find the layer
var layer database.Layer
@ -43,8 +43,8 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
t := time.Now()
err := pgSQL.QueryRow(searchLayer, name).
Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID,
&namespaceName)
observeQueryTime("FindLayer", "searchLayer", t)
&namespaceName)
database.ObserveQueryTime("FindLayer", "searchLayer", t)
if err != nil {
return layer, handleError("searchLayer", err)
@ -89,7 +89,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
t = time.Now()
featureVersions, err := getLayerFeatureVersions(tx, layer.ID)
observeQueryTime("FindLayer", "getLayerFeatureVersions", t)
database.ObserveQueryTime("FindLayer", "getLayerFeatureVersions", t)
if err != nil {
return layer, err
@ -101,7 +101,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo
// Load the vulnerabilities that affect the FeatureVersions.
t = time.Now()
err := loadAffectedBy(tx, layer.Features)
observeQueryTime("FindLayer", "loadAffectedBy", t)
database.ObserveQueryTime("FindLayer", "loadAffectedBy", t)
if err != nil {
return layer, err
@ -233,8 +233,8 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error {
layer.ID = existingLayer.ID
}
// We do `defer observeQueryTime` here because we don't want to observe existing layers.
defer observeQueryTime("InsertLayer", "all", tf)
// We do `defer database.ObserveQueryTime` here because we don't want to observe existing layers.
defer database.ObserveQueryTime("InsertLayer", "all", tf)
// Get parent ID.
var parentID zero.Int
@ -386,7 +386,7 @@ func createNV(features []database.FeatureVersion) (map[string]*database.FeatureV
}
func (pgSQL *pgSQL) DeleteLayer(name string) error {
defer observeQueryTime("DeleteLayer", "all", time.Now())
defer database.ObserveQueryTime("DeleteLayer", "all", time.Now())
result, err := pgSQL.Exec(removeLayer, name)
if err != nil {

View File

@ -17,6 +17,7 @@ package pgsql
import (
"time"
"github.com/coreos/clair/database"
cerrors "github.com/coreos/clair/utils/errors"
)
@ -30,7 +31,7 @@ func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, rene
return false, time.Time{}
}
defer observeQueryTime("Lock", "all", time.Now())
defer database.ObserveQueryTime("Lock", "all", time.Now())
// Compute expiration.
until := time.Now().Add(duration)
@ -70,7 +71,7 @@ func (pgSQL *pgSQL) Unlock(name, owner string) {
return
}
defer observeQueryTime("Unlock", "all", time.Now())
defer database.ObserveQueryTime("Unlock", "all", time.Now())
pgSQL.Exec(removeLock, name, owner)
}
@ -83,7 +84,7 @@ func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) {
return "", time.Time{}, cerrors.NewBadRequestError("could not find an invalid lock")
}
defer observeQueryTime("FindLock", "all", time.Now())
defer database.ObserveQueryTime("FindLock", "all", time.Now())
var owner string
var until time.Time
@ -97,7 +98,7 @@ func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) {
// pruneLocks removes every expired locks from the database
func (pgSQL *pgSQL) pruneLocks() {
defer observeQueryTime("pruneLocks", "all", time.Now())
defer database.ObserveQueryTime("pruneLocks", "all", time.Now())
if _, err := pgSQL.Exec(removeLockExpired); err != nil {
handleError("removeLockExpired", err)

View File

@ -27,15 +27,15 @@ func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) {
}
if pgSQL.cache != nil {
promCacheQueriesTotal.WithLabelValues("namespace").Inc()
database.PromCacheQueriesTotal.WithLabelValues("namespace").Inc()
if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found {
promCacheHitsTotal.WithLabelValues("namespace").Inc()
database.PromCacheHitsTotal.WithLabelValues("namespace").Inc()
return id.(int), nil
}
}
// We do `defer observeQueryTime` here because we don't want to observe cached namespaces.
defer observeQueryTime("insertNamespace", "all", time.Now())
defer database.ObserveQueryTime("insertNamespace", "all", time.Now())
var id int
err := pgSQL.QueryRow(soiNamespace, namespace.Name).Scan(&id)

View File

@ -13,7 +13,7 @@ import (
// do it in tx so we won't insert/update a vuln without notification and vice-versa.
// name and created doesn't matter.
func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error {
defer observeQueryTime("createNotification", "all", time.Now())
defer database.ObserveQueryTime("createNotification", "all", time.Now())
// Insert Notification.
oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0}
@ -30,7 +30,7 @@ func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int)
// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)).
// Does not fill new/old vuln.
func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) {
defer observeQueryTime("GetAvailableNotification", "all", time.Now())
defer database.ObserveQueryTime("GetAvailableNotification", "all", time.Now())
before := time.Now().Add(-renotifyInterval)
row := pgSQL.QueryRow(searchNotificationAvailable, before)
@ -40,7 +40,7 @@ func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (da
}
func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) {
defer observeQueryTime("GetNotification", "all", time.Now())
defer database.ObserveQueryTime("GetNotification", "all", time.Now())
// Get Notification.
notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true)
@ -145,8 +145,8 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu
return -1, nil
}
// We do `defer observeQueryTime` here because we don't want to observe invalid calls.
defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf)
// We do `defer database.ObserveQueryTime` here because we don't want to observe invalid calls.
defer database.ObserveQueryTime("loadLayerIntroducingVulnerability", "all", tf)
// Query with limit + 1, the last item will be used to know the next starting ID.
rows, err := pgSQL.Query(searchNotificationLayerIntroducingVulnerability,
@ -185,7 +185,7 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu
}
func (pgSQL *pgSQL) SetNotificationNotified(name string) error {
defer observeQueryTime("SetNotificationNotified", "all", time.Now())
defer database.ObserveQueryTime("SetNotificationNotified", "all", time.Now())
if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil {
return handleError("updatedNotificationNotified", err)
@ -194,7 +194,7 @@ func (pgSQL *pgSQL) SetNotificationNotified(name string) error {
}
func (pgSQL *pgSQL) DeleteNotification(name string) error {
defer observeQueryTime("DeleteNotification", "all", time.Now())
defer database.ObserveQueryTime("DeleteNotification", "all", time.Now())
result, err := pgSQL.Exec(removeNotification, name)
if err != nil {

View File

@ -23,57 +23,21 @@ import (
"path"
"runtime"
"strings"
"time"
"bitbucket.org/liamstask/goose/lib/goose"
"github.com/coreos/clair/config"
"github.com/coreos/clair/database"
"github.com/coreos/clair/utils"
cerrors "github.com/coreos/clair/utils/errors"
"github.com/coreos/pkg/capnslog"
"github.com/hashicorp/golang-lru"
"github.com/lib/pq"
"github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus"
)
var (
log = capnslog.NewPackageLogger("github.com/coreos/clair", "pgsql")
promErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_errors_total",
Help: "Number of errors that PostgreSQL requests generated.",
}, []string{"request"})
promCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_cache_hits_total",
Help: "Number of cache hits that the PostgreSQL backend did.",
}, []string{"object"})
promCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_pgsql_cache_queries_total",
Help: "Number of cache queries that the PostgreSQL backend did.",
}, []string{"object"})
promQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "clair_pgsql_query_duration_milliseconds",
Help: "Time it takes to execute the database query.",
}, []string{"query", "subquery"})
promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "clair_pgsql_concurrent_lock_vafv_total",
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.",
})
)
func init() {
prometheus.MustRegister(promErrorsTotal)
prometheus.MustRegister(promCacheHitsTotal)
prometheus.MustRegister(promCacheQueriesTotal)
prometheus.MustRegister(promQueryDurationMilliseconds)
prometheus.MustRegister(promConcurrentLockVAFV)
}
type Queryer interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
@ -210,7 +174,7 @@ type pgSQLTest struct {
// Using Close() will drop the database.
func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) {
// Define the PostgreSQL connection strings.
dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname="
dataSource := "host=127.0.0.1 sslmode=disable user=postgres password=huawei dbname="
if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" {
dataSource = dataSourceEnv + " dbname="
}
@ -267,7 +231,7 @@ func handleError(desc string, err error) error {
}
log.Errorf("%s: %v", desc, err)
promErrorsTotal.WithLabelValues(desc).Inc()
database.PromErrorsTotal.WithLabelValues(desc).Inc()
if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") {
return database.ErrBackendException
@ -281,7 +245,3 @@ func isErrUniqueViolation(err error) bool {
pqErr, ok := err.(*pq.Error)
return ok && pqErr.Code == "23505"
}
func observeQueryTime(query, subquery string, start time.Time) {
utils.PrometheusObserveTimeMilliseconds(promQueryDurationMilliseconds.WithLabelValues(query, subquery), start)
}

View File

@ -29,7 +29,7 @@ import (
)
func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) {
defer observeQueryTime("listVulnerabilities", "all", time.Now())
defer database.ObserveQueryTime("listVulnerabilities", "all", time.Now())
// Query Namespace.
var id int
@ -88,7 +88,7 @@ func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vuln
}
func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) {
defer observeQueryTime("findVulnerability", "all", time.Now())
defer database.ObserveQueryTime("findVulnerability", "all", time.Now())
queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName"
query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName
@ -101,7 +101,7 @@ func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bo
}
func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) {
defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now())
defer database.ObserveQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now())
queryName := "searchVulnerabilityBase+searchVulnerabilityByID"
query := searchVulnerabilityBase + searchVulnerabilityByID
@ -215,8 +215,8 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, on
}
}
// We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities.
defer observeQueryTime("insertVulnerability", "all", tf)
// We do `defer database.ObserveQueryTime` here because we don't want to observe invalid vulnerabilities.
defer database.ObserveQueryTime("insertVulnerability", "all", tf)
// Begin transaction.
tx, err := pgSQL.Begin()
@ -401,7 +401,7 @@ func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string
// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to
// Vulnerability_Affects_FeatureVersion.
func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error {
defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now())
defer database.ObserveQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now())
// Insert or find the Features.
// TODO(Quentin-M): Batch me.
@ -420,11 +420,11 @@ func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulner
// Lock Vulnerability_Affects_FeatureVersion exclusively.
// We want to prevent InsertFeatureVersion to modify it.
promConcurrentLockVAFV.Inc()
defer promConcurrentLockVAFV.Dec()
database.PromConcurrentLockVAFV.Inc()
defer database.PromConcurrentLockVAFV.Dec()
t := time.Now()
_, err = tx.Exec(lockVulnerabilityAffects)
observeQueryTime("insertVulnerability", "lock", t)
database.ObserveQueryTime("insertVulnerability", "lock", t)
if err != nil {
tx.Rollback()
@ -498,7 +498,7 @@ func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID,
}
func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error {
defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now())
defer database.ObserveQueryTime("InsertVulnerabilityFixes", "all", time.Now())
v := database.Vulnerability{
Name: vulnerabilityName,
@ -512,7 +512,7 @@ func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabili
}
func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error {
defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now())
defer database.ObserveQueryTime("DeleteVulnerabilityFix", "all", time.Now())
v := database.Vulnerability{
Name: vulnerabilityName,
@ -536,7 +536,7 @@ func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerability
}
func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error {
defer observeQueryTime("DeleteVulnerability", "all", time.Now())
defer database.ObserveQueryTime("DeleteVulnerability", "all", time.Now())
// Begin transaction.
tx, err := pgSQL.Begin()

47
database/prometheus.go Normal file
View File

@ -0,0 +1,47 @@
package database
import (
"time"
"github.com/coreos/clair/utils"
"github.com/prometheus/client_golang/prometheus"
)
var (
PromErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_sql_errors_total",
Help: "Number of errors that SQL requests generated.",
}, []string{"request"})
PromCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_sql_cache_hits_total",
Help: "Number of cache hits that the SQL backend did.",
}, []string{"object"})
PromCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{
Name: "clair_sql_cache_queries_total",
Help: "Number of cache queries that the SQL backend did.",
}, []string{"object"})
PromQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "clair_sql_query_duration_milliseconds",
Help: "Time it takes to execute the database query.",
}, []string{"query", "subquery"})
PromConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{
Name: "clair_sql_concurrent_lock_vafv_total",
Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.",
})
)
func init() {
prometheus.MustRegister(PromErrorsTotal)
prometheus.MustRegister(PromCacheHitsTotal)
prometheus.MustRegister(PromCacheQueriesTotal)
prometheus.MustRegister(PromQueryDurationMilliseconds)
prometheus.MustRegister(PromConcurrentLockVAFV)
}
func ObserveQueryTime(query, subquery string, start time.Time) {
utils.PrometheusObserveTimeMilliseconds(PromQueryDurationMilliseconds.WithLabelValues(query, subquery), start)
}