From 3ba218dbef1632fa978770372d69dc6ce13be600 Mon Sep 17 00:00:00 2001 From: Lei Jitang Date: Tue, 22 Mar 2016 04:30:43 -0400 Subject: [PATCH] database: add support mysql Signed-off-by: Lei Jitang --- clair.go | 17 +- config.example.yaml | 7 +- database/mysql/complex_test.go | 158 +++++ database/mysql/feature.go | 305 +++++++++ database/mysql/feature_test.go | 102 +++ database/mysql/keyvalue.go | 83 +++ database/mysql/keyvalue_test.go | 52 ++ database/mysql/layer.go | 414 ++++++++++++ database/mysql/layer_test.go | 350 +++++++++++ database/mysql/lock.go | 106 ++++ database/mysql/lock_test.go | 69 ++ .../migrations/20151222113213_Initial.sql | 177 ++++++ database/mysql/mysql.go | 264 ++++++++ database/mysql/namespace.go | 125 ++++ database/mysql/namespace_test.go | 66 ++ database/mysql/notification.go | 214 +++++++ database/mysql/notification_test.go | 209 ++++++ database/mysql/queries.go | 204 ++++++ database/mysql/testdata/data.sql | 55 ++ database/mysql/vulnerability.go | 593 ++++++++++++++++++ database/mysql/vulnerability_test.go | 276 ++++++++ database/pgsql/feature.go | 28 +- database/pgsql/keyvalue.go | 5 +- database/pgsql/layer.go | 16 +- database/pgsql/lock.go | 9 +- database/pgsql/namespace.go | 6 +- database/pgsql/notification.go | 14 +- database/pgsql/pgsql.go | 44 +- database/pgsql/vulnerability.go | 24 +- database/prometheus.go | 47 ++ 30 files changed, 3941 insertions(+), 98 deletions(-) create mode 100644 database/mysql/complex_test.go create mode 100644 database/mysql/feature.go create mode 100644 database/mysql/feature_test.go create mode 100644 database/mysql/keyvalue.go create mode 100644 database/mysql/keyvalue_test.go create mode 100644 database/mysql/layer.go create mode 100644 database/mysql/layer_test.go create mode 100644 database/mysql/lock.go create mode 100644 database/mysql/lock_test.go create mode 100644 database/mysql/migrations/20151222113213_Initial.sql create mode 100644 database/mysql/mysql.go create mode 100644 database/mysql/namespace.go create mode 100644 database/mysql/namespace_test.go create mode 100644 database/mysql/notification.go create mode 100644 database/mysql/notification_test.go create mode 100644 database/mysql/queries.go create mode 100644 database/mysql/testdata/data.sql create mode 100644 database/mysql/vulnerability.go create mode 100644 database/mysql/vulnerability_test.go create mode 100644 database/prometheus.go diff --git a/clair.go b/clair.go index cfb5c850..ef7427d2 100644 --- a/clair.go +++ b/clair.go @@ -20,12 +20,15 @@ import ( "math/rand" "os" "os/signal" + "strings" "syscall" "time" "github.com/coreos/clair/api" "github.com/coreos/clair/api/context" "github.com/coreos/clair/config" + "github.com/coreos/clair/database" + "github.com/coreos/clair/database/mysql" "github.com/coreos/clair/database/pgsql" "github.com/coreos/clair/notifier" "github.com/coreos/clair/updater" @@ -40,14 +43,22 @@ var log = capnslog.NewPackageLogger("github.com/coreos/clair", "main") func Boot(config *config.Config) { rand.Seed(time.Now().UnixNano()) st := utils.NewStopper() - + var ( + db database.Datastore + err error + ) // Open database - db, err := pgsql.Open(config.Database) + if strings.HasPrefix(config.Database.Source, "postgres") { + db, err = pgsql.Open(config.Database) + } else if strings.HasPrefix(config.Database.Source, "mysql") { + db, err = mysql.Open(config.Database) + } else { + log.Fatal("database source '%s' does not support, support 'postgres' and 'mysql' for now", config.Database.Source) + } if err != nil { log.Fatal(err) } defer db.Close() - // Start notifier st.Begin() go notifier.Run(config.Notifier, db, st) diff --git a/config.example.yaml b/config.example.yaml index b489b97e..1922a613 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -15,9 +15,10 @@ # The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined. clair: database: - # PostgreSQL Connection string - # http://www.postgresql.org/docs/9.4/static/libpq-connect.html - source: + # database Connection string + # see http://www.postgresql.org/docs/9.4/static/libpq-connect.html for PostgreSQL + # see https://github.com/go-sql-driver/mysql#dsn-data-source-name for MySQL. e.g. "mysql://root@tcp(127.0.0.1:3306)/" + source: mysql://root@tcp(172.17.0.3:3306)/ # Number of elements kept in the cache # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. diff --git a/database/mysql/complex_test.go b/database/mysql/complex_test.go new file mode 100644 index 00000000..51040891 --- /dev/null +++ b/database/mysql/complex_test.go @@ -0,0 +1,158 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "fmt" + "math/rand" + "runtime" + "strconv" + "sync" + "testing" + "time" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils" + "github.com/coreos/clair/utils/types" + "github.com/pborman/uuid" + "github.com/stretchr/testify/assert" +) + +const ( + numVulnerabilities = 100 + numFeatureVersions = 100 +) + +func TestRaceAffects(t *testing.T) { + datastore, err := OpenForTest("RaceAffects", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Insert the Feature on which we'll work. + feature := database.Feature{ + Namespace: database.Namespace{Name: "TestRaceAffectsFeatureNamespace1"}, + Name: "TestRaceAffecturesFeature1", + } + _, err = datastore.insertFeature(feature) + if err != nil { + t.Error(err) + return + } + + // Initialize random generator and enforce max procs. + rand.Seed(time.Now().UnixNano()) + runtime.GOMAXPROCS(runtime.NumCPU()) + + // Generate FeatureVersions. + featureVersions := make([]database.FeatureVersion, numFeatureVersions) + for i := 0; i < numFeatureVersions; i++ { + version := rand.Intn(numFeatureVersions) + + featureVersions[i] = database.FeatureVersion{ + Feature: feature, + Version: types.NewVersionUnsafe(strconv.Itoa(version)), + } + } + + // Generate vulnerabilities. + // They are mapped by fixed version, which will make verification really easy afterwards. + vulnerabilities := make(map[int][]database.Vulnerability) + for i := 0; i < numVulnerabilities; i++ { + version := rand.Intn(numFeatureVersions) + 1 + + // if _, ok := vulnerabilities[version]; !ok { + // vulnerabilities[version] = make([]database.Vulnerability) + // } + + vulnerability := database.Vulnerability{ + Name: uuid.New(), + Namespace: feature.Namespace, + FixedIn: []database.FeatureVersion{ + { + Feature: feature, + Version: types.NewVersionUnsafe(strconv.Itoa(version)), + }, + }, + Severity: types.Unknown, + } + + vulnerabilities[version] = append(vulnerabilities[version], vulnerability) + } + + // Insert featureversions and vulnerabilities in parallel. + var wg sync.WaitGroup + wg.Add(2) + + go func() { + defer wg.Done() + for _, vulnerabilitiesM := range vulnerabilities { + for _, vulnerability := range vulnerabilitiesM { + err = datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) + assert.Nil(t, err) + } + } + fmt.Println("finished to insert vulnerabilities") + }() + + go func() { + defer wg.Done() + for i := 0; i < len(featureVersions); i++ { + featureVersions[i].ID, err = datastore.insertFeatureVersion(featureVersions[i]) + assert.Nil(t, err) + } + fmt.Println("finished to insert featureVersions") + }() + + wg.Wait() + + // Verify consistency now. + var actualAffectedNames []string + var expectedAffectedNames []string + + for _, featureVersion := range featureVersions { + featureVersionVersion, _ := strconv.Atoi(featureVersion.Version.String()) + + // Get actual affects. + rows, err := datastore.Query(searchComplexTestFeatureVersionAffects, + featureVersion.ID) + assert.Nil(t, err) + defer rows.Close() + + var vulnName string + for rows.Next() { + err = rows.Scan(&vulnName) + if !assert.Nil(t, err) { + continue + } + actualAffectedNames = append(actualAffectedNames, vulnName) + } + if assert.Nil(t, rows.Err()) { + rows.Close() + } + + // Get expected affects. + for i := numVulnerabilities; i > featureVersionVersion; i-- { + for _, vulnerability := range vulnerabilities[i] { + expectedAffectedNames = append(expectedAffectedNames, vulnerability.Name) + } + } + + assert.Len(t, utils.CompareStringLists(expectedAffectedNames, actualAffectedNames), 0) + assert.Len(t, utils.CompareStringLists(actualAffectedNames, expectedAffectedNames), 0) + } +} diff --git a/database/mysql/feature.go b/database/mysql/feature.go new file mode 100644 index 00000000..834f75a2 --- /dev/null +++ b/database/mysql/feature.go @@ -0,0 +1,305 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "database/sql" + "time" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" +) + +func (mySQL *mySQL) insertFeatureiWithTransaction(queryer Queryer, feature database.Feature) (int, error) { + if feature.Name == "" { + return 0, cerrors.NewBadRequestError("could not find/insert invalid Feature") + } + + // Do cache lookup. + if mySQL.cache != nil { + database.PromCacheQueriesTotal.WithLabelValues("feature").Inc() + id, found := mySQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name) + if found { + database.PromCacheHitsTotal.WithLabelValues("feature").Inc() + return id.(int), nil + } + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached features. + defer database.ObserveQueryTime("insertFeature", "all", time.Now()) + + // Find or create Namespace. + namespaceID, err := mySQL.insertNamespaceWithTransaction(queryer, feature.Namespace) + if err != nil { + return 0, err + } + // Find or create Feature. + var id int + res, err := queryer.Exec(insertFeature, feature.Name, namespaceID, feature.Name, namespaceID) + if err != nil { + return 0, handleError("insertFeatureVersion", err) + } + tmpid, err := res.LastInsertId() + if err != nil { + return 0, handleError("insertFeatureVersion", err) + } + id = int(tmpid) + // if id==0 means the feature already exists, use query to get id + if id == 0 { + err = queryer.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id) + if err != nil { + return 0, handleError("soiFeature", err) + } + } + if mySQL.cache != nil { + mySQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id) + } + + return id, nil +} + +func (mySQL *mySQL) insertFeature(feature database.Feature) (int, error) { + if feature.Name == "" { + return 0, cerrors.NewBadRequestError("could not find/insert invalid Feature") + } + + // Do cache lookup. + if mySQL.cache != nil { + database.PromCacheQueriesTotal.WithLabelValues("feature").Inc() + id, found := mySQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name) + if found { + database.PromCacheHitsTotal.WithLabelValues("feature").Inc() + return id.(int), nil + } + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached features. + defer database.ObserveQueryTime("insertFeature", "all", time.Now()) + + // Find or create Namespace. + namespaceID, err := mySQL.insertNamespace(feature.Namespace) + if err != nil { + return 0, err + } + // Find or create Feature. + var id int + res, err := mySQL.Exec(insertFeature, feature.Name, namespaceID, feature.Name, namespaceID) + if err != nil { + return 0, handleError("insertFeatureVersion", err) + } + tmpid, err := res.LastInsertId() + if err != nil { + return 0, handleError("insertFeatureVersion", err) + } + id = int(tmpid) + // if id==0 means the feature already exists, use query to get id + if id == 0 { + err = mySQL.QueryRow(soiFeature, feature.Name, namespaceID).Scan(&id) + if err != nil { + return 0, handleError("soiFeature", err) + } + } + if mySQL.cache != nil { + mySQL.cache.Add("feature:"+feature.Namespace.Name+":"+feature.Name, id) + } + + return id, nil +} + +func (mySQL *mySQL) insertFeatureVersion(featureVersion database.FeatureVersion) (id int, err error) { + if featureVersion.Version.String() == "" { + return 0, cerrors.NewBadRequestError("could not find/insert invalid FeatureVersion") + } + + // Do cache lookup. + cacheIndex := "featureversion:" + featureVersion.Feature.Namespace.Name + ":" + featureVersion.Feature.Name + ":" + featureVersion.Version.String() + if mySQL.cache != nil { + database.PromCacheQueriesTotal.WithLabelValues("featureversion").Inc() + id, found := mySQL.cache.Get(cacheIndex) + if found { + database.PromCacheHitsTotal.WithLabelValues("featureversion").Inc() + return id.(int), nil + } + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached featureversions. + defer database.ObserveQueryTime("insertFeatureVersion", "all", time.Now()) + + // Find or create Feature first. + t := time.Now() + featureID, err := mySQL.insertFeature(featureVersion.Feature) + database.ObserveQueryTime("insertFeatureVersion", "insertFeature", t) + + if err != nil { + return 0, err + } + + featureVersion.Feature.ID = featureID + + // Try to find the FeatureVersion. + // + // In a populated database, the likelihood of the FeatureVersion already being there is high. + // If we can find it here, we then avoid using a transaction and locking the database. + err = mySQL.QueryRow(searchFeatureVersion, featureID, &featureVersion.Version). + Scan(&featureVersion.ID) + if err != nil && err != sql.ErrNoRows { + return 0, handleError("searchFeatureVersion", err) + } + if err == nil { + if mySQL.cache != nil { + mySQL.cache.Add(cacheIndex, featureVersion.ID) + } + + return featureVersion.ID, nil + } + + // Begin transaction. + tx, err := mySQL.Begin() + if err != nil { + tx.Rollback() + return 0, handleError("insertFeatureVersion.Begin()", err) + } + + // Lock Vulnerability_Affects_FeatureVersion exclusively. + // We want to prevent InsertVulnerability to modify it. + database.PromConcurrentLockVAFV.Inc() + defer database.PromConcurrentLockVAFV.Dec() + t = time.Now() + var tmp int64 + err = tx.QueryRow(lockVulnerabilityAffects).Scan(&tmp) + database.ObserveQueryTime("insertFeatureVersion", "lock", t) + + if err != nil { + tx.Rollback() + return 0, handleError("insertFeatureVersion.lockVulnerabilityAffects", err) + } + // Find or create FeatureVersion. + var newOrExisting string + t = time.Now() + _, err = tx.Exec(insertFeatureVersion, featureID, &featureVersion.Version, featureID, &featureVersion.Version) + database.ObserveQueryTime("insertFeatureVersion", "soiFeatureVersion", t) + if err != nil { + tx.Rollback() + return 0, handleError("insertFeatureVersion", err) + } + + t = time.Now() + err = tx.QueryRow(soiFeatureVersion, featureID, &featureVersion.Version). + Scan(&newOrExisting, &featureVersion.ID) + database.ObserveQueryTime("insertFeatureVersion", "soiFeatureVersion", t) + + if err != nil { + tx.Rollback() + return 0, handleError("soiFeatureVersion", err) + } + + if newOrExisting == "exi" { + // That featureVersion already exists, return its id. + tx.Commit() + + if mySQL.cache != nil { + mySQL.cache.Add(cacheIndex, featureVersion.ID) + } + + return featureVersion.ID, nil + } + + // Link the new FeatureVersion with every vulnerabilities that affect it, by inserting in + // Vulnerability_Affects_FeatureVersion. + t = time.Now() + err = linkFeatureVersionToVulnerabilities(tx, featureVersion) + database.ObserveQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t) + + if err != nil { + tx.Rollback() + return 0, err + } + + // Commit transaction. + err = tx.Commit() + if err != nil { + return 0, handleError("insertFeatureVersion.Commit()", err) + } + + if mySQL.cache != nil { + mySQL.cache.Add(cacheIndex, featureVersion.ID) + } + + return featureVersion.ID, nil +} + +// TODO(Quentin-M): Batch me +func (mySQL *mySQL) insertFeatureVersions(featureVersions []database.FeatureVersion) ([]int, error) { + IDs := make([]int, 0, len(featureVersions)) + + for i := 0; i < len(featureVersions); i++ { + id, err := mySQL.insertFeatureVersion(featureVersions[i]) + if err != nil { + return IDs, err + } + IDs = append(IDs, id) + } + + return IDs, nil +} + +type vulnerabilityAffectsFeatureVersion struct { + vulnerabilityID int + fixedInID int + fixedInVersion types.Version +} + +func linkFeatureVersionToVulnerabilities(tx *sql.Tx, featureVersion database.FeatureVersion) error { + // Select every vulnerability and the fixed version that affect this Feature. + // TODO(Quentin-M): LIMIT + rows, err := tx.Query(searchVulnerabilityFixedInFeature, featureVersion.Feature.ID) + if err != nil { + return handleError("searchVulnerabilityFixedInFeature", err) + } + defer rows.Close() + + var affects []vulnerabilityAffectsFeatureVersion + for rows.Next() { + var affect vulnerabilityAffectsFeatureVersion + + err := rows.Scan(&affect.fixedInID, &affect.vulnerabilityID, &affect.fixedInVersion) + if err != nil { + return handleError("searchVulnerabilityFixedInFeature.Scan()", err) + } + + if featureVersion.Version.Compare(affect.fixedInVersion) < 0 { + // The version of the FeatureVersion we are inserting is lower than the fixed version on this + // Vulnerability, thus, this FeatureVersion is affected by it. + affects = append(affects, affect) + } + } + if err = rows.Err(); err != nil { + return handleError("searchVulnerabilityFixedInFeature.Rows()", err) + } + rows.Close() + + // Insert into Vulnerability_Affects_FeatureVersion. + for _, affect := range affects { + // TODO(Quentin-M): Batch me. + _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, affect.vulnerabilityID, + featureVersion.ID, affect.fixedInID, affect.vulnerabilityID, featureVersion.ID, affect.fixedInID) + if err != nil { + return handleError("insertVulnerabilityAffectsFeatureVersion", err) + } + } + + return nil +} diff --git a/database/mysql/feature_test.go b/database/mysql/feature_test.go new file mode 100644 index 00000000..2170e319 --- /dev/null +++ b/database/mysql/feature_test.go @@ -0,0 +1,102 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "testing" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils/types" + "github.com/stretchr/testify/assert" +) + +func TestInsertFeature(t *testing.T) { + datastore, err := OpenForTest("InsertFeature", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Invalid Feature. + id0, err := datastore.insertFeature(database.Feature{}) + assert.NotNil(t, err) + assert.Zero(t, id0) + + id0, err = datastore.insertFeature(database.Feature{ + Namespace: database.Namespace{}, + Name: "TestInsertFeature0", + }) + assert.NotNil(t, err) + assert.Zero(t, id0) + + // Insert Feature and ensure we can find it. + feature := database.Feature{ + Namespace: database.Namespace{Name: "TestInsertFeatureNamespace1"}, + Name: "TestInsertFeature1", + } + id1, err := datastore.insertFeature(feature) + assert.Nil(t, err) + id2, err := datastore.insertFeature(feature) + assert.Nil(t, err) + assert.Equal(t, id1, id2) + + // Insert invalid FeatureVersion. + for _, invalidFeatureVersion := range []database.FeatureVersion{ + { + Feature: database.Feature{}, + Version: types.NewVersionUnsafe("1.0"), + }, + { + Feature: database.Feature{ + Namespace: database.Namespace{}, + Name: "TestInsertFeature2", + }, + Version: types.NewVersionUnsafe("1.0"), + }, + { + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertFeatureNamespace2"}, + Name: "TestInsertFeature2", + }, + Version: types.NewVersionUnsafe(""), + }, + { + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertFeatureNamespace2"}, + Name: "TestInsertFeature2", + }, + Version: types.NewVersionUnsafe("bad version"), + }, + } { + id3, err := datastore.insertFeatureVersion(invalidFeatureVersion) + assert.Error(t, err) + assert.Zero(t, id3) + } + + // Insert FeatureVersion and ensure we can find it. + featureVersion := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertFeatureNamespace1"}, + Name: "TestInsertFeature1", + }, + Version: types.NewVersionUnsafe("2:3.0-imba"), + } + id4, err := datastore.insertFeatureVersion(featureVersion) + assert.Nil(t, err) + id5, err := datastore.insertFeatureVersion(featureVersion) + assert.Nil(t, err) + assert.Equal(t, id4, id5) +} diff --git a/database/mysql/keyvalue.go b/database/mysql/keyvalue.go new file mode 100644 index 00000000..eb60beaa --- /dev/null +++ b/database/mysql/keyvalue.go @@ -0,0 +1,83 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "database/sql" + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "time" +) + +// InsertKeyValue stores (or updates) a single key / value tuple. +func (mySQL *mySQL) InsertKeyValue(key, value string) (err error) { + if key == "" || value == "" { + log.Warning("could not insert a flag which has an empty name or value") + return cerrors.NewBadRequestError("could not insert a flag which has an empty name or value") + } + + defer database.ObserveQueryTime("InsertKeyValue", "all", time.Now()) + + // Upsert. + // + // Note: UPSERT works only on >= PostgreSQL 9.5 which is not yet supported by AWS RDS. + // The best solution is currently the use of http://dba.stackexchange.com/a/13477 + // but the key/value storage doesn't need to be super-efficient and super-safe at the + // moment so we can just use a client-side solution with transactions, based on + // http://postgresql.org/docs/current/static/plpgsql-control-structures.html. + // TODO(Quentin-M): Enable Upsert as soon as 9.5 is stable. + + for { + // First, try to update. + r, err := mySQL.Exec(updateKeyValue, value, key) + if err != nil { + return handleError("updateKeyValue", err) + } + if n, _ := r.RowsAffected(); n > 0 { + // Updated successfully. + return nil + } + + // Try to insert the key. + // If someone else inserts the same key concurrently, we could get a unique-key violation error. + _, err = mySQL.Exec(insertKeyValue, key, value) + if err != nil { + if isErrUniqueViolation(err) { + // Got unique constraint violation, retry. + continue + } + return handleError("insertKeyValue", err) + } + + return nil + } +} + +// GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. +func (mySQL *mySQL) GetKeyValue(key string) (string, error) { + defer database.ObserveQueryTime("GetKeyValue", "all", time.Now()) + + var value string + err := mySQL.QueryRow(searchKeyValue, key).Scan(&value) + + if err == sql.ErrNoRows { + return "", nil + } + if err != nil { + return "", handleError("searchKeyValue", err) + } + + return value, nil +} diff --git a/database/mysql/keyvalue_test.go b/database/mysql/keyvalue_test.go new file mode 100644 index 00000000..9e9bab7a --- /dev/null +++ b/database/mysql/keyvalue_test.go @@ -0,0 +1,52 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestKeyValue(t *testing.T) { + datastore, err := OpenForTest("KeyValue", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Get non-existing key/value + f, err := datastore.GetKeyValue("test") + assert.Nil(t, err) + assert.Empty(t, "", f) + + // Try to insert invalid key/value. + assert.Error(t, datastore.InsertKeyValue("test", "")) + assert.Error(t, datastore.InsertKeyValue("", "test")) + assert.Error(t, datastore.InsertKeyValue("", "")) + + // Insert and verify. + assert.Nil(t, datastore.InsertKeyValue("test", "test1")) + f, err = datastore.GetKeyValue("test") + assert.Nil(t, err) + assert.Equal(t, "test1", f) + + // Update and verify. + assert.Nil(t, datastore.InsertKeyValue("test", "test2")) + f, err = datastore.GetKeyValue("test") + assert.Nil(t, err) + assert.Equal(t, "test2", f) +} diff --git a/database/mysql/layer.go b/database/mysql/layer.go new file mode 100644 index 00000000..60beaf21 --- /dev/null +++ b/database/mysql/layer.go @@ -0,0 +1,414 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "database/sql" + "fmt" + "time" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/guregu/null/zero" +) + +func (mySQL *mySQL) FindLayer(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { + subquery := "all" + if withFeatures { + subquery += "/features" + } else if withVulnerabilities { + subquery += "/features+vulnerabilities" + } + defer database.ObserveQueryTime("FindLayer", subquery, time.Now()) + + // Find the layer + var layer database.Layer + var parentID zero.Int + var parentName zero.String + var namespaceID zero.Int + var namespaceName sql.NullString + + t := time.Now() + err := mySQL.QueryRow(searchLayer, name). + Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, + &namespaceName) + database.ObserveQueryTime("FindLayer", "searchLayer", t) + + if err != nil { + return layer, handleError("searchLayer", err) + } + + if !parentID.IsZero() { + layer.Parent = &database.Layer{ + Model: database.Model{ID: int(parentID.Int64)}, + Name: parentName.String, + } + } + if !namespaceID.IsZero() { + layer.Namespace = &database.Namespace{ + Model: database.Model{ID: int(namespaceID.Int64)}, + Name: namespaceName.String, + } + } + + // Find its features + if withFeatures || withVulnerabilities { + // Create a transaction to disable hash/merge joins as our experiments have shown that + // PostgreSQL 9.4 makes bad planning decisions about: + // - joining the layer tree to feature versions and feature + // - joining the feature versions to affected/fixed feature version and vulnerabilities + // It would for instance do a merge join between affected feature versions (300 rows, estimated + // 3000 rows) and fixed in feature version (100k rows). In this case, it is much more + // preferred to use a nested loop. + tx, err := mySQL.Begin() + if err != nil { + return layer, handleError("FindLayer.Begin()", err) + } + defer tx.Commit() + + t = time.Now() + featureVersions, err := getLayerFeatureVersions(tx, layer.ID) + database.ObserveQueryTime("FindLayer", "getLayerFeatureVersions", t) + + if err != nil { + return layer, err + } + + layer.Features = featureVersions + + if withVulnerabilities { + // Load the vulnerabilities that affect the FeatureVersions. + t = time.Now() + err := loadAffectedBy(tx, layer.Features) + database.ObserveQueryTime("FindLayer", "loadAffectedBy", t) + + if err != nil { + return layer, err + } + } + } + + return layer, nil +} + +// getLayerFeatureVersions returns list of database.FeatureVersion that a database.Layer has. +func getLayerFeatureVersions(tx *sql.Tx, layerID int) ([]database.FeatureVersion, error) { + var featureVersions []database.FeatureVersion + + // Query. + rows, err := tx.Query(searchLayerFeatureVersion, layerID) + if err != nil { + return featureVersions, handleError("searchLayerFeatureVersion", err) + } + defer rows.Close() + + // Scan query. + var modification string + mapFeatureVersions := make(map[int]database.FeatureVersion) + for rows.Next() { + var featureVersion database.FeatureVersion + + err = rows.Scan(&featureVersion.ID, &modification, &featureVersion.Feature.Namespace.ID, + &featureVersion.Feature.Namespace.Name, &featureVersion.Feature.ID, + &featureVersion.Feature.Name, &featureVersion.ID, &featureVersion.Version, + &featureVersion.AddedBy.ID, &featureVersion.AddedBy.Name) + if err != nil { + return featureVersions, handleError("searchLayerFeatureVersion.Scan()", err) + } + + // Do transitive closure. + switch modification { + case "add": + mapFeatureVersions[featureVersion.ID] = featureVersion + case "del": + delete(mapFeatureVersions, featureVersion.ID) + default: + log.Warningf("unknown Layer_diff_FeatureVersion's modification: %s", modification) + return featureVersions, database.ErrInconsistent + } + } + if err = rows.Err(); err != nil { + return featureVersions, handleError("searchLayerFeatureVersion.Rows()", err) + } + + // Build result by converting our map to a slice. + for _, featureVersion := range mapFeatureVersions { + featureVersions = append(featureVersions, featureVersion) + } + + return featureVersions, nil +} + +// loadAffectedBy returns the list of database.Vulnerability that affect the given +// FeatureVersion. +func loadAffectedBy(tx *sql.Tx, featureVersions []database.FeatureVersion) error { + if len(featureVersions) == 0 { + return nil + } + + // Construct list of FeatureVersion IDs, we will do a single query + featureVersionIDs := make([]int, 0, len(featureVersions)) + for i := 0; i < len(featureVersions); i++ { + featureVersionIDs = append(featureVersionIDs, featureVersions[i].ID) + } + searchFeatureVersionVulnerabilityQuery := fmt.Sprintf(searchFeatureVersionVulnerability, buildInputArray(featureVersionIDs)) + rows, err := tx.Query(searchFeatureVersionVulnerabilityQuery) + if err != nil && err != sql.ErrNoRows { + return handleError("searchFeatureVersionVulnerability", err) + } + defer rows.Close() + + vulnerabilities := make(map[int][]database.Vulnerability, len(featureVersions)) + var featureversionID int + for rows.Next() { + var vulnerability database.Vulnerability + err := rows.Scan(&featureversionID, &vulnerability.ID, &vulnerability.Name, + &vulnerability.Description, &vulnerability.Link, &vulnerability.Severity, + &vulnerability.Metadata, &vulnerability.Namespace.Name, &vulnerability.FixedBy) + if err != nil { + return handleError("searchFeatureVersionVulnerability.Scan()", err) + } + vulnerabilities[featureversionID] = append(vulnerabilities[featureversionID], vulnerability) + } + if err = rows.Err(); err != nil { + return handleError("searchFeatureVersionVulnerability.Rows()", err) + } + + // Assign vulnerabilities to every FeatureVersions + for i := 0; i < len(featureVersions); i++ { + featureVersions[i].AffectedBy = vulnerabilities[featureVersions[i].ID] + } + + return nil +} + +// Internally, only Feature additions/removals are stored for each layer. If a layer has a parent, +// the Feature list will be compared to the parent's Feature list and the difference will be stored. +// Note that when the Namespace of a layer differs from its parent, it is expected that several +// Feature that were already included a parent will have their Namespace updated as well +// (happens when Feature detectors relies on the detected layer Namespace). However, if the listed +// Feature has the same Name/Version as its parent, InsertLayer considers that the Feature hasn't +// been modified. +func (mySQL *mySQL) InsertLayer(layer database.Layer) error { + tf := time.Now() + + // Verify parameters + if layer.Name == "" { + log.Warning("could not insert a layer which has an empty Name") + return cerrors.NewBadRequestError("could not insert a layer which has an empty Name") + } + + // Get a potentially existing layer. + existingLayer, err := mySQL.FindLayer(layer.Name, true, false) + if err != nil && err != cerrors.ErrNotFound { + return err + } else if err == nil { + if existingLayer.EngineVersion >= layer.EngineVersion { + // The layer exists and has an equal or higher engine version, do nothing. + return nil + } + + layer.ID = existingLayer.ID + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe existing layers. + defer database.ObserveQueryTime("InsertLayer", "all", tf) + + // Get parent ID. + var parentID zero.Int + if layer.Parent != nil { + if layer.Parent.ID == 0 { + log.Warning("Parent is expected to be retrieved from database when inserting a layer.") + return cerrors.NewBadRequestError("Parent is expected to be retrieved from database when inserting a layer.") + } + + parentID = zero.IntFrom(int64(layer.Parent.ID)) + } + + // Find or insert namespace if provided. + var namespaceID zero.Int + if layer.Namespace != nil { + n, err := mySQL.insertNamespace(*layer.Namespace) + if err != nil { + return err + } + namespaceID = zero.IntFrom(int64(n)) + } else if layer.Namespace == nil && layer.Parent != nil { + // Import the Namespace from the parent if it has one and this layer doesn't specify one. + if layer.Parent.Namespace != nil { + namespaceID = zero.IntFrom(int64(layer.Parent.Namespace.ID)) + } + } + + // Begin transaction. + tx, err := mySQL.Begin() + if err != nil { + tx.Rollback() + return handleError("InsertLayer.Begin()", err) + } + if layer.ID == 0 { + // Insert a new layer. + res, err := tx.Exec(insertLayer, layer.Name, layer.EngineVersion, parentID, namespaceID) + if err != nil { + tx.Rollback() + + if isErrUniqueViolation(err) { + // Ignore this error, another process collided. + log.Debug("Attempted to insert duplicate layer.") + return nil + } + return handleError("insertLayer", err) + } + tmpid, err := res.LastInsertId() + if err != nil { + tx.Rollback() + return handleError("inserLayer", err) + } + layer.ID = int(tmpid) + // if LastInsertId == 0, it means the layer already exists + if layer.ID == 0 { + err = tx.QueryRow(getLayerId, layer.Name).Scan(&layer.ID) + if err != nil { + tx.Rollback() + if isErrUniqueViolation(err) { + // Ignore this error, another process collided. + return nil + } + return handleError("getLayerId", err) + } + } + } else { + // Update an existing layer. + _, err = tx.Exec(updateLayer, layer.EngineVersion, namespaceID, layer.ID) + if err != nil { + tx.Rollback() + return handleError("updateLayer", err) + } + + // Remove all existing Layer_diff_FeatureVersion. + _, err = tx.Exec(removeLayerDiffFeatureVersion, layer.ID) + if err != nil { + tx.Rollback() + return handleError("removeLayerDiffFeatureVersion", err) + } + } + // Update Layer_diff_FeatureVersion now. + err = mySQL.updateDiffFeatureVersions(tx, &layer, &existingLayer) + if err != nil { + tx.Rollback() + return err + } + + // Commit transaction. + err = tx.Commit() + if err != nil { + tx.Rollback() + return handleError("InsertLayer.Commit()", err) + } + + return nil +} + +func (mySQL *mySQL) updateDiffFeatureVersions(tx *sql.Tx, layer, existingLayer *database.Layer) error { + // add and del are the FeatureVersion diff we should insert. + var add []database.FeatureVersion + var del []database.FeatureVersion + + if layer.Parent == nil { + // There is no parent, every Features are added. + add = append(add, layer.Features...) + } else if layer.Parent != nil { + // There is a parent, we need to diff the Features with it. + + // Build name:version structures. + layerFeaturesMapNV, layerFeaturesNV := createNV(layer.Features) + parentLayerFeaturesMapNV, parentLayerFeaturesNV := createNV(layer.Parent.Features) + + // Calculate the added and deleted FeatureVersions name:version. + addNV := utils.CompareStringLists(layerFeaturesNV, parentLayerFeaturesNV) + delNV := utils.CompareStringLists(parentLayerFeaturesNV, layerFeaturesNV) + + // Fill the structures containing the added and deleted FeatureVersions + for _, nv := range addNV { + add = append(add, *layerFeaturesMapNV[nv]) + } + for _, nv := range delNV { + del = append(del, *parentLayerFeaturesMapNV[nv]) + } + } + + // Insert FeatureVersions in the database. + addIDs, err := mySQL.insertFeatureVersions(add) + if err != nil { + return err + } + delIDs, err := mySQL.insertFeatureVersions(del) + if err != nil { + return err + } + + // Insert diff in the database. + if len(addIDs) > 0 { + insertLayerDiffFeatureVersionStmt := fmt.Sprintf(insertLayerDiffFeatureVersion, buildInputArray(addIDs)) + _, err = tx.Exec(insertLayerDiffFeatureVersionStmt, layer.ID, "add") + if err != nil { + return handleError("insertLayerDiffFeatureVersion.Add", err) + } + } + if len(delIDs) > 0 { + insertLayerDiffFeatureVersionStmt := fmt.Sprintf(insertLayerDiffFeatureVersion, buildInputArray(delIDs)) + _, err = tx.Exec(insertLayerDiffFeatureVersionStmt, layer.ID, "del") + if err != nil { + return handleError("insertLayerDiffFeatureVersion.Del", err) + } + } + + return nil +} + +func createNV(features []database.FeatureVersion) (map[string]*database.FeatureVersion, []string) { + mapNV := make(map[string]*database.FeatureVersion, 0) + sliceNV := make([]string, 0, len(features)) + + for i := 0; i < len(features); i++ { + featureVersion := &features[i] + nv := featureVersion.Feature.Name + ":" + featureVersion.Version.String() + mapNV[nv] = featureVersion + sliceNV = append(sliceNV, nv) + } + + return mapNV, sliceNV +} + +func (mySQL *mySQL) DeleteLayer(name string) error { + defer database.ObserveQueryTime("DeleteLayer", "all", time.Now()) + + result, err := mySQL.Exec(removeLayer, name) + if err != nil { + return handleError("removeLayer", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return handleError("removeLayer.RowsAffected()", err) + } + + if affected <= 0 { + return cerrors.ErrNotFound + } + + return nil +} diff --git a/database/mysql/layer_test.go b/database/mysql/layer_test.go new file mode 100644 index 00000000..c58130b2 --- /dev/null +++ b/database/mysql/layer_test.go @@ -0,0 +1,350 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "fmt" + "testing" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" + "github.com/stretchr/testify/assert" +) + +func TestFindLayer(t *testing.T) { + datastore, err := OpenForTest("FindLayer", true) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Layer-0: no parent, no namespace, no feature, no vulnerability + layer, err := datastore.FindLayer("layer-0", false, false) + if assert.Nil(t, err) && assert.NotNil(t, layer) { + assert.Equal(t, "layer-0", layer.Name) + assert.Nil(t, layer.Namespace) + assert.Nil(t, layer.Parent) + assert.Equal(t, 1, layer.EngineVersion) + assert.Len(t, layer.Features, 0) + } + + layer, err = datastore.FindLayer("layer-0", true, false) + if assert.Nil(t, err) && assert.NotNil(t, layer) { + assert.Len(t, layer.Features, 0) + } + + // Layer-1: one parent, adds two features, one vulnerability + layer, err = datastore.FindLayer("layer-1", false, false) + if assert.Nil(t, err) && assert.NotNil(t, layer) { + assert.Equal(t, layer.Name, "layer-1") + assert.Equal(t, "debian:7", layer.Namespace.Name) + if assert.NotNil(t, layer.Parent) { + assert.Equal(t, "layer-0", layer.Parent.Name) + } + assert.Equal(t, 1, layer.EngineVersion) + assert.Len(t, layer.Features, 0) + } + + layer, err = datastore.FindLayer("layer-1", true, false) + if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { + for _, featureVersion := range layer.Features { + assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) + + switch featureVersion.Feature.Name { + case "wechat": + assert.Equal(t, types.NewVersionUnsafe("0.5"), featureVersion.Version) + case "openssl": + assert.Equal(t, types.NewVersionUnsafe("1.0"), featureVersion.Version) + default: + t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) + } + } + } + + layer, err = datastore.FindLayer("layer-1", true, true) + if assert.Nil(t, err) && assert.NotNil(t, layer) && assert.Len(t, layer.Features, 2) { + for _, featureVersion := range layer.Features { + assert.Equal(t, "debian:7", featureVersion.Feature.Namespace.Name) + + switch featureVersion.Feature.Name { + case "wechat": + assert.Equal(t, types.NewVersionUnsafe("0.5"), featureVersion.Version) + case "openssl": + assert.Equal(t, types.NewVersionUnsafe("1.0"), featureVersion.Version) + + if assert.Len(t, featureVersion.AffectedBy, 1) { + assert.Equal(t, "debian:7", featureVersion.AffectedBy[0].Namespace.Name) + assert.Equal(t, "CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Name) + assert.Equal(t, types.High, featureVersion.AffectedBy[0].Severity) + assert.Equal(t, "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", featureVersion.AffectedBy[0].Description) + assert.Equal(t, "http://google.com/#q=CVE-OPENSSL-1-DEB7", featureVersion.AffectedBy[0].Link) + assert.Equal(t, types.NewVersionUnsafe("2.0"), featureVersion.AffectedBy[0].FixedBy) + } + default: + t.Errorf("unexpected package %s for layer-1", featureVersion.Feature.Name) + } + } + } +} + +func TestInsertLayer(t *testing.T) { + datastore, err := OpenForTest("InsertLayer", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Insert invalid layer. + testInsertLayerInvalid(t, datastore) + + // Insert a layer tree. + testInsertLayerTree(t, datastore) + + // Update layer. + testInsertLayerUpdate(t, datastore) + + // Delete layer. + testInsertLayerDelete(t, datastore) +} + +func testInsertLayerInvalid(t *testing.T, datastore database.Datastore) { + invalidLayers := []database.Layer{ + {}, + {Name: "layer0", Parent: &database.Layer{}}, + {Name: "layer0", Parent: &database.Layer{Name: "UnknownLayer"}}, + } + + for _, invalidLayer := range invalidLayers { + err := datastore.InsertLayer(invalidLayer) + assert.Error(t, err) + } +} + +func testInsertLayerTree(t *testing.T, datastore database.Datastore) { + f1 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace2"}, + Name: "TestInsertLayerFeature1", + }, + Version: types.NewVersionUnsafe("1.0"), + } + f2 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace2"}, + Name: "TestInsertLayerFeature2", + }, + Version: types.NewVersionUnsafe("0.34"), + } + f3 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace2"}, + Name: "TestInsertLayerFeature3", + }, + Version: types.NewVersionUnsafe("0.56"), + } + f4 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"}, + Name: "TestInsertLayerFeature2", + }, + Version: types.NewVersionUnsafe("0.34"), + } + f5 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"}, + Name: "TestInsertLayerFeature3", + }, + Version: types.NewVersionUnsafe("0.57"), + } + f6 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"}, + Name: "TestInsertLayerFeature4", + }, + Version: types.NewVersionUnsafe("0.666"), + } + + layers := []database.Layer{ + { + Name: "TestInsertLayer1", + }, + { + Name: "TestInsertLayer2", + Parent: &database.Layer{Name: "TestInsertLayer1"}, + Namespace: &database.Namespace{Name: "TestInsertLayerNamespace1"}, + }, + // This layer changes the namespace and adds Features. + { + Name: "TestInsertLayer3", + Parent: &database.Layer{Name: "TestInsertLayer2"}, + Namespace: &database.Namespace{Name: "TestInsertLayerNamespace2"}, + Features: []database.FeatureVersion{f1, f2, f3}, + }, + // This layer covers the case where the last layer doesn't provide any new Feature. + { + Name: "TestInsertLayer4a", + Parent: &database.Layer{Name: "TestInsertLayer3"}, + Features: []database.FeatureVersion{f1, f2, f3}, + }, + // This layer covers the case where the last layer provides Features. + // It also modifies the Namespace ("upgrade") but keeps some Features not upgraded, their + // Namespaces should then remain unchanged. + { + Name: "TestInsertLayer4b", + Parent: &database.Layer{Name: "TestInsertLayer3"}, + Namespace: &database.Namespace{Name: "TestInsertLayerNamespace3"}, + Features: []database.FeatureVersion{ + // Deletes TestInsertLayerFeature1. + // Keep TestInsertLayerFeature2 (old Namespace should be kept): + f4, + // Upgrades TestInsertLayerFeature3 (with new Namespace): + f5, + // Adds TestInsertLayerFeature4: + f6, + }, + }, + } + + var err error + retrievedLayers := make(map[string]database.Layer) + for _, layer := range layers { + if layer.Parent != nil { + // Retrieve from database its parent and assign. + parent := retrievedLayers[layer.Parent.Name] + layer.Parent = &parent + } + + err = datastore.InsertLayer(layer) + assert.Nil(t, err) + + retrievedLayers[layer.Name], err = datastore.FindLayer(layer.Name, true, false) + assert.Nil(t, err) + } + + l4a := retrievedLayers["TestInsertLayer4a"] + if assert.NotNil(t, l4a.Namespace) { + assert.Equal(t, "TestInsertLayerNamespace2", l4a.Namespace.Name) + } + assert.Len(t, l4a.Features, 3) + for _, featureVersion := range l4a.Features { + if cmpFV(featureVersion, f1) && cmpFV(featureVersion, f2) && cmpFV(featureVersion, f3) { + assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f1, f2, f3)) + } + } + + l4b := retrievedLayers["TestInsertLayer4b"] + if assert.NotNil(t, l4b.Namespace) { + assert.Equal(t, "TestInsertLayerNamespace3", l4b.Namespace.Name) + } + assert.Len(t, l4b.Features, 3) + for _, featureVersion := range l4b.Features { + if cmpFV(featureVersion, f2) && cmpFV(featureVersion, f5) && cmpFV(featureVersion, f6) { + assert.Error(t, fmt.Errorf("TestInsertLayer4a contains an unexpected package: %#v. Should contain %#v and %#v and %#v.", featureVersion, f2, f4, f6)) + } + } +} + +func testInsertLayerUpdate(t *testing.T, datastore database.Datastore) { + f7 := database.FeatureVersion{ + Feature: database.Feature{ + Namespace: database.Namespace{Name: "TestInsertLayerNamespace3"}, + Name: "TestInsertLayerFeature7", + }, + Version: types.NewVersionUnsafe("0.01"), + } + + l3, _ := datastore.FindLayer("TestInsertLayer3", true, false) + l3u := database.Layer{ + Name: l3.Name, + Parent: l3.Parent, + Namespace: &database.Namespace{Name: "TestInsertLayerNamespaceUpdated1"}, + Features: []database.FeatureVersion{f7}, + } + + l4u := database.Layer{ + Name: "TestInsertLayer4", + Parent: &database.Layer{Name: "TestInsertLayer3"}, + Features: []database.FeatureVersion{f7}, + EngineVersion: 2, + } + + // Try to re-insert without increasing the EngineVersion. + err := datastore.InsertLayer(l3u) + assert.Nil(t, err) + + l3uf, err := datastore.FindLayer(l3u.Name, true, false) + if assert.Nil(t, err) { + assert.Equal(t, l3.Namespace.Name, l3uf.Namespace.Name) + assert.Equal(t, l3.EngineVersion, l3uf.EngineVersion) + assert.Len(t, l3uf.Features, len(l3.Features)) + } + + // Update layer l3. + // Verify that the Namespace, EngineVersion and FeatureVersions got updated. + l3u.EngineVersion = 2 + err = datastore.InsertLayer(l3u) + assert.Nil(t, err) + + l3uf, err = datastore.FindLayer(l3u.Name, true, false) + if assert.Nil(t, err) { + assert.Equal(t, l3u.Namespace.Name, l3uf.Namespace.Name) + assert.Equal(t, l3u.EngineVersion, l3uf.EngineVersion) + if assert.Len(t, l3uf.Features, 1) { + assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l3uf.Features[0]) + } + } + + // Update layer l4. + // Verify that the Namespace got updated from its new Parent's, and also verify the + // EnginVersion and FeatureVersions. + l4u.Parent = &l3uf + err = datastore.InsertLayer(l4u) + assert.Nil(t, err) + + l4uf, err := datastore.FindLayer(l3u.Name, true, false) + if assert.Nil(t, err) { + assert.Equal(t, l3u.Namespace.Name, l4uf.Namespace.Name) + assert.Equal(t, l4u.EngineVersion, l4uf.EngineVersion) + if assert.Len(t, l4uf.Features, 1) { + assert.True(t, cmpFV(l3uf.Features[0], f7), "Updated layer should have %#v but actually have %#v", f7, l4uf.Features[0]) + } + } +} + +func testInsertLayerDelete(t *testing.T, datastore database.Datastore) { + err := datastore.DeleteLayer("TestInsertLayerX") + assert.Equal(t, cerrors.ErrNotFound, err) + + err = datastore.DeleteLayer("TestInsertLayer3") + assert.Nil(t, err) + + _, err = datastore.FindLayer("TestInsertLayer3", false, false) + assert.Equal(t, cerrors.ErrNotFound, err) + + _, err = datastore.FindLayer("TestInsertLayer4a", false, false) + assert.Equal(t, cerrors.ErrNotFound, err) + + _, err = datastore.FindLayer("TestInsertLayer4b", true, false) + assert.Equal(t, cerrors.ErrNotFound, err) +} + +func cmpFV(a, b database.FeatureVersion) bool { + return a.Feature.Name == b.Feature.Name && + a.Feature.Namespace.Name == b.Feature.Namespace.Name && + a.Version.String() == b.Version.String() +} diff --git a/database/mysql/lock.go b/database/mysql/lock.go new file mode 100644 index 00000000..20a994b4 --- /dev/null +++ b/database/mysql/lock.go @@ -0,0 +1,106 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "time" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" +) + +// Lock tries to set a temporary lock in the database. +// +// Lock does not block, instead, it returns true and its expiration time +// is the lock has been successfully acquired or false otherwise +func (mySQL *mySQL) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { + if name == "" || owner == "" || duration == 0 { + log.Warning("could not create an invalid lock") + return false, time.Time{} + } + + defer database.ObserveQueryTime("Lock", "all", time.Now()) + + // Compute expiration. + until := time.Now().Add(duration) + + if renew { + // Renew lock. + r, err := mySQL.Exec(updateLock, until, name, owner) + if err != nil { + handleError("updateLock", err) + return false, until + } + if n, _ := r.RowsAffected(); n > 0 { + // Updated successfully. + return true, until + } + } else { + // Prune locks. + mySQL.pruneLocks() + } + + // Lock. + _, err := mySQL.Exec(insertLock, name, owner, until) + if err != nil { + if !isErrUniqueViolation(err) { + handleError("insertLock", err) + } + return false, until + } + + return true, until +} + +// Unlock unlocks a lock specified by its name if I own it +func (mySQL *mySQL) Unlock(name, owner string) { + if name == "" || owner == "" { + log.Warning("could not delete an invalid lock") + return + } + + defer database.ObserveQueryTime("Unlock", "all", time.Now()) + + mySQL.Exec(removeLock, name, owner) +} + +// FindLock returns the owner of a lock specified by its name and its +// expiration time. +func (mySQL *mySQL) FindLock(name string) (string, time.Time, error) { + if name == "" { + log.Warning("could not find an invalid lock") + return "", time.Time{}, cerrors.NewBadRequestError("could not find an invalid lock") + } + + defer database.ObserveQueryTime("FindLock", "all", time.Now()) + + var owner string + var until time.Time + err := mySQL.QueryRow(searchLock, name).Scan(&owner, &until) + if err != nil { + return owner, until, handleError("searchLock", err) + } + + return owner, until, nil +} + +// pruneLocks removes every expired locks from the database +func (mySQL *mySQL) pruneLocks() { + defer database.ObserveQueryTime("pruneLocks", "all", time.Now()) + + if _, err := mySQL.Exec(removeLockExpired); err != nil { + handleError("removeLockExpired", err) + } +} diff --git a/database/mysql/lock_test.go b/database/mysql/lock_test.go new file mode 100644 index 00000000..bfeda625 --- /dev/null +++ b/database/mysql/lock_test.go @@ -0,0 +1,69 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestLock(t *testing.T) { + datastore, err := OpenForTest("InsertNamespace", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + var l bool + var et time.Time + + // Create a first lock. + l, _ = datastore.Lock("test1", "owner1", time.Minute, false) + assert.True(t, l) + + // Try to lock the same lock with another owner. + l, _ = datastore.Lock("test1", "owner2", time.Minute, true) + assert.False(t, l) + + l, _ = datastore.Lock("test1", "owner2", time.Minute, false) + assert.False(t, l) + + // Renew the lock. + l, _ = datastore.Lock("test1", "owner1", 2*time.Minute, true) + assert.True(t, l) + + // Unlock and then relock by someone else. + datastore.Unlock("test1", "owner1") + + l, et = datastore.Lock("test1", "owner2", time.Minute, false) + assert.True(t, l) + + // LockInfo + o, et2, err := datastore.FindLock("test1") + assert.Nil(t, err) + assert.Equal(t, "owner2", o) + assert.Equal(t, et.Second(), et2.Second()) + + // Create a second lock which is actually already expired ... + l, _ = datastore.Lock("test2", "owner1", -time.Minute, false) + assert.True(t, l) + + // Take over the lock + l, _ = datastore.Lock("test2", "owner2", time.Minute, false) + assert.True(t, l) +} diff --git a/database/mysql/migrations/20151222113213_Initial.sql b/database/mysql/migrations/20151222113213_Initial.sql new file mode 100644 index 00000000..bc12dcd1 --- /dev/null +++ b/database/mysql/migrations/20151222113213_Initial.sql @@ -0,0 +1,177 @@ +-- Copyright 2015 clair authors +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- +goose Up + +-- ----------------------------------------------------- +-- Table Namespace +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Namespace ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NULL); + + +-- ----------------------------------------------------- +-- Table Layer +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Layer ( + id SERIAL PRIMARY KEY, + name VARCHAR(128) NOT NULL UNIQUE, + engineversion SMALLINT NOT NULL, + parent_id BIGINT UNSIGNED NULL, + namespace_id BIGINT UNSIGNED NULL, + created_at TIMESTAMP, + FOREIGN KEY(parent_id) REFERENCES Layer(id) ON DELETE CASCADE, + FOREIGN KEY(namespace_id) REFERENCES Namespace(id), + INDEX (parent_id), + INDEX (namespace_id)); + + +-- ----------------------------------------------------- +-- Table Feature +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Feature ( + id SERIAL PRIMARY KEY, + namespace_id BIGINT UNSIGNED NOT NULL, + name VARCHAR(128) NOT NULL, + + FOREIGN KEY(namespace_id) REFERENCES Namespace(ID), + UNIQUE (namespace_id, name)); + + +-- ----------------------------------------------------- +-- Table FeatureVersion +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS FeatureVersion ( + id SERIAL PRIMARY KEY, + feature_id BIGINT UNSIGNED NOT NULL, + version VARCHAR(128) NOT NULL, + FOREIGN KEY(feature_id) REFERENCES Feature(id), + INDEX (feature_id)); + + +-- ----------------------------------------------------- +-- Table Layer_diff_FeatureVersion +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Layer_diff_FeatureVersion ( + id SERIAL PRIMARY KEY, + layer_id BIGINT UNSIGNED NOT NULL , + featureversion_id BIGINT UNSIGNED NOT NULL , + modification ENUM('add', 'del') NOT NULL, + + FOREIGN KEY (layer_id) REFERENCES Layer(id) ON DELETE CASCADE, + FOREIGN KEY (featureversion_id) REFERENCES FeatureVersion(id), + INDEX (layer_id), + INDEX (featureversion_id), + InDEX (featureversion_id, layer_id), + UNIQUE (layer_id, featureversion_id)); + + +-- ----------------------------------------------------- +-- Table Vulnerability +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Vulnerability ( + id SERIAL PRIMARY KEY, + namespace_id INT NOT NULL REFERENCES Namespace, + name VARCHAR(128) NOT NULL, + description TEXT NULL, + link VARCHAR(128) NULL, + severity ENUM('Unknown', 'Negligible', 'Low', 'Medium', 'High', 'Critical', 'Defcon1') NOT NULL, + metadata TEXT NULL, + created_at TIMESTAMP, + deleted_at TIMESTAMP NULL); + + +-- ----------------------------------------------------- +-- Table Vulnerability_FixedIn_Feature +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Vulnerability_FixedIn_Feature ( + id SERIAL PRIMARY KEY, + vulnerability_id BIGINT UNSIGNED NOT NULL, + feature_id BIGINT UNSIGNED NOT NULL, + version VARCHAR(128) NOT NULL, + + INDEX (feature_id, vulnerability_id), + FOREIGN KEY (vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE, + FOREIGN KEY (feature_id) REFERENCES Feature(id), + UNIQUE (vulnerability_id, feature_id)); + +-- ----------------------------------------------------- +-- Table Vulnerability_Affects_FeatureVersion +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Vulnerability_Affects_FeatureVersion ( + id SERIAL PRIMARY KEY, + vulnerability_id BIGINT UNSIGNED NOT NULL, + featureversion_id BIGINT UNSIGNED NOT NULL, + fixedin_id BIGINT UNSIGNED NOT NULL, + + INDEX (fixedin_id), + INDEX (featureversion_id, vulnerability_id), + FOREIGN KEY (vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE, + FOREIGN KEY (fixedin_id) REFERENCES Vulnerability_FixedIn_Feature (id) ON DELETE CASCADE, + FOREIGN KEY (featureversion_id) REFERENCES FeatureVersion (id), + UNIQUE (vulnerability_id, featureversion_id)); + + + +-- ----------------------------------------------------- +-- Table KeyValue +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS KeyValue ( + id SERIAL PRIMARY KEY, + `key` VARCHAR(128) NOT NULL UNIQUE, + `value` TEXT); + +-- ----------------------------------------------------- +-- Table Lock +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS `Lock` ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + owner VARCHAR(64) NOT NULL, + until TIMESTAMP, + + INDEX (owner)); + + +-- ----------------------------------------------------- +-- Table VulnerabilityNotification +-- ----------------------------------------------------- +CREATE TABLE IF NOT EXISTS Vulnerability_Notification ( + id SERIAL PRIMARY KEY, + name VARCHAR(64) NOT NULL UNIQUE, + created_at TIMESTAMP , + notified_at TIMESTAMP NULL, + deleted_at TIMESTAMP NULL, + old_vulnerability_id BIGINT UNSIGNED NULL, + new_vulnerability_id BIGINT UNSIGNED NULL, + + FOREIGN KEY (old_vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE, + FOREIGN KEY (new_vulnerability_id) REFERENCES Vulnerability(id) ON DELETE CASCADE, + INDEX (notified_at)); + +-- +goose Down + +DROP TABLE IF EXISTS Namespace, + Layer, + Feature, + FeatureVersion, + Layer_diff_FeatureVersion, + Vulnerability, + Vulnerability_FixedIn_Feature, + Vulnerability_Affects_FeatureVersion, + Vulnerability_Notification, + `KeyValue`, + `Lock` + CASCADE; diff --git a/database/mysql/mysql.go b/database/mysql/mysql.go new file mode 100644 index 00000000..22ca0385 --- /dev/null +++ b/database/mysql/mysql.go @@ -0,0 +1,264 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package mysql implements database.Datastore with MySQL. +package mysql + +import ( + "database/sql" + "fmt" + "io/ioutil" + "os" + "path" + "runtime" + "strings" + + "bitbucket.org/liamstask/goose/lib/goose" + "github.com/coreos/clair/config" + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/pkg/capnslog" + _ "github.com/go-sql-driver/mysql" + "github.com/hashicorp/golang-lru" + "github.com/pborman/uuid" +) + +var ( + log = capnslog.NewPackageLogger("github.com/coreos/clair", "mysql") +) + +const DATABASENAME = "clair" +const DEFAULTFLAG = "?charset=utf8&parseTime=True" +const DEFAULTSOURCE = DATABASENAME + DEFAULTFLAG + +type Queryer interface { + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row + Exec(query string, args ...interface{}) (sql.Result, error) +} + +type mySQL struct { + *sql.DB + cache *lru.ARCCache +} + +func (mySQL *mySQL) Close() { + mySQL.DB.Close() +} + +func (mySQL *mySQL) Ping() bool { + return mySQL.DB.Ping() == nil +} + +// Open creates a Datastore backed by a PostgreSQL database. +// It will run immediately every necessary migration on the database. +func Open(config *config.DatabaseConfig) (database.Datastore, error) { + source := config.Source + if strings.HasPrefix(source, "mysql://") { + source = strings.TrimPrefix(source, "mysql://") + } + config.Source = source + DEFAULTSOURCE + // Create Database if not exists + err := createDatabase(source, DATABASENAME) + if err != nil { + log.Error(err) + return nil, database.ErrCantOpen + } + return open(config) +} + +func open(config *config.DatabaseConfig) (database.Datastore, error) { + // Run migrations. + if err := migrate(config.Source); err != nil { + log.Error(err) + return nil, database.ErrCantOpen + } + + // Open database. + db, err := sql.Open("mysql", config.Source) + if err != nil { + log.Error(err) + return nil, database.ErrCantOpen + } + + // Initialize cache. + // TODO(Quentin-M): Benchmark with a simple LRU Cache. + var cache *lru.ARCCache + if config.CacheSize > 0 { + cache, _ = lru.NewARC(config.CacheSize) + } + + return &mySQL{DB: db, cache: cache}, nil +} + +// migrate runs all available migrations on a pgSQL database. +func migrate(dataSource string) error { + log.Info("running database migrations") + + _, filename, _, _ := runtime.Caller(1) + migrationDir := path.Join(path.Dir(filename), "/migrations/") + conf := &goose.DBConf{ + MigrationsDir: migrationDir, + Driver: goose.DBDriver{ + Name: "mysql", + OpenStr: dataSource, + Import: "github.com/go-sql-driver/mysql", + Dialect: &goose.MySqlDialect{}, + }, + } + + // Determine the most recent revision available from the migrations folder. + target, err := goose.GetMostRecentDBVersion(conf.MigrationsDir) + if err != nil { + return err + } + + // Run migrations + err = goose.RunMigrations(conf, conf.MigrationsDir, target) + if err != nil { + return err + } + + log.Info("database migration ran successfully") + return nil +} + +// TODO: +// createDatabase creates a new database. +// The dataSource parameter should not contain a dbname. +func createDatabase(dataSource, databaseName string) error { + // Open database. + log.Info("Create database: ", databaseName) + db, err := sql.Open("mysql", dataSource) + if err != nil { + return fmt.Errorf("could not open database (CreateDatabase): %v", err) + } + defer db.Close() + + // Create database. + _, err = db.Exec("CREATE DATABASE IF NOT EXISTS " + databaseName + " DEFAULT CHARACTER SET utf8 COLLATE utf8_general_ci") + if err != nil { + return fmt.Errorf("could not create database: %v", err) + } + + return nil +} + +// dropDatabase drops an existing database. +// The dataSource parameter should not contain a dbname. +func dropDatabase(dataSource, databaseName string) error { + // Open database. + db, err := sql.Open("mysql", dataSource) + if err != nil { + return fmt.Errorf("could not open database (DropDatabase): %v", err) + } + defer db.Close() + + // Drop database. + if _, err = db.Exec("DROP DATABASE " + databaseName); err != nil { + return fmt.Errorf("could not drop database: %v", err) + } + + return nil +} + +// TODO +// pgSQLTest wraps pgSQL for testing purposes. +// Its Close() method drops the database. +type pgSQLTest struct { + *mySQL + dataSourceDefaultDatabase string + dbName string +} + +// OpenForTest creates a test Datastore backed by a new PostgreSQL database. +// It creates a new unique and prefixed ("test_") database. +// Using Close() will drop the database. +func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { + // Define the PostgreSQL connection strings. + dataSource := "root@tcp(127.0.0.1:3306)/" + if dataSourceEnv := os.Getenv("CLAIR_TEST_MYSQL"); dataSourceEnv != "" { + dataSource = dataSourceEnv + } + dbName := "test_" + strings.ToLower(name) + "_" + strings.Replace(uuid.New(), "-", "_", -1) + dataSourceDefaultDatabase := dataSource + dataSourceTestDatabase := dataSource + dbName + "?charset=utf8&parseTime=True" + + // Create database. + if err := createDatabase(dataSourceDefaultDatabase, dbName); err != nil { + log.Error(err) + return nil, database.ErrCantOpen + } + + // Open database. + db, err := open(&config.DatabaseConfig{Source: dataSourceTestDatabase, CacheSize: 0}) + if err != nil { + dropDatabase(dataSourceDefaultDatabase, dbName) + log.Error(err) + return nil, database.ErrCantOpen + } + // Load test data if specified. + if withTestData { + _, filename, _, _ := runtime.Caller(0) + d, _ := ioutil.ReadFile(path.Join(path.Dir(filename)) + "/testdata/data.sql") + queries := strings.Split(fmt.Sprintf("%s", d), ";") + for _, q := range queries { + _, err = db.(*mySQL).Exec(q) + if err != nil { + dropDatabase(dataSourceDefaultDatabase, dbName) + log.Error(err) + return nil, database.ErrCantOpen + } + } + } + + return &pgSQLTest{ + mySQL: db.(*mySQL), + dataSourceDefaultDatabase: dataSourceDefaultDatabase, + dbName: dbName}, nil +} + +func (pgSQL *pgSQLTest) Close() { + pgSQL.DB.Close() + dropDatabase(pgSQL.dataSourceDefaultDatabase, pgSQL.dbName) +} + +// handleError logs an error with an extra description and masks the error if it's an SQL one. +// This ensures we never return plain SQL errors and leak anything. +func handleError(desc string, err error) error { + if err == nil { + return nil + } + + if err == sql.ErrNoRows { + return cerrors.ErrNotFound + } + + log.Errorf("%s: %v", desc, err) + database.PromErrorsTotal.WithLabelValues(desc).Inc() + + if err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { + return database.ErrBackendException + } + + return err +} + +// isErrUniqueViolation determines is the given error is a duplicate entry error. +func isErrUniqueViolation(err error) bool { + if strings.Contains(fmt.Sprintf("%v", err), "Error 1062") { + return true + } + return false +} diff --git a/database/mysql/namespace.go b/database/mysql/namespace.go new file mode 100644 index 00000000..c5f34f3c --- /dev/null +++ b/database/mysql/namespace.go @@ -0,0 +1,125 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "time" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" +) + +func (mySQL *mySQL) insertNamespaceWithTransaction(queryer Queryer, namespace database.Namespace) (int, error) { + if namespace.Name == "" { + return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace") + } + + if mySQL.cache != nil { + database.PromCacheQueriesTotal.WithLabelValues("namespace").Inc() + if id, found := mySQL.cache.Get("namespace:" + namespace.Name); found { + database.PromCacheHitsTotal.WithLabelValues("namespace").Inc() + return id.(int), nil + } + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached namespaces. + defer database.ObserveQueryTime("insertNamespace", "all", time.Now()) + + var id int + res, err := queryer.Exec(insertNamespace, namespace.Name, namespace.Name) + if err != nil { + return 0, handleError("insertNamespace", err) + } + tmpid, err := res.LastInsertId() + if err != nil { + return 0, handleError("insertNamespace", err) + } + id = int(tmpid) + if id == 0 { + err = queryer.QueryRow(soiNamespace, namespace.Name).Scan(&id) + if err != nil { + return 0, handleError("soiNamespace", err) + } + } + if mySQL.cache != nil { + mySQL.cache.Add("namespace:"+namespace.Name, id) + } + + return id, nil + +} + +func (mySQL *mySQL) insertNamespace(namespace database.Namespace) (int, error) { + if namespace.Name == "" { + return 0, cerrors.NewBadRequestError("could not find/insert invalid Namespace") + } + + if mySQL.cache != nil { + database.PromCacheQueriesTotal.WithLabelValues("namespace").Inc() + if id, found := mySQL.cache.Get("namespace:" + namespace.Name); found { + database.PromCacheHitsTotal.WithLabelValues("namespace").Inc() + return id.(int), nil + } + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached namespaces. + defer database.ObserveQueryTime("insertNamespace", "all", time.Now()) + + var id int + res, err := mySQL.Exec(insertNamespace, namespace.Name, namespace.Name) + if err != nil { + return 0, handleError("insertNamespace", err) + } + tmpid, err := res.LastInsertId() + if err != nil { + return 0, handleError("insertNamespace", err) + } + id = int(tmpid) + if id == 0 { + err = mySQL.QueryRow(soiNamespace, namespace.Name).Scan(&id) + if err != nil { + return 0, handleError("soiNamespace", err) + } + } + if mySQL.cache != nil { + mySQL.cache.Add("namespace:"+namespace.Name, id) + } + + return id, nil +} + +func (mySQL *mySQL) ListNamespaces() (namespaces []database.Namespace, err error) { + rows, err := mySQL.Query(listNamespace) + if err != nil { + return namespaces, handleError("listNamespace", err) + } + defer rows.Close() + + for rows.Next() { + var namespace database.Namespace + + err = rows.Scan(&namespace.ID, &namespace.Name) + if err != nil { + return namespaces, handleError("listNamespace.Scan()", err) + } + + namespaces = append(namespaces, namespace) + } + if err = rows.Err(); err != nil { + return namespaces, handleError("listNamespace.Rows()", err) + } + + return namespaces, err +} diff --git a/database/mysql/namespace_test.go b/database/mysql/namespace_test.go new file mode 100644 index 00000000..a77b6816 --- /dev/null +++ b/database/mysql/namespace_test.go @@ -0,0 +1,66 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "fmt" + "testing" + + "github.com/coreos/clair/database" + "github.com/stretchr/testify/assert" +) + +func TestInsertNamespace(t *testing.T) { + datastore, err := OpenForTest("InsertNamespace", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Invalid Namespace. + id0, err := datastore.insertNamespace(database.Namespace{}) + assert.NotNil(t, err) + assert.Zero(t, id0) + + // Insert Namespace and ensure we can find it. + id1, err := datastore.insertNamespace(database.Namespace{Name: "TestInsertNamespace1"}) + assert.Nil(t, err) + id2, err := datastore.insertNamespace(database.Namespace{Name: "TestInsertNamespace1"}) + assert.Nil(t, err) + assert.Equal(t, id1, id2) +} + +func TestListNamespace(t *testing.T) { + datastore, err := OpenForTest("ListNamespaces", true) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + namespaces, err := datastore.ListNamespaces() + assert.Nil(t, err) + if assert.Len(t, namespaces, 2) { + for _, namespace := range namespaces { + switch namespace.Name { + case "debian:7", "debian:8": + continue + default: + assert.Error(t, fmt.Errorf("ListNamespaces should not have returned '%s'", namespace.Name)) + } + } + } +} diff --git a/database/mysql/notification.go b/database/mysql/notification.go new file mode 100644 index 00000000..d8cd2bb4 --- /dev/null +++ b/database/mysql/notification.go @@ -0,0 +1,214 @@ +package mysql + +import ( + "database/sql" + "time" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/guregu/null/zero" + "github.com/pborman/uuid" +) + +// do it in tx so we won't insert/update a vuln without notification and vice-versa. +// name and created doesn't matter. +func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error { + defer database.ObserveQueryTime("createNotification", "all", time.Now()) + + // Insert Notification. + oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} + newVulnerabilityNullableID := sql.NullInt64{Int64: int64(newVulnerabilityID), Valid: newVulnerabilityID != 0} + _, err := tx.Exec(insertNotification, uuid.New(), oldVulnerabilityNullableID, newVulnerabilityNullableID) + if err != nil { + tx.Rollback() + return handleError("insertNotification", err) + } + + return nil +} + +// Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)). +// Does not fill new/old vuln. +func (mySQL *mySQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) { + defer database.ObserveQueryTime("GetAvailableNotification", "all", time.Now()) + + before := time.Now().Add(-renotifyInterval) + row := mySQL.QueryRow(searchNotificationAvailable, before) + notification, err := mySQL.scanNotification(row, false) + + return notification, handleError("searchNotificationAvailable", err) +} + +func (mySQL *mySQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) { + defer database.ObserveQueryTime("GetNotification", "all", time.Now()) + + // Get Notification. + notification, err := mySQL.scanNotification(mySQL.QueryRow(searchNotification, name), true) + if err != nil { + return notification, page, handleError("searchNotification", err) + } + + // Load vulnerabilities' LayersIntroducingVulnerability. + page.OldVulnerability, err = mySQL.loadLayerIntroducingVulnerability( + notification.OldVulnerability, + limit, + page.OldVulnerability, + ) + + if err != nil { + return notification, page, err + } + + page.NewVulnerability, err = mySQL.loadLayerIntroducingVulnerability( + notification.NewVulnerability, + limit, + page.NewVulnerability, + ) + + if err != nil { + return notification, page, err + } + + return notification, page, nil +} + +func (mySQL *mySQL) scanNotification(row *sql.Row, hasVulns bool) (database.VulnerabilityNotification, error) { + var notification database.VulnerabilityNotification + var created zero.Time + var notified zero.Time + var deleted zero.Time + var oldVulnerabilityNullableID sql.NullInt64 + var newVulnerabilityNullableID sql.NullInt64 + + // Scan notification. + if hasVulns { + err := row.Scan( + ¬ification.ID, + ¬ification.Name, + &created, + ¬ified, + &deleted, + &oldVulnerabilityNullableID, + &newVulnerabilityNullableID, + ) + + if err != nil { + return notification, err + } + } else { + err := row.Scan(¬ification.ID, ¬ification.Name, &created, ¬ified, &deleted) + + if err != nil { + return notification, err + } + } + + notification.Created = created.Time + notification.Notified = notified.Time + notification.Deleted = deleted.Time + + if hasVulns { + if oldVulnerabilityNullableID.Valid { + vulnerability, err := mySQL.findVulnerabilityByIDWithDeleted(int(oldVulnerabilityNullableID.Int64)) + if err != nil { + return notification, err + } + + notification.OldVulnerability = &vulnerability + } + + if newVulnerabilityNullableID.Valid { + vulnerability, err := mySQL.findVulnerabilityByIDWithDeleted(int(newVulnerabilityNullableID.Int64)) + if err != nil { + return notification, err + } + + notification.NewVulnerability = &vulnerability + } + } + + return notification, nil +} + +// Fills Vulnerability.LayersIntroducingVulnerability. +// limit -1: won't do anything +// limit 0: will just get the startID of the second page +func (mySQL *mySQL) loadLayerIntroducingVulnerability(vulnerability *database.Vulnerability, limit, startID int) (int, error) { + tf := time.Now() + + if vulnerability == nil { + return -1, nil + } + + // A startID equals to -1 means that we reached the end already. + if startID == -1 || limit == -1 { + return -1, nil + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe invalid calls. + defer database.ObserveQueryTime("loadLayerIntroducingVulnerability", "all", tf) + + // Query with limit + 1, the last item will be used to know the next starting ID. + rows, err := mySQL.Query(searchNotificationLayerIntroducingVulnerability, + vulnerability.ID, startID, limit+1) + if err != nil { + return 0, handleError("searchNotificationLayerIntroducingVulnerability", err) + } + defer rows.Close() + + var layers []database.Layer + for rows.Next() { + var layer database.Layer + + if err := rows.Scan(&layer.ID, &layer.Name); err != nil { + return -1, handleError("searchNotificationLayerIntroducingVulnerability.Scan()", err) + } + + layers = append(layers, layer) + } + if err = rows.Err(); err != nil { + return -1, handleError("searchNotificationLayerIntroducingVulnerability.Rows()", err) + } + + size := limit + if len(layers) < limit { + size = len(layers) + } + vulnerability.LayersIntroducingVulnerability = layers[:size] + + nextID := -1 + if len(layers) > limit { + nextID = layers[limit].ID + } + + return nextID, nil +} + +func (mySQL *mySQL) SetNotificationNotified(name string) error { + defer database.ObserveQueryTime("SetNotificationNotified", "all", time.Now()) + + if _, err := mySQL.Exec(updatedNotificationNotified, name); err != nil { + return handleError("updatedNotificationNotified", err) + } + return nil +} + +func (mySQL *mySQL) DeleteNotification(name string) error { + defer database.ObserveQueryTime("DeleteNotification", "all", time.Now()) + + result, err := mySQL.Exec(removeNotification, name) + if err != nil { + return handleError("removeNotification", err) + } + + affected, err := result.RowsAffected() + if err != nil { + return handleError("removeNotification.RowsAffected()", err) + } + + if affected <= 0 { + return cerrors.ErrNotFound + } + + return nil +} diff --git a/database/mysql/notification_test.go b/database/mysql/notification_test.go new file mode 100644 index 00000000..82738c4e --- /dev/null +++ b/database/mysql/notification_test.go @@ -0,0 +1,209 @@ +package mysql + +import ( + "testing" + "time" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" + "github.com/stretchr/testify/assert" +) + +func TestNotification(t *testing.T) { + datastore, err := OpenForTest("Notification", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Try to get a notification when there is none. + _, err = datastore.GetAvailableNotification(time.Second) + assert.Equal(t, cerrors.ErrNotFound, err) + + // Create some data. + f1 := database.Feature{ + Name: "TestNotificationFeature1", + Namespace: database.Namespace{Name: "TestNotificationNamespace1"}, + } + + f2 := database.Feature{ + Name: "TestNotificationFeature2", + Namespace: database.Namespace{Name: "TestNotificationNamespace1"}, + } + + l1 := database.Layer{ + Name: "TestNotificationLayer1", + Features: []database.FeatureVersion{ + { + Feature: f1, + Version: types.NewVersionUnsafe("0.1"), + }, + }, + } + + l2 := database.Layer{ + Name: "TestNotificationLayer2", + Features: []database.FeatureVersion{ + { + Feature: f1, + Version: types.NewVersionUnsafe("0.2"), + }, + }, + } + + l3 := database.Layer{ + Name: "TestNotificationLayer3", + Features: []database.FeatureVersion{ + { + Feature: f1, + Version: types.NewVersionUnsafe("0.3"), + }, + }, + } + + l4 := database.Layer{ + Name: "TestNotificationLayer4", + Features: []database.FeatureVersion{ + { + Feature: f2, + Version: types.NewVersionUnsafe("0.1"), + }, + }, + } + + if !assert.Nil(t, datastore.InsertLayer(l1)) || + !assert.Nil(t, datastore.InsertLayer(l2)) || + !assert.Nil(t, datastore.InsertLayer(l3)) || + !assert.Nil(t, datastore.InsertLayer(l4)) { + return + } + + // Insert a new vulnerability that is introduced by three layers. + v1 := database.Vulnerability{ + Name: "TestNotificationVulnerability1", + Namespace: f1.Namespace, + Description: "TestNotificationDescription1", + Link: "TestNotificationLink1", + Severity: "Unknown", + FixedIn: []database.FeatureVersion{ + { + Feature: f1, + Version: types.NewVersionUnsafe("1.0"), + }, + }, + } + assert.Nil(t, datastore.insertVulnerability(v1, false, true)) + + // Get the notification associated to the previously inserted vulnerability. + notification, err := datastore.GetAvailableNotification(time.Second) + + if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { + // Verify the renotify behaviour. + if assert.Nil(t, datastore.SetNotificationNotified(notification.Name)) { + _, err := datastore.GetAvailableNotification(time.Second) + assert.Equal(t, cerrors.ErrNotFound, err) + + time.Sleep(50 * time.Millisecond) + notificationB, err := datastore.GetAvailableNotification(20 * time.Millisecond) + assert.Nil(t, err) + assert.Equal(t, notification.Name, notificationB.Name) + + datastore.SetNotificationNotified(notification.Name) + } + + // Get notification. + filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) + if assert.Nil(t, err) { + assert.NotEqual(t, database.NoVulnerabilityNotificationPage, nextPage) + assert.Nil(t, filledNotification.OldVulnerability) + + if assert.NotNil(t, filledNotification.NewVulnerability) { + assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) + assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 2) + } + } + + // Get second page. + filledNotification, nextPage, err = datastore.GetNotification(notification.Name, 2, nextPage) + if assert.Nil(t, err) { + assert.Equal(t, database.NoVulnerabilityNotificationPage, nextPage) + assert.Nil(t, filledNotification.OldVulnerability) + + if assert.NotNil(t, filledNotification.NewVulnerability) { + assert.Equal(t, v1.Name, filledNotification.NewVulnerability.Name) + assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) + } + } + + // Delete notification. + assert.Nil(t, datastore.DeleteNotification(notification.Name)) + + _, err = datastore.GetAvailableNotification(time.Millisecond) + assert.Equal(t, cerrors.ErrNotFound, err) + } + + // Update a vulnerability and ensure that the old/new vulnerabilities are correct. + v1b := v1 + v1b.Severity = types.High + v1b.FixedIn = []database.FeatureVersion{ + { + Feature: f1, + Version: types.MinVersion, + }, + { + Feature: f2, + Version: types.MaxVersion, + }, + } + + if assert.Nil(t, datastore.insertVulnerability(v1b, false, true)) { + notification, err = datastore.GetAvailableNotification(time.Second) + assert.Nil(t, err) + assert.NotEmpty(t, notification.Name) + + if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { + filledNotification, nextPage, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) + if assert.Nil(t, err) { + if assert.NotNil(t, filledNotification.OldVulnerability) { + assert.Equal(t, v1.Name, filledNotification.OldVulnerability.Name) + assert.Equal(t, v1.Severity, filledNotification.OldVulnerability.Severity) + assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 2) + } + + if assert.NotNil(t, filledNotification.NewVulnerability) { + assert.Equal(t, v1b.Name, filledNotification.NewVulnerability.Name) + assert.Equal(t, v1b.Severity, filledNotification.NewVulnerability.Severity) + assert.Len(t, filledNotification.NewVulnerability.LayersIntroducingVulnerability, 1) + } + + assert.Equal(t, -1, nextPage.NewVulnerability) + } + + assert.Nil(t, datastore.DeleteNotification(notification.Name)) + } + } + + // Delete a vulnerability and verify the notification. + if assert.Nil(t, datastore.DeleteVulnerability(v1b.Namespace.Name, v1b.Name)) { + notification, err = datastore.GetAvailableNotification(time.Second) + assert.Nil(t, err) + assert.NotEmpty(t, notification.Name) + + if assert.Nil(t, err) && assert.NotEmpty(t, notification.Name) { + filledNotification, _, err := datastore.GetNotification(notification.Name, 2, database.VulnerabilityNotificationFirstPage) + if assert.Nil(t, err) { + assert.Nil(t, filledNotification.NewVulnerability) + + if assert.NotNil(t, filledNotification.OldVulnerability) { + assert.Equal(t, v1b.Name, filledNotification.OldVulnerability.Name) + assert.Equal(t, v1b.Severity, filledNotification.OldVulnerability.Severity) + assert.Len(t, filledNotification.OldVulnerability.LayersIntroducingVulnerability, 1) + } + } + + assert.Nil(t, datastore.DeleteNotification(notification.Name)) + } + } +} diff --git a/database/mysql/queries.go b/database/mysql/queries.go new file mode 100644 index 00000000..11cd419b --- /dev/null +++ b/database/mysql/queries.go @@ -0,0 +1,204 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import "strconv" + +const ( + // This is for lock the table + lockVulnerabilityAffects = `select count(*) from Vulnerability_Affects_FeatureVersion where vulnerability_id > 0 for update` + + // keyvalue.go + updateKeyValue = "UPDATE KeyValue SET `value` = ? WHERE `key` = ?" + insertKeyValue = "INSERT INTO KeyValue(`key`, `value`) VALUES(?, ?)" + searchKeyValue = "SELECT `value` FROM KeyValue WHERE `key` = ?" + + // namespace.go + insertNamespace = `insert into Namespace(name) select ? from dual where not exists (select * from Namespace where name = ?)` + soiNamespace = `SELECT id from Namespace WHERE name = ?` + + listNamespace = `SELECT id, name FROM Namespace` + searchNamespace = `SELECT id FROM Namespace WHERE name = ?` + + // feature.go + insertFeature = `insert into Feature(name, namespace_id) select CAST(? AS CHAR), CAST(? AS UNSIGNED) FROM dual WHERE NOT EXISTS (SELECT id FROM Feature WHERE name = ? AND namespace_id = ?)` + soiFeature = `SELECT id FROM Feature WHERE name = ? AND namespace_id = ?` + + insertFeatureVersion = ` + Insert into FeatureVersion(feature_id, version) select cast(? AS unsigned),cast(? as char) from dual where not exists (select * from FeatureVersion where feature_id = ? AND version = ?) LIMIT 1` + // TODO: need to handle 'exi' + soiFeatureVersion = `SELECT 'new', id FROM FeatureVersion WHERE feature_id = ? AND version = ?` + + searchFeatureVersion = `SELECT id FROM FeatureVersion WHERE feature_id = ? AND version = ?` + searchVulnerabilityFixedInFeature = ` + SELECT id, vulnerability_id, version FROM Vulnerability_FixedIn_Feature WHERE feature_id = ?` + + insertVulnerabilityAffectsFeatureVersion = ` + INSERT INTO Vulnerability_Affects_FeatureVersion(vulnerability_id, featureversion_id, fixedin_id) select ?,?,? from dual where not exists (select * from Vulnerability_Affects_FeatureVersion where vulnerability_id =? and featureversion_id = ? and fixedin_id = ?)` + // layer.go + searchLayer = ` + SELECT l.id, l.name, l.engineversion, p.id, p.name, n.id, n.name + FROM Layer l + LEFT JOIN Layer p ON l.parent_id = p.id + LEFT JOIN Namespace n ON l.namespace_id = n.id + WHERE l.name = ?` + + //TODO: + searchLayerFeatureVersion = ` + SELECT ldf.featureversion_id, ldf.modification, fn.id, fn.name, f.id, f.name, fv.id, fv.version, ltree.id, ltree.name + FROM Layer_diff_FeatureVersion ldf + JOIN ( + SELECT h.id,h.name,h.parent_id FROM (SELECT @Id AS tempId,name,(SELECT @Id := parent_id FROM Layer WHERE Id = tempId) AS parent_id FROM(SELECT @Id := ?)initializeVars, Layer h WHERE @Id <> 0)a INNER JOIN Layer h ON h.Id = a.tempId order by h.id desc + ) AS ltree ON ldf.layer_id = ltree.id, FeatureVersion fv, Feature f, Namespace fn + WHERE ldf.featureversion_id = fv.id AND fv.feature_id = f.id AND f.namespace_id = fn.id + ORDER BY ltree.id` + + searchFeatureVersionVulnerability = ` + SELECT vafv.featureversion_id, v.id, v.name, v.description, v.link, v.severity, v.metadata, + vn.name, vfif.version + FROM Vulnerability_Affects_FeatureVersion vafv, Vulnerability v, + Namespace vn, Vulnerability_FixedIn_Feature vfif + WHERE vafv.featureversion_id IN (%s) + AND vfif.vulnerability_id = v.id + AND vafv.fixedin_id = vfif.id + AND v.namespace_id = vn.id + AND v.deleted_at IS NULL` + + insertLayer = ` + INSERT INTO Layer(name, engineversion, parent_id, namespace_id, created_at) + VALUES(?, ?, ?, ?, CURRENT_TIMESTAMP) + ` + getLayerId = `select id from Layer where name=?` + + updateLayer = `UPDATE Layer SET engineversion = ?, namespace_id = ? WHERE id = ?` + + removeLayerDiffFeatureVersion = ` + DELETE FROM Layer_diff_FeatureVersion + WHERE layer_id = ?` + + insertLayerDiffFeatureVersion = ` + INSERT INTO Layer_diff_FeatureVersion(layer_id, featureversion_id, modification) + SELECT ?, fv.id, ? + FROM FeatureVersion fv + WHERE fv.id in(%s)` + + removeLayer = `DELETE FROM Layer WHERE name = ?` + + // lock.go + insertLock = "INSERT INTO `Lock`(name, owner, until) VALUES(?, ?, ?)" + searchLock = "SELECT owner, until FROM `Lock` WHERE name = ?" + updateLock = "UPDATE `Lock` SET until = ? WHERE name = ? AND owner = ?" + removeLock = "DELETE FROM `Lock` WHERE name = ? AND owner = ?" + removeLockExpired = "DELETE FROM `Lock` WHERE until < CURRENT_TIMESTAMP" + + // vulnerability.go + searchVulnerabilityBase = ` + SELECT v.id, v.name, n.id, n.name, v.description, v.link, v.severity, v.metadata + FROM Vulnerability v JOIN Namespace n ON v.namespace_id = n.id` + searchVulnerabilityForUpdate = ` FOR UPDATE ` + searchVulnerabilityByNamespaceAndName = ` WHERE n.name = ? AND v.name = ? AND v.deleted_at IS NULL` + searchVulnerabilityByID = ` WHERE v.id = ?` + searchVulnerabilityByNamespace = ` WHERE n.name = ? AND v.deleted_at IS NULL + AND v.id >= ? + ORDER BY v.id + LIMIT ?` + + searchVulnerabilityFixedIn = ` + SELECT vfif.version, f.id, f.Name + FROM Vulnerability_FixedIn_Feature vfif JOIN Feature f ON vfif.feature_id = f.id + WHERE vfif.vulnerability_id = ?` + + insertVulnerability = ` + INSERT INTO Vulnerability(namespace_id, name, description, link, severity, metadata, created_at) + VALUES(?, ?, ?, ?, ?, ?, CURRENT_TIMESTAMP) + ` + getVulId = `select id from Vulnerability where namespace_id=? and name=?` + getVulIdWithNamespaceName = `select id from Vulnerability where namespace_id=(select id from Namespace where name=?) and name=? and deleted_at IS NULL` + insertVulnerabilityFixedInFeature = ` + INSERT INTO Vulnerability_FixedIn_Feature(vulnerability_id, feature_id, version) + VALUES(?, ?, ?) + ` + findVulnerabilityFixedInFeature = `select id from Vulnerability_FixedIn_Feature where vulnerability_id=? AND feature_id=? AND version=?` + + searchFeatureVersionByFeature = `SELECT id, version FROM FeatureVersion WHERE feature_id = ?` + + removeVulnerability = ` + UPDATE Vulnerability + SET deleted_at = CURRENT_TIMESTAMP + WHERE namespace_id = (SELECT id FROM Namespace WHERE name = ?) + AND name = ? + AND deleted_at IS NULL` + + // notification.go + insertNotification = ` + INSERT INTO Vulnerability_Notification(name, created_at, old_vulnerability_id, new_vulnerability_id) + VALUES(?, CURRENT_TIMESTAMP, ?, ?)` + + updatedNotificationNotified = ` + UPDATE Vulnerability_Notification + SET notified_at = CURRENT_TIMESTAMP + WHERE name = ?` + + removeNotification = ` + UPDATE Vulnerability_Notification + SET deleted_at = CURRENT_TIMESTAMP + WHERE name = ?` + + searchNotificationAvailable = " SELECT id, name, created_at, notified_at, deleted_at" + + " FROM Vulnerability_Notification" + + " WHERE (notified_at IS NULL OR notified_at < ?)" + + " AND deleted_at IS NULL" + + " AND name NOT IN (SELECT name FROM `Lock`)" + + " ORDER BY Rand()" + + " LIMIT 1" + + searchNotification = ` + SELECT id, name, created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id + FROM Vulnerability_Notification + WHERE name = ?` + + searchNotificationLayerIntroducingVulnerability = ` + SELECT l.ID, l.name + FROM Vulnerability v, Vulnerability_Affects_FeatureVersion vafv, FeatureVersion fv, Layer_diff_FeatureVersion ldfv, Layer l + WHERE v.id = ? + AND v.id = vafv.vulnerability_id + AND vafv.featureversion_id = fv.id + AND fv.id = ldfv.featureversion_id + AND ldfv.modification = 'add' + AND ldfv.layer_id = l.id + AND l.id >= ? + ORDER BY l.ID + LIMIT ?` + + // complex_test.go + searchComplexTestFeatureVersionAffects = ` + SELECT v.name + FROM FeatureVersion fv + LEFT JOIN Vulnerability_Affects_FeatureVersion vaf ON fv.id = vaf.featureversion_id + JOIN Vulnerability v ON vaf.vulnerability_id = v.id + WHERE featureversion_id = ?` +) + +// buildInputArray constructs a PostgreSQL input array from the specified integers. +// Useful to use the `= ANY($1::integer[])` syntax that let us use a IN clause while using +// a single placeholder. +func buildInputArray(ints []int) string { + str := "" + for i := 0; i < len(ints)-1; i++ { + str = str + strconv.Itoa(ints[i]) + "," + } + str = str + strconv.Itoa(ints[len(ints)-1]) + return str +} diff --git a/database/mysql/testdata/data.sql b/database/mysql/testdata/data.sql new file mode 100644 index 00000000..7b2890f8 --- /dev/null +++ b/database/mysql/testdata/data.sql @@ -0,0 +1,55 @@ +-- Copyright 2015 clair authors +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +INSERT INTO Namespace (id, name) VALUES + (1, 'debian:7'), + (2, 'debian:8'); + +INSERT INTO Feature (id, namespace_id, name) VALUES + (1, 1, 'wechat'), + (2, 1, 'openssl'), + (4, 1, 'libssl'), + (3, 2, 'openssl'); + +INSERT INTO FeatureVersion (id, feature_id, version) VALUES + (1, 1, '0.5'), + (2, 2, '1.0'), + (3, 2, '2.0'), + (4, 3, '1.0'); + +INSERT INTO Layer (id, name, engineversion, parent_id, namespace_id) VALUES + (1, 'layer-0', 1, NULL, NULL), + (2, 'layer-1', 1, 1, 1), + (3, 'layer-2', 1, 2, 1), + (4, 'layer-3a', 1, 3, 1), + (5, 'layer-3b', 1, 3, 2); + +INSERT INTO Layer_diff_FeatureVersion (id, layer_id, featureversion_id, modification) VALUES + (1, 2, 1, 'add'), + (2, 2, 2, 'add'), + (3, 3, 2, 'del'), -- layer-2: Update Debian:7 OpenSSL 1.0 -> 2.0 + (4, 3, 3, 'add'), -- ^ + (5, 5, 3, 'del'), -- layer-3b: Delete Debian:7 OpenSSL 2.0 + (6, 5, 4, 'add'); -- layer-3b: Add Debian:8 OpenSSL 1.0 + +INSERT INTO Vulnerability (id, namespace_id, name, description, link, severity) VALUES + (1, 1, 'CVE-OPENSSL-1-DEB7', 'A vulnerability affecting OpenSSL < 2.0 on Debian 7.0', 'http://google.com/#q=CVE-OPENSSL-1-DEB7', 'High'), + (2, 1, 'CVE-NOPE', 'A vulnerability affecting nothing', '', 'Unknown'); + +INSERT INTO Vulnerability_FixedIn_Feature (id, vulnerability_id, feature_id, version) VALUES + (1, 1, 2, '2.0'), + (2, 1, 4, '1.9-abc'); + +INSERT INTO Vulnerability_Affects_FeatureVersion (id, vulnerability_id, featureversion_id, fixedin_id) VALUES + (1, 1, 2, 1); -- CVE-OPENSSL-1-DEB7 affects Debian:7 OpenSSL 1.0 diff --git a/database/mysql/vulnerability.go b/database/mysql/vulnerability.go new file mode 100644 index 00000000..03493e69 --- /dev/null +++ b/database/mysql/vulnerability.go @@ -0,0 +1,593 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "database/sql" + "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/coreos/clair/database" + "github.com/coreos/clair/utils" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" + "github.com/guregu/null/zero" +) + +func (mySQL *mySQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) { + defer database.ObserveQueryTime("listVulnerabilities", "all", time.Now()) + + // Query Namespace. + var id int + err := mySQL.QueryRow(searchNamespace, namespaceName).Scan(&id) + if err != nil { + return nil, -1, handleError("searchNamespace", err) + } else if id == 0 { + return nil, -1, cerrors.ErrNotFound + } + + // Query. + query := searchVulnerabilityBase + searchVulnerabilityByNamespace + rows, err := mySQL.Query(query, namespaceName, startID, limit+1) + if err != nil { + return nil, -1, handleError("searchVulnerabilityByNamespace", err) + } + defer rows.Close() + + var vulns []database.Vulnerability + nextID := -1 + size := 0 + // Scan query. + for rows.Next() { + var vulnerability database.Vulnerability + + err := rows.Scan( + &vulnerability.ID, + &vulnerability.Name, + &vulnerability.Namespace.ID, + &vulnerability.Namespace.Name, + &vulnerability.Description, + &vulnerability.Link, + &vulnerability.Severity, + &vulnerability.Metadata, + ) + if err != nil { + return nil, -1, handleError("searchVulnerabilityByNamespace.Scan()", err) + } + size++ + if size > limit { + nextID = vulnerability.ID + } else { + vulns = append(vulns, vulnerability) + } + } + + if err := rows.Err(); err != nil { + return nil, -1, handleError("searchVulnerabilityByNamespace.Rows()", err) + } + + return vulns, nextID, nil +} + +func (mySQL *mySQL) FindVulnerability(namespaceName, name string) (database.Vulnerability, error) { + return findVulnerability(mySQL, namespaceName, name, false) +} + +func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { + defer database.ObserveQueryTime("findVulnerability", "all", time.Now()) + + queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName" + query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName + if forUpdate { + queryName = queryName + "+searchVulnerabilityForUpdate" + query = query + searchVulnerabilityForUpdate + } + + return scanVulnerability(queryer, queryName, queryer.QueryRow(query, namespaceName, name)) +} + +func (mySQL *mySQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) { + defer database.ObserveQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) + + queryName := "searchVulnerabilityBase+searchVulnerabilityByID" + query := searchVulnerabilityBase + searchVulnerabilityByID + + return scanVulnerability(mySQL, queryName, mySQL.QueryRow(query, id)) +} + +func scanVulnerability(queryer Queryer, queryName string, vulnerabilityRow *sql.Row) (database.Vulnerability, error) { + var vulnerability database.Vulnerability + + err := vulnerabilityRow.Scan( + &vulnerability.ID, + &vulnerability.Name, + &vulnerability.Namespace.ID, + &vulnerability.Namespace.Name, + &vulnerability.Description, + &vulnerability.Link, + &vulnerability.Severity, + &vulnerability.Metadata, + ) + + if err != nil { + return vulnerability, handleError(queryName+".Scan()", err) + } + + if vulnerability.ID == 0 { + return vulnerability, cerrors.ErrNotFound + } + + // Query the FixedIn FeatureVersion now. + rows, err := queryer.Query(searchVulnerabilityFixedIn, vulnerability.ID) + if err != nil { + return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) + } + defer rows.Close() + + for rows.Next() { + var featureVersionID zero.Int + var featureVersionVersion zero.String + var featureVersionFeatureName zero.String + + err := rows.Scan( + &featureVersionVersion, + &featureVersionID, + &featureVersionFeatureName, + ) + + if err != nil { + return vulnerability, handleError("searchVulnerabilityFixedIn.Scan()", err) + } + + if !featureVersionID.IsZero() { + // Note that the ID we fill in featureVersion is actually a Feature ID, and not + // a FeatureVersion ID. + featureVersion := database.FeatureVersion{ + Model: database.Model{ID: int(featureVersionID.Int64)}, + Feature: database.Feature{ + Model: database.Model{ID: int(featureVersionID.Int64)}, + Namespace: vulnerability.Namespace, + Name: featureVersionFeatureName.String, + }, + Version: types.NewVersionUnsafe(featureVersionVersion.String), + } + vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + } + } + + if err := rows.Err(); err != nil { + return vulnerability, handleError("searchVulnerabilityFixedIn.Rows()", err) + } + + return vulnerability, nil +} + +// FixedIn.Namespace are not necessary, they are overwritten by the vuln. +// By setting the fixed version to minVersion, we can say that the vuln does'nt affect anymore. +func (mySQL *mySQL) InsertVulnerabilities(vulnerabilities []database.Vulnerability, generateNotifications bool) error { + for _, vulnerability := range vulnerabilities { + err := mySQL.insertVulnerability(vulnerability, false, generateNotifications) + if err != nil { + fmt.Printf("%#v\n", vulnerability) + return err + } + } + return nil +} + +func (mySQL *mySQL) insertVulnerability(vulnerability database.Vulnerability, onlyFixedIn, generateNotification bool) error { + tf := time.Now() + + // Verify parameters + if vulnerability.Name == "" || vulnerability.Namespace.Name == "" { + return cerrors.NewBadRequestError("insertVulnerability needs at least the Name and the Namespace") + } + if !onlyFixedIn && !vulnerability.Severity.IsValid() { + msg := fmt.Sprintf("could not insert a vulnerability that has an invalid Severity: %s", vulnerability.Severity) + log.Warning(msg) + return cerrors.NewBadRequestError(msg) + } + for i := 0; i < len(vulnerability.FixedIn); i++ { + fifv := &vulnerability.FixedIn[i] + + if fifv.Feature.Namespace.Name == "" { + // As there is no Namespace on that FixedIn FeatureVersion, set it to the Vulnerability's + // Namespace. + fifv.Feature.Namespace.Name = vulnerability.Namespace.Name + } else if fifv.Feature.Namespace.Name != vulnerability.Namespace.Name { + msg := "could not insert an invalid vulnerability that contains FixedIn FeatureVersion that are not in the same namespace as the Vulnerability" + log.Warning(msg) + return cerrors.NewBadRequestError(msg) + } + } + + // We do `defer database.ObserveQueryTime` here because we don't want to observe invalid vulnerabilities. + defer database.ObserveQueryTime("insertVulnerability", "all", tf) + // Begin transaction. + tx, err := mySQL.Begin() + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.Begin()", err) + } + // Find existing vulnerability and its Vulnerability_FixedIn_Features (for update). + existingVulnerability, err := findVulnerability(tx, vulnerability.Namespace.Name, vulnerability.Name, false) + if err != nil && err != cerrors.ErrNotFound { + tx.Rollback() + return err + } + + if onlyFixedIn { + // Because this call tries to update FixedIn FeatureVersion, import all other data from the + // existing one. + if existingVulnerability.ID == 0 { + return cerrors.ErrNotFound + } + + fixedIn := vulnerability.FixedIn + vulnerability = existingVulnerability + vulnerability.FixedIn = fixedIn + } + + if existingVulnerability.ID != 0 { + updateMetadata := vulnerability.Description != existingVulnerability.Description || + vulnerability.Link != existingVulnerability.Link || + vulnerability.Severity != existingVulnerability.Severity || + !reflect.DeepEqual(castMetadata(vulnerability.Metadata), existingVulnerability.Metadata) + + // Construct the entire list of FixedIn FeatureVersion, by using the + // the FixedIn list of the old vulnerability. + // + // TODO(Quentin-M): We could use !updateFixedIn to just copy FixedIn/Affects rows from the + // existing vulnerability in order to make metadata updates much faster. + var updateFixedIn bool + vulnerability.FixedIn, updateFixedIn = applyFixedInDiff(existingVulnerability.FixedIn, vulnerability.FixedIn) + + if !updateMetadata && !updateFixedIn { + tx.Commit() + return nil + } + + // Mark the old vulnerability as non latest. + _, err = tx.Exec(removeVulnerability, vulnerability.Namespace.Name, vulnerability.Name) + if err != nil { + tx.Rollback() + return handleError("removeVulnerability", err) + } + } else { + // The vulnerability is new, we don't want to have any types.MinVersion as they are only used + // for diffing existing vulnerabilities. + var fixedIn []database.FeatureVersion + for _, fv := range vulnerability.FixedIn { + if fv.Version != types.MinVersion { + fixedIn = append(fixedIn, fv) + } + } + vulnerability.FixedIn = fixedIn + } + // Find or insert Vulnerability's Namespace. + namespaceID, err := mySQL.insertNamespaceWithTransaction(tx, vulnerability.Namespace) + if err != nil { + tx.Rollback() + return err + } + + // Insert vulnerability. + res, err := tx.Exec( + insertVulnerability, + namespaceID, + vulnerability.Name, + vulnerability.Description, + vulnerability.Link, + &vulnerability.Severity, + &vulnerability.Metadata, + ) + + if err != nil { + tx.Rollback() + return handleError("insertVulnerability", err) + } + ID, err := res.LastInsertId() + if err != nil { + tx.Rollback() + return err + + } + vulnerability.ID = int(ID) + // vulnerability.ID == 0 means the vulnerability is already exists + if vulnerability.ID == 0 { + err = tx.QueryRow(getVulId, namespaceID, vulnerability.Name).Scan(&vulnerability.ID) + if err != nil { + tx.Rollback() + return err + + } + } + // Update Vulnerability_FixedIn_Feature and Vulnerability_Affects_FeatureVersion now. + err = mySQL.insertVulnerabilityFixedInFeatureVersions(tx, vulnerability.ID, vulnerability.FixedIn) + if err != nil { + tx.Rollback() + return err + } + + // Create a notification. + if generateNotification { + err = createNotification(tx, existingVulnerability.ID, vulnerability.ID) + if err != nil { + return err + } + } + + // Commit transaction. + err = tx.Commit() + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.Commit()", err) + } + + return nil +} + +// castMetadata marshals the given database.MetadataMap and unmarshals it again to make sure that +// everything has the interface{} type. +// It is required when comparing crafted MetadataMap against MetadataMap that we get from the +// database. +func castMetadata(m database.MetadataMap) database.MetadataMap { + c := make(database.MetadataMap) + j, _ := json.Marshal(m) + json.Unmarshal(j, &c) + return c +} + +// applyFixedInDiff applies a FeatureVersion diff on a FeatureVersion list and returns the result. +func applyFixedInDiff(currentList, diff []database.FeatureVersion) ([]database.FeatureVersion, bool) { + currentMap, currentNames := createFeatureVersionNameMap(currentList) + diffMap, diffNames := createFeatureVersionNameMap(diff) + + addedNames := utils.CompareStringLists(diffNames, currentNames) + inBothNames := utils.CompareStringListsInBoth(diffNames, currentNames) + + different := false + + for _, name := range addedNames { + if diffMap[name].Version == types.MinVersion { + // MinVersion only makes sense when a Feature is already fixed in some version, + // in which case we would be in the "inBothNames". + continue + } + + currentMap[name] = diffMap[name] + different = true + } + + for _, name := range inBothNames { + fv := diffMap[name] + + if fv.Version == types.MinVersion { + // MinVersion means that the Feature doesn't affect the Vulnerability anymore. + delete(currentMap, name) + different = true + } else if fv.Version != currentMap[name].Version { + // The version got updated. + currentMap[name] = diffMap[name] + different = true + } + } + + // Convert currentMap to a slice and return it. + var newList []database.FeatureVersion + for _, fv := range currentMap { + newList = append(newList, fv) + } + + return newList, different +} + +func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string]database.FeatureVersion, []string) { + m := make(map[string]database.FeatureVersion, 0) + s := make([]string, 0, len(features)) + + for i := 0; i < len(features); i++ { + featureVersion := features[i] + m[featureVersion.Feature.Name] = featureVersion + s = append(s, featureVersion.Feature.Name) + } + + return m, s +} + +// insertVulnerabilityFixedInFeatureVersions populates Vulnerability_FixedIn_Feature for the given +// vulnerability with the specified database.FeatureVersion list and uses +// linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to +// Vulnerability_Affects_FeatureVersion. +func (mySQL *mySQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error { + defer database.ObserveQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now()) + + // Insert or find the Features. + // TODO(Quentin-M): Batch me. + var err error + var features []*database.Feature + for i := 0; i < len(fixedIn); i++ { + features = append(features, &fixedIn[i].Feature) + } + for _, feature := range features { + if feature.ID == 0 { + if feature.ID, err = mySQL.insertFeatureiWithTransaction(tx, *feature); err != nil { + return err + } + } + } + + // Lock Vulnerability_Affects_FeatureVersion exclusively. + // We want to prevent InsertFeatureVersion to modify it. + database.PromConcurrentLockVAFV.Inc() + defer database.PromConcurrentLockVAFV.Dec() + t := time.Now() + var tmp int64 + err = tx.QueryRow(lockVulnerabilityAffects).Scan(&tmp) + database.ObserveQueryTime("insertVulnerability", "lock", t) + + if err != nil { + tx.Rollback() + return handleError("insertVulnerability.lockVulnerabilityAffects", err) + } + + for _, fv := range fixedIn { + var fixedInID int + + // Insert Vulnerability_FixedIn_Feature. + _, err = tx.Exec( + insertVulnerabilityFixedInFeature, + vulnerabilityID, fv.Feature.ID, + &fv.Version, + ) + + if err != nil { + return handleError("insertVulnerabilityFixedInFeature", err) + } + err = tx.QueryRow(findVulnerabilityFixedInFeature, vulnerabilityID, fv.Feature.ID, &fv.Version).Scan(&fixedInID) + if err != nil { + return handleError("findVulnerabilityFixedInFeature", err) + } + + // Insert Vulnerability_Affects_FeatureVersion. + err = linkVulnerabilityToFeatureVersions(tx, fixedInID, vulnerabilityID, fv.Feature.ID, fv.Version) + if err != nil { + return err + } + } + + return nil +} + +func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, featureID int, fixedInVersion types.Version) error { + // Find every FeatureVersions of the Feature that the vulnerability affects. + // TODO(Quentin-M): LIMIT + rows, err := tx.Query(searchFeatureVersionByFeature, featureID) + if err != nil { + return handleError("searchFeatureVersionByFeature", err) + } + defer rows.Close() + + var affecteds []database.FeatureVersion + for rows.Next() { + var affected database.FeatureVersion + + err := rows.Scan(&affected.ID, &affected.Version) + if err != nil { + return handleError("searchFeatureVersionByFeature.Scan()", err) + } + + if affected.Version.Compare(fixedInVersion) < 0 { + // The version of the FeatureVersion is lower than the fixed version of this vulnerability, + // thus, this FeatureVersion is affected by it. + affecteds = append(affecteds, affected) + } + } + if err = rows.Err(); err != nil { + return handleError("searchFeatureVersionByFeature.Rows()", err) + } + rows.Close() + + // Insert into Vulnerability_Affects_FeatureVersion. + for _, affected := range affecteds { + // TODO(Quentin-M): Batch me. + _, err := tx.Exec(insertVulnerabilityAffectsFeatureVersion, vulnerabilityID, + affected.ID, fixedInID, vulnerabilityID, affected.ID, fixedInID) + if err != nil { + return handleError("insertVulnerabilityAffectsFeatureVersion", err) + } + } + + return nil +} + +func (mySQL *mySQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { + defer database.ObserveQueryTime("InsertVulnerabilityFixes", "all", time.Now()) + + v := database.Vulnerability{ + Name: vulnerabilityName, + Namespace: database.Namespace{ + Name: vulnerabilityNamespace, + }, + FixedIn: fixes, + } + + return mySQL.insertVulnerability(v, true, true) +} + +func (mySQL *mySQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { + defer database.ObserveQueryTime("DeleteVulnerabilityFix", "all", time.Now()) + + v := database.Vulnerability{ + Name: vulnerabilityName, + Namespace: database.Namespace{ + Name: vulnerabilityNamespace, + }, + FixedIn: []database.FeatureVersion{ + { + Feature: database.Feature{ + Name: featureName, + Namespace: database.Namespace{ + Name: vulnerabilityNamespace, + }, + }, + Version: types.MinVersion, + }, + }, + } + + return mySQL.insertVulnerability(v, true, true) +} + +func (mySQL *mySQL) DeleteVulnerability(namespaceName, name string) error { + defer database.ObserveQueryTime("DeleteVulnerability", "all", time.Now()) + + // Begin transaction. + tx, err := mySQL.Begin() + if err != nil { + tx.Rollback() + return handleError("DeleteVulnerability.Begin()", err) + } + + var vulnerabilityID int + err = tx.QueryRow(getVulIdWithNamespaceName, namespaceName, name).Scan(&vulnerabilityID) + if err != nil { + tx.Rollback() + return handleError("removeVulnerability", err) + } + + _, err = tx.Exec(removeVulnerability, namespaceName, name) + if err != nil { + tx.Rollback() + return handleError("removeVulnerability", err) + } + + // Create a notification. + err = createNotification(tx, vulnerabilityID, 0) + if err != nil { + return err + } + + // Commit transaction. + err = tx.Commit() + if err != nil { + tx.Rollback() + return handleError("DeleteVulnerability.Commit()", err) + } + + return nil +} diff --git a/database/mysql/vulnerability_test.go b/database/mysql/vulnerability_test.go new file mode 100644 index 00000000..a65a228b --- /dev/null +++ b/database/mysql/vulnerability_test.go @@ -0,0 +1,276 @@ +// Copyright 2015 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package mysql + +import ( + "reflect" + "testing" + + "github.com/coreos/clair/database" + cerrors "github.com/coreos/clair/utils/errors" + "github.com/coreos/clair/utils/types" + "github.com/stretchr/testify/assert" +) + +func TestFindVulnerability(t *testing.T) { + datastore, err := OpenForTest("FindVulnerability", true) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Find a vulnerability that does not exist. + _, err = datastore.FindVulnerability("", "") + assert.Equal(t, cerrors.ErrNotFound, err) + + // Find a normal vulnerability. + v1 := database.Vulnerability{ + Name: "CVE-OPENSSL-1-DEB7", + Description: "A vulnerability affecting OpenSSL < 2.0 on Debian 7.0", + Link: "http://google.com/#q=CVE-OPENSSL-1-DEB7", + Severity: types.High, + Namespace: database.Namespace{Name: "debian:7"}, + FixedIn: []database.FeatureVersion{ + { + Feature: database.Feature{Name: "openssl"}, + Version: types.NewVersionUnsafe("2.0"), + }, + { + Feature: database.Feature{Name: "libssl"}, + Version: types.NewVersionUnsafe("1.9-abc"), + }, + }, + } + + v1f, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") + if assert.Nil(t, err) { + equalsVuln(t, &v1, &v1f) + } + + // Find a vulnerability that has no link, no severity and no FixedIn. + v2 := database.Vulnerability{ + Name: "CVE-NOPE", + Description: "A vulnerability affecting nothing", + Namespace: database.Namespace{Name: "debian:7"}, + Severity: types.Unknown, + } + + v2f, err := datastore.FindVulnerability("debian:7", "CVE-NOPE") + if assert.Nil(t, err) { + equalsVuln(t, &v2, &v2f) + } +} + +func TestDeleteVulnerability(t *testing.T) { + datastore, err := OpenForTest("InsertVulnerability", true) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Delete non-existing Vulnerability. + err = datastore.DeleteVulnerability("TestDeleteVulnerabilityNamespace1", "CVE-OPENSSL-1-DEB7") + assert.Equal(t, cerrors.ErrNotFound, err) + err = datastore.DeleteVulnerability("debian:7", "TestDeleteVulnerabilityVulnerability1") + assert.Equal(t, cerrors.ErrNotFound, err) + + // Delete Vulnerability. + err = datastore.DeleteVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") + if assert.Nil(t, err) { + _, err := datastore.FindVulnerability("debian:7", "CVE-OPENSSL-1-DEB7") + assert.Equal(t, cerrors.ErrNotFound, err) + } +} + +func TestInsertVulnerability(t *testing.T) { + datastore, err := OpenForTest("InsertVulnerability", false) + if err != nil { + t.Error(err) + return + } + defer datastore.Close() + + // Create some data. + n1 := database.Namespace{Name: "TestInsertVulnerabilityNamespace1"} + n2 := database.Namespace{Name: "TestInsertVulnerabilityNamespace2"} + + f1 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion1", + Namespace: n1, + }, + Version: types.NewVersionUnsafe("1.0"), + } + f2 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion1", + Namespace: n2, + }, + Version: types.NewVersionUnsafe("1.0"), + } + f3 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion2", + }, + Version: types.MaxVersion, + } + f4 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion2", + }, + Version: types.NewVersionUnsafe("1.4"), + } + f5 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion3", + }, + Version: types.NewVersionUnsafe("1.5"), + } + f6 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion4", + }, + Version: types.NewVersionUnsafe("0.1"), + } + f7 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion5", + }, + Version: types.MaxVersion, + } + f8 := database.FeatureVersion{ + Feature: database.Feature{ + Name: "TestInsertVulnerabilityFeatureVersion5", + }, + Version: types.MinVersion, + } + + // Insert invalid vulnerabilities. + for _, vulnerability := range []database.Vulnerability{ + { + Name: "", + Namespace: n1, + FixedIn: []database.FeatureVersion{f1}, + Severity: types.Unknown, + }, + { + Name: "TestInsertVulnerability0", + Namespace: database.Namespace{}, + FixedIn: []database.FeatureVersion{f1}, + Severity: types.Unknown, + }, + { + Name: "TestInsertVulnerability0-", + Namespace: database.Namespace{}, + FixedIn: []database.FeatureVersion{f1}, + }, + { + Name: "TestInsertVulnerability0", + Namespace: n1, + FixedIn: []database.FeatureVersion{f1}, + Severity: types.Priority(""), + }, + { + Name: "TestInsertVulnerability0", + Namespace: n1, + FixedIn: []database.FeatureVersion{f2}, + Severity: types.Unknown, + }, + } { + err := datastore.InsertVulnerabilities([]database.Vulnerability{vulnerability}, true) + assert.Error(t, err) + } + + // Insert a simple vulnerability and find it. + v1meta := make(map[string]interface{}) + v1meta["TestInsertVulnerabilityMetadata1"] = "TestInsertVulnerabilityMetadataValue1" + v1meta["TestInsertVulnerabilityMetadata2"] = struct { + Test string + }{ + Test: "TestInsertVulnerabilityMetadataValue1", + } + + v1 := database.Vulnerability{ + Name: "TestInsertVulnerability1", + Namespace: n1, + FixedIn: []database.FeatureVersion{f1, f3, f6, f7}, + Severity: types.Low, + Description: "TestInsertVulnerabilityDescription1", + Link: "TestInsertVulnerabilityLink1", + Metadata: v1meta, + } + err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) + if assert.Nil(t, err) { + v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) + if assert.Nil(t, err) { + equalsVuln(t, &v1, &v1f) + } + } + + // Update vulnerability. + v1.Description = "TestInsertVulnerabilityLink2" + v1.Link = "TestInsertVulnerabilityLink2" + v1.Severity = types.High + // Update f3 in f4, add fixed in f5, add fixed in f6 which already exists, removes fixed in f7 by + // adding f8 which is f7 but with MinVersion. + v1.FixedIn = []database.FeatureVersion{f4, f5, f6, f8} + + err = datastore.InsertVulnerabilities([]database.Vulnerability{v1}, true) + if assert.Nil(t, err) { + v1f, err := datastore.FindVulnerability(n1.Name, v1.Name) + if assert.Nil(t, err) { + // We already had f1 before the update. + // Add it to the struct for comparison. + v1.FixedIn = append(v1.FixedIn, f1) + + // Removes f8 from the struct for comparison as it was just here to cancel f7. + for i := 0; i < len(v1.FixedIn); i++ { + if v1.FixedIn[i].Feature.Name == f8.Feature.Name { + v1.FixedIn = append(v1.FixedIn[:i], v1.FixedIn[i+1:]...) + } + } + + equalsVuln(t, &v1, &v1f) + } + } +} + +func equalsVuln(t *testing.T, expected, actual *database.Vulnerability) { + assert.Equal(t, expected.Name, actual.Name) + assert.Equal(t, expected.Namespace.Name, actual.Namespace.Name) + assert.Equal(t, expected.Description, actual.Description) + assert.Equal(t, expected.Link, actual.Link) + assert.Equal(t, expected.Severity, actual.Severity) + assert.True(t, reflect.DeepEqual(castMetadata(expected.Metadata), actual.Metadata), "Got metadata %#v, expected %#v", actual.Metadata, castMetadata(expected.Metadata)) + + if assert.Len(t, actual.FixedIn, len(expected.FixedIn)) { + for _, actualFeatureVersion := range actual.FixedIn { + found := false + for _, expectedFeatureVersion := range expected.FixedIn { + if expectedFeatureVersion.Feature.Name == actualFeatureVersion.Feature.Name { + found = true + + assert.Equal(t, expected.Namespace.Name, actualFeatureVersion.Feature.Namespace.Name) + assert.Equal(t, expectedFeatureVersion.Version, actualFeatureVersion.Version) + } + } + if !found { + t.Errorf("unexpected package %s in %s", actualFeatureVersion.Feature.Name, expected.Name) + } + } + } +} diff --git a/database/pgsql/feature.go b/database/pgsql/feature.go index a2f2abe8..0f567217 100644 --- a/database/pgsql/feature.go +++ b/database/pgsql/feature.go @@ -30,16 +30,16 @@ func (pgSQL *pgSQL) insertFeature(feature database.Feature) (int, error) { // Do cache lookup. if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("feature").Inc() + database.PromCacheQueriesTotal.WithLabelValues("feature").Inc() id, found := pgSQL.cache.Get("feature:" + feature.Namespace.Name + ":" + feature.Name) if found { - promCacheHitsTotal.WithLabelValues("feature").Inc() + database.PromCacheHitsTotal.WithLabelValues("feature").Inc() return id.(int), nil } } - // We do `defer observeQueryTime` here because we don't want to observe cached features. - defer observeQueryTime("insertFeature", "all", time.Now()) + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached features. + defer database.ObserveQueryTime("insertFeature", "all", time.Now()) // Find or create Namespace. namespaceID, err := pgSQL.insertNamespace(feature.Namespace) @@ -69,21 +69,21 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) // Do cache lookup. cacheIndex := "featureversion:" + featureVersion.Feature.Namespace.Name + ":" + featureVersion.Feature.Name + ":" + featureVersion.Version.String() if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("featureversion").Inc() + database.PromCacheQueriesTotal.WithLabelValues("featureversion").Inc() id, found := pgSQL.cache.Get(cacheIndex) if found { - promCacheHitsTotal.WithLabelValues("featureversion").Inc() + database.PromCacheHitsTotal.WithLabelValues("featureversion").Inc() return id.(int), nil } } - // We do `defer observeQueryTime` here because we don't want to observe cached featureversions. - defer observeQueryTime("insertFeatureVersion", "all", time.Now()) + // We do `defer database.ObserveQueryTime` here because we don't want to observe cached featureversions. + defer database.ObserveQueryTime("insertFeatureVersion", "all", time.Now()) // Find or create Feature first. t := time.Now() featureID, err := pgSQL.insertFeature(featureVersion.Feature) - observeQueryTime("insertFeatureVersion", "insertFeature", t) + database.ObserveQueryTime("insertFeatureVersion", "insertFeature", t) if err != nil { return 0, err @@ -117,11 +117,11 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) // Lock Vulnerability_Affects_FeatureVersion exclusively. // We want to prevent InsertVulnerability to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() + database.PromConcurrentLockVAFV.Inc() + defer database.PromConcurrentLockVAFV.Dec() t = time.Now() _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertFeatureVersion", "lock", t) + database.ObserveQueryTime("insertFeatureVersion", "lock", t) if err != nil { tx.Rollback() @@ -134,7 +134,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) t = time.Now() err = tx.QueryRow(soiFeatureVersion, featureID, &featureVersion.Version). Scan(&newOrExisting, &featureVersion.ID) - observeQueryTime("insertFeatureVersion", "soiFeatureVersion", t) + database.ObserveQueryTime("insertFeatureVersion", "soiFeatureVersion", t) if err != nil { tx.Rollback() @@ -156,7 +156,7 @@ func (pgSQL *pgSQL) insertFeatureVersion(featureVersion database.FeatureVersion) // Vulnerability_Affects_FeatureVersion. t = time.Now() err = linkFeatureVersionToVulnerabilities(tx, featureVersion) - observeQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t) + database.ObserveQueryTime("insertFeatureVersion", "linkFeatureVersionToVulnerabilities", t) if err != nil { tx.Rollback() diff --git a/database/pgsql/keyvalue.go b/database/pgsql/keyvalue.go index 264774c7..af1ca396 100644 --- a/database/pgsql/keyvalue.go +++ b/database/pgsql/keyvalue.go @@ -18,6 +18,7 @@ import ( "database/sql" "time" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" ) @@ -28,7 +29,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { return cerrors.NewBadRequestError("could not insert a flag which has an empty name or value") } - defer observeQueryTime("InsertKeyValue", "all", time.Now()) + defer database.ObserveQueryTime("InsertKeyValue", "all", time.Now()) // Upsert. // @@ -67,7 +68,7 @@ func (pgSQL *pgSQL) InsertKeyValue(key, value string) (err error) { // GetValue reads a single key / value tuple and returns an empty string if the key doesn't exist. func (pgSQL *pgSQL) GetKeyValue(key string) (string, error) { - defer observeQueryTime("GetKeyValue", "all", time.Now()) + defer database.ObserveQueryTime("GetKeyValue", "all", time.Now()) var value string err := pgSQL.QueryRow(searchKeyValue, key).Scan(&value) diff --git a/database/pgsql/layer.go b/database/pgsql/layer.go index 71e5726a..92bb95dd 100644 --- a/database/pgsql/layer.go +++ b/database/pgsql/layer.go @@ -31,7 +31,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo } else if withVulnerabilities { subquery += "/features+vulnerabilities" } - defer observeQueryTime("FindLayer", subquery, time.Now()) + defer database.ObserveQueryTime("FindLayer", subquery, time.Now()) // Find the layer var layer database.Layer @@ -43,8 +43,8 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo t := time.Now() err := pgSQL.QueryRow(searchLayer, name). Scan(&layer.ID, &layer.Name, &layer.EngineVersion, &parentID, &parentName, &namespaceID, - &namespaceName) - observeQueryTime("FindLayer", "searchLayer", t) + &namespaceName) + database.ObserveQueryTime("FindLayer", "searchLayer", t) if err != nil { return layer, handleError("searchLayer", err) @@ -89,7 +89,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo t = time.Now() featureVersions, err := getLayerFeatureVersions(tx, layer.ID) - observeQueryTime("FindLayer", "getLayerFeatureVersions", t) + database.ObserveQueryTime("FindLayer", "getLayerFeatureVersions", t) if err != nil { return layer, err @@ -101,7 +101,7 @@ func (pgSQL *pgSQL) FindLayer(name string, withFeatures, withVulnerabilities boo // Load the vulnerabilities that affect the FeatureVersions. t = time.Now() err := loadAffectedBy(tx, layer.Features) - observeQueryTime("FindLayer", "loadAffectedBy", t) + database.ObserveQueryTime("FindLayer", "loadAffectedBy", t) if err != nil { return layer, err @@ -233,8 +233,8 @@ func (pgSQL *pgSQL) InsertLayer(layer database.Layer) error { layer.ID = existingLayer.ID } - // We do `defer observeQueryTime` here because we don't want to observe existing layers. - defer observeQueryTime("InsertLayer", "all", tf) + // We do `defer database.ObserveQueryTime` here because we don't want to observe existing layers. + defer database.ObserveQueryTime("InsertLayer", "all", tf) // Get parent ID. var parentID zero.Int @@ -386,7 +386,7 @@ func createNV(features []database.FeatureVersion) (map[string]*database.FeatureV } func (pgSQL *pgSQL) DeleteLayer(name string) error { - defer observeQueryTime("DeleteLayer", "all", time.Now()) + defer database.ObserveQueryTime("DeleteLayer", "all", time.Now()) result, err := pgSQL.Exec(removeLayer, name) if err != nil { diff --git a/database/pgsql/lock.go b/database/pgsql/lock.go index 2f491caa..4af8a7db 100644 --- a/database/pgsql/lock.go +++ b/database/pgsql/lock.go @@ -17,6 +17,7 @@ package pgsql import ( "time" + "github.com/coreos/clair/database" cerrors "github.com/coreos/clair/utils/errors" ) @@ -30,7 +31,7 @@ func (pgSQL *pgSQL) Lock(name string, owner string, duration time.Duration, rene return false, time.Time{} } - defer observeQueryTime("Lock", "all", time.Now()) + defer database.ObserveQueryTime("Lock", "all", time.Now()) // Compute expiration. until := time.Now().Add(duration) @@ -70,7 +71,7 @@ func (pgSQL *pgSQL) Unlock(name, owner string) { return } - defer observeQueryTime("Unlock", "all", time.Now()) + defer database.ObserveQueryTime("Unlock", "all", time.Now()) pgSQL.Exec(removeLock, name, owner) } @@ -83,7 +84,7 @@ func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { return "", time.Time{}, cerrors.NewBadRequestError("could not find an invalid lock") } - defer observeQueryTime("FindLock", "all", time.Now()) + defer database.ObserveQueryTime("FindLock", "all", time.Now()) var owner string var until time.Time @@ -97,7 +98,7 @@ func (pgSQL *pgSQL) FindLock(name string) (string, time.Time, error) { // pruneLocks removes every expired locks from the database func (pgSQL *pgSQL) pruneLocks() { - defer observeQueryTime("pruneLocks", "all", time.Now()) + defer database.ObserveQueryTime("pruneLocks", "all", time.Now()) if _, err := pgSQL.Exec(removeLockExpired); err != nil { handleError("removeLockExpired", err) diff --git a/database/pgsql/namespace.go b/database/pgsql/namespace.go index 3c85c784..20e7875b 100644 --- a/database/pgsql/namespace.go +++ b/database/pgsql/namespace.go @@ -27,15 +27,15 @@ func (pgSQL *pgSQL) insertNamespace(namespace database.Namespace) (int, error) { } if pgSQL.cache != nil { - promCacheQueriesTotal.WithLabelValues("namespace").Inc() + database.PromCacheQueriesTotal.WithLabelValues("namespace").Inc() if id, found := pgSQL.cache.Get("namespace:" + namespace.Name); found { - promCacheHitsTotal.WithLabelValues("namespace").Inc() + database.PromCacheHitsTotal.WithLabelValues("namespace").Inc() return id.(int), nil } } // We do `defer observeQueryTime` here because we don't want to observe cached namespaces. - defer observeQueryTime("insertNamespace", "all", time.Now()) + defer database.ObserveQueryTime("insertNamespace", "all", time.Now()) var id int err := pgSQL.QueryRow(soiNamespace, namespace.Name).Scan(&id) diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index 70bde316..eacb5737 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -13,7 +13,7 @@ import ( // do it in tx so we won't insert/update a vuln without notification and vice-versa. // name and created doesn't matter. func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) error { - defer observeQueryTime("createNotification", "all", time.Now()) + defer database.ObserveQueryTime("createNotification", "all", time.Now()) // Insert Notification. oldVulnerabilityNullableID := sql.NullInt64{Int64: int64(oldVulnerabilityID), Valid: oldVulnerabilityID != 0} @@ -30,7 +30,7 @@ func createNotification(tx *sql.Tx, oldVulnerabilityID, newVulnerabilityID int) // Get one available notification name (!locked && !deleted && (!notified || notified_but_timed-out)). // Does not fill new/old vuln. func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (database.VulnerabilityNotification, error) { - defer observeQueryTime("GetAvailableNotification", "all", time.Now()) + defer database.ObserveQueryTime("GetAvailableNotification", "all", time.Now()) before := time.Now().Add(-renotifyInterval) row := pgSQL.QueryRow(searchNotificationAvailable, before) @@ -40,7 +40,7 @@ func (pgSQL *pgSQL) GetAvailableNotification(renotifyInterval time.Duration) (da } func (pgSQL *pgSQL) GetNotification(name string, limit int, page database.VulnerabilityNotificationPageNumber) (database.VulnerabilityNotification, database.VulnerabilityNotificationPageNumber, error) { - defer observeQueryTime("GetNotification", "all", time.Now()) + defer database.ObserveQueryTime("GetNotification", "all", time.Now()) // Get Notification. notification, err := pgSQL.scanNotification(pgSQL.QueryRow(searchNotification, name), true) @@ -145,8 +145,8 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu return -1, nil } - // We do `defer observeQueryTime` here because we don't want to observe invalid calls. - defer observeQueryTime("loadLayerIntroducingVulnerability", "all", tf) + // We do `defer database.ObserveQueryTime` here because we don't want to observe invalid calls. + defer database.ObserveQueryTime("loadLayerIntroducingVulnerability", "all", tf) // Query with limit + 1, the last item will be used to know the next starting ID. rows, err := pgSQL.Query(searchNotificationLayerIntroducingVulnerability, @@ -185,7 +185,7 @@ func (pgSQL *pgSQL) loadLayerIntroducingVulnerability(vulnerability *database.Vu } func (pgSQL *pgSQL) SetNotificationNotified(name string) error { - defer observeQueryTime("SetNotificationNotified", "all", time.Now()) + defer database.ObserveQueryTime("SetNotificationNotified", "all", time.Now()) if _, err := pgSQL.Exec(updatedNotificationNotified, name); err != nil { return handleError("updatedNotificationNotified", err) @@ -194,7 +194,7 @@ func (pgSQL *pgSQL) SetNotificationNotified(name string) error { } func (pgSQL *pgSQL) DeleteNotification(name string) error { - defer observeQueryTime("DeleteNotification", "all", time.Now()) + defer database.ObserveQueryTime("DeleteNotification", "all", time.Now()) result, err := pgSQL.Exec(removeNotification, name) if err != nil { diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 15789dd9..ad219951 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -23,57 +23,21 @@ import ( "path" "runtime" "strings" - "time" "bitbucket.org/liamstask/goose/lib/goose" "github.com/coreos/clair/config" "github.com/coreos/clair/database" - "github.com/coreos/clair/utils" cerrors "github.com/coreos/clair/utils/errors" "github.com/coreos/pkg/capnslog" "github.com/hashicorp/golang-lru" "github.com/lib/pq" "github.com/pborman/uuid" - "github.com/prometheus/client_golang/prometheus" ) var ( log = capnslog.NewPackageLogger("github.com/coreos/clair", "pgsql") - - promErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "clair_pgsql_errors_total", - Help: "Number of errors that PostgreSQL requests generated.", - }, []string{"request"}) - - promCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "clair_pgsql_cache_hits_total", - Help: "Number of cache hits that the PostgreSQL backend did.", - }, []string{"object"}) - - promCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ - Name: "clair_pgsql_cache_queries_total", - Help: "Number of cache queries that the PostgreSQL backend did.", - }, []string{"object"}) - - promQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ - Name: "clair_pgsql_query_duration_milliseconds", - Help: "Time it takes to execute the database query.", - }, []string{"query", "subquery"}) - - promConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ - Name: "clair_pgsql_concurrent_lock_vafv_total", - Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.", - }) ) -func init() { - prometheus.MustRegister(promErrorsTotal) - prometheus.MustRegister(promCacheHitsTotal) - prometheus.MustRegister(promCacheQueriesTotal) - prometheus.MustRegister(promQueryDurationMilliseconds) - prometheus.MustRegister(promConcurrentLockVAFV) -} - type Queryer interface { Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row @@ -210,7 +174,7 @@ type pgSQLTest struct { // Using Close() will drop the database. func OpenForTest(name string, withTestData bool) (*pgSQLTest, error) { // Define the PostgreSQL connection strings. - dataSource := "host=127.0.0.1 sslmode=disable user=postgres dbname=" + dataSource := "host=127.0.0.1 sslmode=disable user=postgres password=huawei dbname=" if dataSourceEnv := os.Getenv("CLAIR_TEST_PGSQL"); dataSourceEnv != "" { dataSource = dataSourceEnv + " dbname=" } @@ -267,7 +231,7 @@ func handleError(desc string, err error) error { } log.Errorf("%s: %v", desc, err) - promErrorsTotal.WithLabelValues(desc).Inc() + database.PromErrorsTotal.WithLabelValues(desc).Inc() if _, o := err.(*pq.Error); o || err == sql.ErrTxDone || strings.HasPrefix(err.Error(), "sql:") { return database.ErrBackendException @@ -281,7 +245,3 @@ func isErrUniqueViolation(err error) bool { pqErr, ok := err.(*pq.Error) return ok && pqErr.Code == "23505" } - -func observeQueryTime(query, subquery string, start time.Time) { - utils.PrometheusObserveTimeMilliseconds(promQueryDurationMilliseconds.WithLabelValues(query, subquery), start) -} diff --git a/database/pgsql/vulnerability.go b/database/pgsql/vulnerability.go index 74ee9828..848be274 100644 --- a/database/pgsql/vulnerability.go +++ b/database/pgsql/vulnerability.go @@ -29,7 +29,7 @@ import ( ) func (pgSQL *pgSQL) ListVulnerabilities(namespaceName string, limit int, startID int) ([]database.Vulnerability, int, error) { - defer observeQueryTime("listVulnerabilities", "all", time.Now()) + defer database.ObserveQueryTime("listVulnerabilities", "all", time.Now()) // Query Namespace. var id int @@ -88,7 +88,7 @@ func (pgSQL *pgSQL) FindVulnerability(namespaceName, name string) (database.Vuln } func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bool) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerability", "all", time.Now()) + defer database.ObserveQueryTime("findVulnerability", "all", time.Now()) queryName := "searchVulnerabilityBase+searchVulnerabilityByNamespaceAndName" query := searchVulnerabilityBase + searchVulnerabilityByNamespaceAndName @@ -101,7 +101,7 @@ func findVulnerability(queryer Queryer, namespaceName, name string, forUpdate bo } func (pgSQL *pgSQL) findVulnerabilityByIDWithDeleted(id int) (database.Vulnerability, error) { - defer observeQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) + defer database.ObserveQueryTime("findVulnerabilityByIDWithDeleted", "all", time.Now()) queryName := "searchVulnerabilityBase+searchVulnerabilityByID" query := searchVulnerabilityBase + searchVulnerabilityByID @@ -215,8 +215,8 @@ func (pgSQL *pgSQL) insertVulnerability(vulnerability database.Vulnerability, on } } - // We do `defer observeQueryTime` here because we don't want to observe invalid vulnerabilities. - defer observeQueryTime("insertVulnerability", "all", tf) + // We do `defer database.ObserveQueryTime` here because we don't want to observe invalid vulnerabilities. + defer database.ObserveQueryTime("insertVulnerability", "all", tf) // Begin transaction. tx, err := pgSQL.Begin() @@ -401,7 +401,7 @@ func createFeatureVersionNameMap(features []database.FeatureVersion) (map[string // linkVulnerabilityToFeatureVersions to propagate the changes on Vulnerability_FixedIn_Feature to // Vulnerability_Affects_FeatureVersion. func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulnerabilityID int, fixedIn []database.FeatureVersion) error { - defer observeQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now()) + defer database.ObserveQueryTime("insertVulnerabilityFixedInFeatureVersions", "all", time.Now()) // Insert or find the Features. // TODO(Quentin-M): Batch me. @@ -420,11 +420,11 @@ func (pgSQL *pgSQL) insertVulnerabilityFixedInFeatureVersions(tx *sql.Tx, vulner // Lock Vulnerability_Affects_FeatureVersion exclusively. // We want to prevent InsertFeatureVersion to modify it. - promConcurrentLockVAFV.Inc() - defer promConcurrentLockVAFV.Dec() + database.PromConcurrentLockVAFV.Inc() + defer database.PromConcurrentLockVAFV.Dec() t := time.Now() _, err = tx.Exec(lockVulnerabilityAffects) - observeQueryTime("insertVulnerability", "lock", t) + database.ObserveQueryTime("insertVulnerability", "lock", t) if err != nil { tx.Rollback() @@ -498,7 +498,7 @@ func linkVulnerabilityToFeatureVersions(tx *sql.Tx, fixedInID, vulnerabilityID, } func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []database.FeatureVersion) error { - defer observeQueryTime("InsertVulnerabilityFixes", "all", time.Now()) + defer database.ObserveQueryTime("InsertVulnerabilityFixes", "all", time.Now()) v := database.Vulnerability{ Name: vulnerabilityName, @@ -512,7 +512,7 @@ func (pgSQL *pgSQL) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabili } func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { - defer observeQueryTime("DeleteVulnerabilityFix", "all", time.Now()) + defer database.ObserveQueryTime("DeleteVulnerabilityFix", "all", time.Now()) v := database.Vulnerability{ Name: vulnerabilityName, @@ -536,7 +536,7 @@ func (pgSQL *pgSQL) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerability } func (pgSQL *pgSQL) DeleteVulnerability(namespaceName, name string) error { - defer observeQueryTime("DeleteVulnerability", "all", time.Now()) + defer database.ObserveQueryTime("DeleteVulnerability", "all", time.Now()) // Begin transaction. tx, err := pgSQL.Begin() diff --git a/database/prometheus.go b/database/prometheus.go new file mode 100644 index 00000000..7750b807 --- /dev/null +++ b/database/prometheus.go @@ -0,0 +1,47 @@ +package database + +import ( + "time" + + "github.com/coreos/clair/utils" + "github.com/prometheus/client_golang/prometheus" +) + +var ( + PromErrorsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "clair_sql_errors_total", + Help: "Number of errors that SQL requests generated.", + }, []string{"request"}) + + PromCacheHitsTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "clair_sql_cache_hits_total", + Help: "Number of cache hits that the SQL backend did.", + }, []string{"object"}) + + PromCacheQueriesTotal = prometheus.NewCounterVec(prometheus.CounterOpts{ + Name: "clair_sql_cache_queries_total", + Help: "Number of cache queries that the SQL backend did.", + }, []string{"object"}) + + PromQueryDurationMilliseconds = prometheus.NewHistogramVec(prometheus.HistogramOpts{ + Name: "clair_sql_query_duration_milliseconds", + Help: "Time it takes to execute the database query.", + }, []string{"query", "subquery"}) + + PromConcurrentLockVAFV = prometheus.NewGauge(prometheus.GaugeOpts{ + Name: "clair_sql_concurrent_lock_vafv_total", + Help: "Number of transactions trying to hold the exclusive Vulnerability_Affects_FeatureVersion lock.", + }) +) + +func init() { + prometheus.MustRegister(PromErrorsTotal) + prometheus.MustRegister(PromCacheHitsTotal) + prometheus.MustRegister(PromCacheQueriesTotal) + prometheus.MustRegister(PromQueryDurationMilliseconds) + prometheus.MustRegister(PromConcurrentLockVAFV) +} + +func ObserveQueryTime(query, subquery string, start time.Time) { + utils.PrometheusObserveTimeMilliseconds(PromQueryDurationMilliseconds.WithLabelValues(query, subquery), start) +}