diff --git a/cmd/clair/config.go b/cmd/clair/config.go index cccbebfb..e1628f05 100644 --- a/cmd/clair/config.go +++ b/cmd/clair/config.go @@ -20,13 +20,17 @@ import ( "os" "time" + "github.com/fernet/fernet-go" + log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" "github.com/coreos/clair" "github.com/coreos/clair/api" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/notification" - "github.com/fernet/fernet-go" + "github.com/coreos/clair/ext/vulnsrc" ) // ErrDatasourceNotLoaded is returned when the datasource variable in the @@ -43,6 +47,7 @@ type File struct { type Config struct { Database database.RegistrableComponentConfig Updater *clair.UpdaterConfig + Worker *clair.WorkerConfig Notifier *notification.Config API *api.Config } @@ -54,12 +59,16 @@ func DefaultConfig() Config { Type: "pgsql", }, Updater: &clair.UpdaterConfig{ - Interval: 1 * time.Hour, + EnabledUpdaters: vulnsrc.ListUpdaters(), + Interval: 1 * time.Hour, + }, + Worker: &clair.WorkerConfig{ + EnabledDetectors: featurens.ListDetectors(), + EnabledListers: featurefmt.ListListers(), }, API: &api.Config{ - Port: 6060, HealthPort: 6061, - GrpcPort: 6070, + GrpcPort: 6060, Timeout: 900 * time.Second, }, Notifier: ¬ification.Config{ @@ -97,14 +106,15 @@ func LoadConfig(path string) (config *Config, err error) { config = &cfgFile.Clair // Generate a pagination key if none is provided. - if config.API.PaginationKey == "" { + if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" { + log.Warn("pagination key is empty, generating...") var key fernet.Key if err = key.Generate(); err != nil { return } - config.API.PaginationKey = key.Encode() + config.Database.Options["paginationkey"] = key.Encode() } else { - _, err = fernet.DecodeKey(config.API.PaginationKey) + _, err = fernet.DecodeKey(config.Database.Options["paginationkey"].(string)) if err != nil { err = errors.New("Invalid Pagination key; must be 32-bit URL-safe base64") return diff --git a/cmd/clair/main.go b/cmd/clair/main.go index 0408a732..fbf5d256 100644 --- a/cmd/clair/main.go +++ b/cmd/clair/main.go @@ -30,9 +30,13 @@ import ( "github.com/coreos/clair" "github.com/coreos/clair/api" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/imagefmt" + "github.com/coreos/clair/ext/vulnsrc" "github.com/coreos/clair/pkg/formatter" "github.com/coreos/clair/pkg/stopper" + "github.com/coreos/clair/pkg/strutil" // Register database driver. _ "github.com/coreos/clair/database/pgsql" @@ -85,6 +89,43 @@ func stopCPUProfiling(f *os.File) { log.Info("stopped CPU profiling") } +func configClairVersion(config *Config) { + listers := featurefmt.ListListers() + detectors := featurens.ListDetectors() + updaters := vulnsrc.ListUpdaters() + + log.WithFields(log.Fields{ + "Listers": strings.Join(listers, ","), + "Detectors": strings.Join(detectors, ","), + "Updaters": strings.Join(updaters, ","), + }).Info("Clair registered components") + + unregDetectors := strutil.CompareStringLists(config.Worker.EnabledDetectors, detectors) + unregListers := strutil.CompareStringLists(config.Worker.EnabledListers, listers) + unregUpdaters := strutil.CompareStringLists(config.Updater.EnabledUpdaters, updaters) + if len(unregDetectors) != 0 || len(unregListers) != 0 || len(unregUpdaters) != 0 { + log.WithFields(log.Fields{ + "Unknown Detectors": strings.Join(unregDetectors, ","), + "Unknown Listers": strings.Join(unregListers, ","), + "Unknown Updaters": strings.Join(unregUpdaters, ","), + "Available Listers": strings.Join(featurefmt.ListListers(), ","), + "Available Detectors": strings.Join(featurens.ListDetectors(), ","), + "Available Updaters": strings.Join(vulnsrc.ListUpdaters(), ","), + }).Fatal("Unknown or unregistered components are configured") + } + + // verify the user specified detectors/listers/updaters are implemented. If + // some are not registered, it logs warning and won't use the unregistered + // extensions. + + clair.Processors = database.Processors{ + Detectors: strutil.CompareStringListsInBoth(config.Worker.EnabledDetectors, detectors), + Listers: strutil.CompareStringListsInBoth(config.Worker.EnabledListers, listers), + } + + clair.EnabledUpdaters = strutil.CompareStringListsInBoth(config.Updater.EnabledUpdaters, updaters) +} + // Boot starts Clair instance with the provided config. func Boot(config *Config) { rand.Seed(time.Now().UnixNano()) @@ -102,9 +143,8 @@ func Boot(config *Config) { go clair.RunNotifier(config.Notifier, db, st) // Start API - st.Begin() - go api.Run(config.API, db, st) go api.RunV2(config.API, db) + st.Begin() go api.RunHealth(config.API, db, st) @@ -135,19 +175,17 @@ func main() { } } - // Load configuration - config, err := LoadConfig(*flagConfigPath) - if err != nil { - log.WithError(err).Fatal("failed to load configuration") - } - // Initialize logging system - logLevel, err := log.ParseLevel(strings.ToUpper(*flagLogLevel)) log.SetLevel(logLevel) log.SetOutput(os.Stdout) log.SetFormatter(&formatter.JSONExtendedFormatter{ShowLn: true}) + config, err := LoadConfig(*flagConfigPath) + if err != nil { + log.WithError(err).Fatal("failed to load configuration") + } + // Enable CPU Profiling if specified if *flagCPUProfilePath != "" { defer stopCPUProfiling(startCPUProfiling(*flagCPUProfilePath)) @@ -159,5 +197,8 @@ func main() { imagefmt.SetInsecureTLS(*flagInsecureTLS) } + // configure updater and worker + configClairVersion(config) + Boot(config) } diff --git a/config.example.yaml b/config.example.yaml index ab47886c..c45833c3 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -25,11 +25,15 @@ clair: # Number of elements kept in the cache # Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database. cachesize: 16384 + # 32-bit URL-safe base64 key used to encrypt pagination tokens + # If one is not provided, it will be generated. + # Multiple clair instances in the same cluster need the same value. + paginationkey: api: - # API server port - port: 6060 - grpcPort: 6070 + # v2 grpc/RESTful API server port + grpcport : 6060 + # Health server port # This is an unencrypted endpoint useful for load balancers to check to healthiness of the clair server. healthport: 6061 @@ -37,11 +41,6 @@ clair: # Deadline before an API request will respond with a 503 timeout: 900s - # 32-bit URL-safe base64 key used to encrypt pagination tokens - # If one is not provided, it will be generated. - # Multiple clair instances in the same cluster need the same value. - paginationkey: - # Optional PKI configuration # If you want to easily generate client certificates and CAs, try the following projects: # https://github.com/coreos/etcd-ca @@ -51,10 +50,29 @@ clair: keyfile: certfile: + worker: + namespace_detectors: + - os-release + - lsb-release + - apt-sources + - alpine-release + - redhat-release + + feature_listers: + - apk + - dpkg + - rpm + updater: # Frequency the database will be updated with vulnerabilities from the default data sources # The value 0 disables the updater entirely. interval: 2h + enabledupdaters: + - debian + - ubuntu + - rhel + - oracle + - alpine notifier: # Number of attempts before the notification is marked as failed to be sent @@ -72,9 +90,9 @@ clair: # https://github.com/cloudflare/cfssl # https://github.com/coreos/etcd-ca servername: - cafile: - keyfile: - certfile: + cafile: + keyfile: + certfile: # Optional HTTP Proxy: must be a valid URL (including the scheme). proxy: diff --git a/database/mock.go b/database/mock.go index 9a0963c8..966e9c88 100644 --- a/database/mock.go +++ b/database/mock.go @@ -16,161 +16,240 @@ package database import "time" -// MockDatastore implements Datastore and enables overriding each available method. +// MockSession implements Session and enables overriding each available method. // The default behavior of each method is to simply panic. -type MockDatastore struct { - FctListNamespaces func() ([]Namespace, error) - FctInsertLayer func(Layer) error - FctFindLayer func(name string, withFeatures, withVulnerabilities bool) (Layer, error) - FctDeleteLayer func(name string) error - FctListVulnerabilities func(namespaceName string, limit int, page int) ([]Vulnerability, int, error) - FctInsertVulnerabilities func(vulnerabilities []Vulnerability, createNotification bool) error - FctFindVulnerability func(namespaceName, name string) (Vulnerability, error) - FctDeleteVulnerability func(namespaceName, name string) error - FctInsertVulnerabilityFixes func(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error - FctDeleteVulnerabilityFix func(vulnerabilityNamespace, vulnerabilityName, featureName string) error - FctGetAvailableNotification func(renotifyInterval time.Duration) (VulnerabilityNotification, error) - FctGetNotification func(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) - FctSetNotificationNotified func(name string) error +type MockSession struct { + FctCommit func() error + FctRollback func() error + FctUpsertAncestry func(Ancestry, []NamespacedFeature, Processors) error + FctFindAncestry func(name string) (Ancestry, Processors, bool, error) + FctFindAncestryFeatures func(name string) (AncestryWithFeatures, bool, error) + FctFindAffectedNamespacedFeatures func(features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) + FctPersistNamespaces func([]Namespace) error + FctPersistFeatures func([]Feature) error + FctPersistNamespacedFeatures func([]NamespacedFeature) error + FctCacheAffectedNamespacedFeatures func([]NamespacedFeature) error + FctPersistLayer func(Layer) error + FctPersistLayerContent func(hash string, namespaces []Namespace, features []Feature, processedBy Processors) error + FctFindLayer func(name string) (Layer, Processors, bool, error) + FctFindLayerWithContent func(name string) (LayerWithContent, bool, error) + FctInsertVulnerabilities func([]VulnerabilityWithAffected) error + FctFindVulnerabilities func([]VulnerabilityID) ([]NullableVulnerability, error) + FctDeleteVulnerabilities func([]VulnerabilityID) error + FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error + FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) + FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) + FctMarkNotificationNotified func(name string) error FctDeleteNotification func(name string) error - FctInsertKeyValue func(key, value string) error - FctGetKeyValue func(key string) (string, error) - FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) - FctUnlock func(name, owner string) - FctFindLock func(name string) (string, time.Time, error) - FctPing func() bool - FctClose func() + FctUpdateKeyValue func(key, value string) error + FctFindKeyValue func(key string) (string, bool, error) + FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) + FctUnlock func(name, owner string) error + FctFindLock func(name string) (string, time.Time, bool, error) } -func (mds *MockDatastore) ListNamespaces() ([]Namespace, error) { - if mds.FctListNamespaces != nil { - return mds.FctListNamespaces() +func (ms *MockSession) Commit() error { + if ms.FctCommit != nil { + return ms.FctCommit() } panic("required mock function not implemented") } -func (mds *MockDatastore) InsertLayer(layer Layer) error { - if mds.FctInsertLayer != nil { - return mds.FctInsertLayer(layer) +func (ms *MockSession) Rollback() error { + if ms.FctRollback != nil { + return ms.FctRollback() } panic("required mock function not implemented") } -func (mds *MockDatastore) FindLayer(name string, withFeatures, withVulnerabilities bool) (Layer, error) { - if mds.FctFindLayer != nil { - return mds.FctFindLayer(name, withFeatures, withVulnerabilities) +func (ms *MockSession) UpsertAncestry(ancestry Ancestry, features []NamespacedFeature, processedBy Processors) error { + if ms.FctUpsertAncestry != nil { + return ms.FctUpsertAncestry(ancestry, features, processedBy) } panic("required mock function not implemented") } -func (mds *MockDatastore) DeleteLayer(name string) error { - if mds.FctDeleteLayer != nil { - return mds.FctDeleteLayer(name) +func (ms *MockSession) FindAncestry(name string) (Ancestry, Processors, bool, error) { + if ms.FctFindAncestry != nil { + return ms.FctFindAncestry(name) } panic("required mock function not implemented") } -func (mds *MockDatastore) ListVulnerabilities(namespaceName string, limit int, page int) ([]Vulnerability, int, error) { - if mds.FctListVulnerabilities != nil { - return mds.FctListVulnerabilities(namespaceName, limit, page) +func (ms *MockSession) FindAncestryFeatures(name string) (AncestryWithFeatures, bool, error) { + if ms.FctFindAncestryFeatures != nil { + return ms.FctFindAncestryFeatures(name) } panic("required mock function not implemented") } -func (mds *MockDatastore) InsertVulnerabilities(vulnerabilities []Vulnerability, createNotification bool) error { - if mds.FctInsertVulnerabilities != nil { - return mds.FctInsertVulnerabilities(vulnerabilities, createNotification) +func (ms *MockSession) FindAffectedNamespacedFeatures(features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) { + if ms.FctFindAffectedNamespacedFeatures != nil { + return ms.FctFindAffectedNamespacedFeatures(features) } panic("required mock function not implemented") } -func (mds *MockDatastore) FindVulnerability(namespaceName, name string) (Vulnerability, error) { - if mds.FctFindVulnerability != nil { - return mds.FctFindVulnerability(namespaceName, name) +func (ms *MockSession) PersistNamespaces(namespaces []Namespace) error { + if ms.FctPersistNamespaces != nil { + return ms.FctPersistNamespaces(namespaces) } panic("required mock function not implemented") } -func (mds *MockDatastore) DeleteVulnerability(namespaceName, name string) error { - if mds.FctDeleteVulnerability != nil { - return mds.FctDeleteVulnerability(namespaceName, name) +func (ms *MockSession) PersistFeatures(features []Feature) error { + if ms.FctPersistFeatures != nil { + return ms.FctPersistFeatures(features) } panic("required mock function not implemented") } -func (mds *MockDatastore) InsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName string, fixes []FeatureVersion) error { - if mds.FctInsertVulnerabilityFixes != nil { - return mds.FctInsertVulnerabilityFixes(vulnerabilityNamespace, vulnerabilityName, fixes) +func (ms *MockSession) PersistNamespacedFeatures(namespacedFeatures []NamespacedFeature) error { + if ms.FctPersistNamespacedFeatures != nil { + return ms.FctPersistNamespacedFeatures(namespacedFeatures) } panic("required mock function not implemented") } -func (mds *MockDatastore) DeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName string) error { - if mds.FctDeleteVulnerabilityFix != nil { - return mds.FctDeleteVulnerabilityFix(vulnerabilityNamespace, vulnerabilityName, featureName) +func (ms *MockSession) CacheAffectedNamespacedFeatures(namespacedFeatures []NamespacedFeature) error { + if ms.FctCacheAffectedNamespacedFeatures != nil { + return ms.FctCacheAffectedNamespacedFeatures(namespacedFeatures) } panic("required mock function not implemented") } -func (mds *MockDatastore) GetAvailableNotification(renotifyInterval time.Duration) (VulnerabilityNotification, error) { - if mds.FctGetAvailableNotification != nil { - return mds.FctGetAvailableNotification(renotifyInterval) +func (ms *MockSession) PersistLayer(layer Layer) error { + if ms.FctPersistLayer != nil { + return ms.FctPersistLayer(layer) } panic("required mock function not implemented") } -func (mds *MockDatastore) GetNotification(name string, limit int, page VulnerabilityNotificationPageNumber) (VulnerabilityNotification, VulnerabilityNotificationPageNumber, error) { - if mds.FctGetNotification != nil { - return mds.FctGetNotification(name, limit, page) +func (ms *MockSession) PersistLayerContent(hash string, namespaces []Namespace, features []Feature, processedBy Processors) error { + if ms.FctPersistLayerContent != nil { + return ms.FctPersistLayerContent(hash, namespaces, features, processedBy) } panic("required mock function not implemented") } -func (mds *MockDatastore) SetNotificationNotified(name string) error { - if mds.FctSetNotificationNotified != nil { - return mds.FctSetNotificationNotified(name) +func (ms *MockSession) FindLayer(name string) (Layer, Processors, bool, error) { + if ms.FctFindLayer != nil { + return ms.FctFindLayer(name) } panic("required mock function not implemented") } -func (mds *MockDatastore) DeleteNotification(name string) error { - if mds.FctDeleteNotification != nil { - return mds.FctDeleteNotification(name) +func (ms *MockSession) FindLayerWithContent(name string) (LayerWithContent, bool, error) { + if ms.FctFindLayerWithContent != nil { + return ms.FctFindLayerWithContent(name) } panic("required mock function not implemented") } -func (mds *MockDatastore) InsertKeyValue(key, value string) error { - if mds.FctInsertKeyValue != nil { - return mds.FctInsertKeyValue(key, value) + +func (ms *MockSession) InsertVulnerabilities(vulnerabilities []VulnerabilityWithAffected) error { + if ms.FctInsertVulnerabilities != nil { + return ms.FctInsertVulnerabilities(vulnerabilities) } panic("required mock function not implemented") } -func (mds *MockDatastore) GetKeyValue(key string) (string, error) { - if mds.FctGetKeyValue != nil { - return mds.FctGetKeyValue(key) +func (ms *MockSession) FindVulnerabilities(vulnerabilityIDs []VulnerabilityID) ([]NullableVulnerability, error) { + if ms.FctFindVulnerabilities != nil { + return ms.FctFindVulnerabilities(vulnerabilityIDs) } panic("required mock function not implemented") } -func (mds *MockDatastore) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { - if mds.FctLock != nil { - return mds.FctLock(name, owner, duration, renew) +func (ms *MockSession) DeleteVulnerabilities(VulnerabilityIDs []VulnerabilityID) error { + if ms.FctDeleteVulnerabilities != nil { + return ms.FctDeleteVulnerabilities(VulnerabilityIDs) } panic("required mock function not implemented") } -func (mds *MockDatastore) Unlock(name, owner string) { - if mds.FctUnlock != nil { - mds.FctUnlock(name, owner) - return +func (ms *MockSession) InsertVulnerabilityNotifications(vulnerabilityNotifications []VulnerabilityNotification) error { + if ms.FctInsertVulnerabilityNotifications != nil { + return ms.FctInsertVulnerabilityNotifications(vulnerabilityNotifications) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindNewNotification(lastNotified time.Time) (NotificationHook, bool, error) { + if ms.FctFindNewNotification != nil { + return ms.FctFindNewNotification(lastNotified) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + VulnerabilityNotificationWithVulnerable, bool, error) { + if ms.FctFindVulnerabilityNotification != nil { + return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) } panic("required mock function not implemented") } -func (mds *MockDatastore) FindLock(name string) (string, time.Time, error) { - if mds.FctFindLock != nil { - return mds.FctFindLock(name) +func (ms *MockSession) MarkNotificationNotified(name string) error { + if ms.FctMarkNotificationNotified != nil { + return ms.FctMarkNotificationNotified(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) DeleteNotification(name string) error { + if ms.FctDeleteNotification != nil { + return ms.FctDeleteNotification(name) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) UpdateKeyValue(key, value string) error { + if ms.FctUpdateKeyValue != nil { + return ms.FctUpdateKeyValue(key, value) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindKeyValue(key string) (string, bool, error) { + if ms.FctFindKeyValue != nil { + return ms.FctFindKeyValue(key) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) Lock(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) { + if ms.FctLock != nil { + return ms.FctLock(name, owner, duration, renew) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) Unlock(name, owner string) error { + if ms.FctUnlock != nil { + return ms.FctUnlock(name, owner) + } + panic("required mock function not implemented") +} + +func (ms *MockSession) FindLock(name string) (string, time.Time, bool, error) { + if ms.FctFindLock != nil { + return ms.FctFindLock(name) + } + panic("required mock function not implemented") +} + +// MockDatastore implements Datastore and enables overriding each available method. +// The default behavior of each method is to simply panic. +type MockDatastore struct { + FctBegin func() (Session, error) + FctPing func() bool + FctClose func() +} + +func (mds *MockDatastore) Begin() (Session, error) { + if mds.FctBegin != nil { + return mds.FctBegin() } panic("required mock function not implemented") } diff --git a/ext/featurefmt/apk/apk.go b/ext/featurefmt/apk/apk.go index ff63880d..195c8920 100644 --- a/ext/featurefmt/apk/apk.go +++ b/ext/featurefmt/apk/apk.go @@ -34,17 +34,17 @@ func init() { type lister struct{} -func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, error) { +func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.Feature, error) { file, exists := files["lib/apk/db/installed"] if !exists { - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } // Iterate over each line in the "installed" file attempting to parse each // package into a feature that will be stored in a set to guarantee // uniqueness. - pkgSet := make(map[string]database.FeatureVersion) - ipkg := database.FeatureVersion{} + pkgSet := make(map[string]database.Feature) + ipkg := database.Feature{} scanner := bufio.NewScanner(bytes.NewBuffer(file)) for scanner.Scan() { line := scanner.Text() @@ -55,7 +55,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, // Parse the package name or version. switch { case line[:2] == "P:": - ipkg.Feature.Name = line[2:] + ipkg.Name = line[2:] case line[:2] == "V:": version := string(line[2:]) err := versionfmt.Valid(dpkg.ParserName, version) @@ -67,20 +67,21 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, case line == "": // Restart if the parser reaches another package definition before // creating a valid package. - ipkg = database.FeatureVersion{} + ipkg = database.Feature{} } // If we have a whole feature, store it in the set and try to parse a new // one. - if ipkg.Feature.Name != "" && ipkg.Version != "" { - pkgSet[ipkg.Feature.Name+"#"+ipkg.Version] = ipkg - ipkg = database.FeatureVersion{} + if ipkg.Name != "" && ipkg.Version != "" { + pkgSet[ipkg.Name+"#"+ipkg.Version] = ipkg + ipkg = database.Feature{} } } - // Convert the map into a slice. - pkgs := make([]database.FeatureVersion, 0, len(pkgSet)) + // Convert the map into a slice and attach the version format + pkgs := make([]database.Feature, 0, len(pkgSet)) for _, pkg := range pkgSet { + pkg.VersionFormat = dpkg.ParserName pkgs = append(pkgs, pkg) } diff --git a/ext/featurefmt/apk/apk_test.go b/ext/featurefmt/apk/apk_test.go index 6dbde3e6..d8dc0d88 100644 --- a/ext/featurefmt/apk/apk_test.go +++ b/ext/featurefmt/apk/apk_test.go @@ -19,58 +19,32 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/pkg/tarutil" ) func TestAPKFeatureDetection(t *testing.T) { + testFeatures := []database.Feature{ + {Name: "musl", Version: "1.1.14-r10"}, + {Name: "busybox", Version: "1.24.2-r9"}, + {Name: "alpine-baselayout", Version: "3.0.3-r0"}, + {Name: "alpine-keys", Version: "1.1-r0"}, + {Name: "zlib", Version: "1.2.8-r2"}, + {Name: "libcrypto1.0", Version: "1.0.2h-r1"}, + {Name: "libssl1.0", Version: "1.0.2h-r1"}, + {Name: "apk-tools", Version: "2.6.7-r0"}, + {Name: "scanelf", Version: "1.1.6-r0"}, + {Name: "musl-utils", Version: "1.1.14-r10"}, + {Name: "libc-utils", Version: "0.7-r0"}, + } + + for i := range testFeatures { + testFeatures[i].VersionFormat = dpkg.ParserName + } + testData := []featurefmt.TestData{ { - FeatureVersions: []database.FeatureVersion{ - { - Feature: database.Feature{Name: "musl"}, - Version: "1.1.14-r10", - }, - { - Feature: database.Feature{Name: "busybox"}, - Version: "1.24.2-r9", - }, - { - Feature: database.Feature{Name: "alpine-baselayout"}, - Version: "3.0.3-r0", - }, - { - Feature: database.Feature{Name: "alpine-keys"}, - Version: "1.1-r0", - }, - { - Feature: database.Feature{Name: "zlib"}, - Version: "1.2.8-r2", - }, - { - Feature: database.Feature{Name: "libcrypto1.0"}, - Version: "1.0.2h-r1", - }, - { - Feature: database.Feature{Name: "libssl1.0"}, - Version: "1.0.2h-r1", - }, - { - Feature: database.Feature{Name: "apk-tools"}, - Version: "2.6.7-r0", - }, - { - Feature: database.Feature{Name: "scanelf"}, - Version: "1.1.6-r0", - }, - { - Feature: database.Feature{Name: "musl-utils"}, - Version: "1.1.14-r10", - }, - { - Feature: database.Feature{Name: "libc-utils"}, - Version: "0.7-r0", - }, - }, + Features: testFeatures, Files: tarutil.FilesMap{ "lib/apk/db/installed": featurefmt.LoadFileForTest("apk/testdata/installed"), }, diff --git a/ext/featurefmt/dpkg/dpkg.go b/ext/featurefmt/dpkg/dpkg.go index a0653580..6b987cf3 100644 --- a/ext/featurefmt/dpkg/dpkg.go +++ b/ext/featurefmt/dpkg/dpkg.go @@ -40,16 +40,16 @@ func init() { featurefmt.RegisterLister("dpkg", dpkg.ParserName, &lister{}) } -func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, error) { +func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.Feature, error) { f, hasFile := files["var/lib/dpkg/status"] if !hasFile { - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } // Create a map to store packages and ensure their uniqueness - packagesMap := make(map[string]database.FeatureVersion) + packagesMap := make(map[string]database.Feature) - var pkg database.FeatureVersion + var pkg database.Feature var err error scanner := bufio.NewScanner(strings.NewReader(string(f))) for scanner.Scan() { @@ -59,7 +59,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, // Package line // Defines the name of the package - pkg.Feature.Name = strings.TrimSpace(strings.TrimPrefix(line, "Package: ")) + pkg.Name = strings.TrimSpace(strings.TrimPrefix(line, "Package: ")) pkg.Version = "" } else if strings.HasPrefix(line, "Source: ") { // Source line (Optionnal) @@ -72,7 +72,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, md[dpkgSrcCaptureRegexpNames[i]] = strings.TrimSpace(n) } - pkg.Feature.Name = md["name"] + pkg.Name = md["name"] if md["version"] != "" { version := md["version"] err = versionfmt.Valid(dpkg.ParserName, version) @@ -96,21 +96,22 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, pkg.Version = version } } else if line == "" { - pkg.Feature.Name = "" + pkg.Name = "" pkg.Version = "" } // Add the package to the result array if we have all the informations - if pkg.Feature.Name != "" && pkg.Version != "" { - packagesMap[pkg.Feature.Name+"#"+pkg.Version] = pkg - pkg.Feature.Name = "" + if pkg.Name != "" && pkg.Version != "" { + packagesMap[pkg.Name+"#"+pkg.Version] = pkg + pkg.Name = "" pkg.Version = "" } } - // Convert the map to a slice - packages := make([]database.FeatureVersion, 0, len(packagesMap)) + // Convert the map to a slice and add version format. + packages := make([]database.Feature, 0, len(packagesMap)) for _, pkg := range packagesMap { + pkg.VersionFormat = dpkg.ParserName packages = append(packages, pkg) } diff --git a/ext/featurefmt/dpkg/dpkg_test.go b/ext/featurefmt/dpkg/dpkg_test.go index a9c3a8cf..1561be4f 100644 --- a/ext/featurefmt/dpkg/dpkg_test.go +++ b/ext/featurefmt/dpkg/dpkg_test.go @@ -19,28 +19,35 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/versionfmt/dpkg" "github.com/coreos/clair/pkg/tarutil" ) func TestDpkgFeatureDetection(t *testing.T) { + testFeatures := []database.Feature{ + // Two packages from this source are installed, it should only appear one time + { + Name: "pam", + Version: "1.1.8-3.1ubuntu3", + }, + { + Name: "makedev", // The source name and the package name are equals + Version: "2.3.1-93ubuntu1", // The version comes from the "Version:" line + }, + { + Name: "gcc-5", + Version: "5.1.1-12ubuntu1", // The version comes from the "Source:" line + }, + } + + for i := range testFeatures { + testFeatures[i].VersionFormat = dpkg.ParserName + } + testData := []featurefmt.TestData{ // Test an Ubuntu dpkg status file { - FeatureVersions: []database.FeatureVersion{ - // Two packages from this source are installed, it should only appear one time - { - Feature: database.Feature{Name: "pam"}, - Version: "1.1.8-3.1ubuntu3", - }, - { - Feature: database.Feature{Name: "makedev"}, // The source name and the package name are equals - Version: "2.3.1-93ubuntu1", // The version comes from the "Version:" line - }, - { - Feature: database.Feature{Name: "gcc-5"}, - Version: "5.1.1-12ubuntu1", // The version comes from the "Source:" line - }, - }, + Features: testFeatures, Files: tarutil.FilesMap{ "var/lib/dpkg/status": featurefmt.LoadFileForTest("dpkg/testdata/status"), }, diff --git a/ext/featurefmt/driver.go b/ext/featurefmt/driver.go index 8e8d593d..0f48b0e7 100644 --- a/ext/featurefmt/driver.go +++ b/ext/featurefmt/driver.go @@ -38,8 +38,8 @@ var ( // Lister represents an ability to list the features present in an image layer. type Lister interface { - // ListFeatures produces a list of FeatureVersions present in an image layer. - ListFeatures(tarutil.FilesMap) ([]database.FeatureVersion, error) + // ListFeatures produces a list of Features present in an image layer. + ListFeatures(tarutil.FilesMap) ([]database.Feature, error) // RequiredFilenames returns the list of files required to be in the FilesMap // provided to the ListFeatures method. @@ -71,34 +71,24 @@ func RegisterLister(name string, versionfmt string, l Lister) { versionfmtListerName[versionfmt] = append(versionfmtListerName[versionfmt], name) } -// ListFeatures produces the list of FeatureVersions in an image layer using +// ListFeatures produces the list of Features in an image layer using // every registered Lister. -func ListFeatures(files tarutil.FilesMap, namespace *database.Namespace) ([]database.FeatureVersion, error) { +func ListFeatures(files tarutil.FilesMap, listerNames []string) ([]database.Feature, error) { listersM.RLock() defer listersM.RUnlock() - var ( - totalFeatures []database.FeatureVersion - listersName []string - found bool - ) + var totalFeatures []database.Feature - if namespace == nil { - log.Debug("Can't detect features without namespace") - return totalFeatures, nil - } - - if listersName, found = versionfmtListerName[namespace.VersionFormat]; !found { - log.WithFields(log.Fields{"namespace": namespace.Name, "version format": namespace.VersionFormat}).Debug("Unsupported Namespace") - return totalFeatures, nil - } - - for _, listerName := range listersName { - features, err := listers[listerName].ListFeatures(files) - if err != nil { - return totalFeatures, err + for _, name := range listerNames { + if lister, ok := listers[name]; ok { + features, err := lister.ListFeatures(files) + if err != nil { + return []database.Feature{}, err + } + totalFeatures = append(totalFeatures, features...) + } else { + log.WithField("Name", name).Warn("Unknown Lister") } - totalFeatures = append(totalFeatures, features...) } return totalFeatures, nil @@ -106,7 +96,7 @@ func ListFeatures(files tarutil.FilesMap, namespace *database.Namespace) ([]data // RequiredFilenames returns the total list of files required for all // registered Listers. -func RequiredFilenames() (files []string) { +func RequiredFilenames(listerNames []string) (files []string) { listersM.RLock() defer listersM.RUnlock() @@ -117,10 +107,19 @@ func RequiredFilenames() (files []string) { return } +// ListListers returns the names of all the registered feature listers. +func ListListers() []string { + r := []string{} + for name := range listers { + r = append(r, name) + } + return r +} + // TestData represents the data used to test an implementation of Lister. type TestData struct { - Files tarutil.FilesMap - FeatureVersions []database.FeatureVersion + Files tarutil.FilesMap + Features []database.Feature } // LoadFileForTest can be used in order to obtain the []byte contents of a file @@ -136,9 +135,9 @@ func LoadFileForTest(name string) []byte { func TestLister(t *testing.T, l Lister, testData []TestData) { for _, td := range testData { featureVersions, err := l.ListFeatures(td.Files) - if assert.Nil(t, err) && assert.Len(t, featureVersions, len(td.FeatureVersions)) { - for _, expectedFeatureVersion := range td.FeatureVersions { - assert.Contains(t, featureVersions, expectedFeatureVersion) + if assert.Nil(t, err) && assert.Len(t, featureVersions, len(td.Features)) { + for _, expectedFeature := range td.Features { + assert.Contains(t, featureVersions, expectedFeature) } } } diff --git a/ext/featurefmt/rpm/rpm.go b/ext/featurefmt/rpm/rpm.go index 9e62f0fc..5a0e1fa1 100644 --- a/ext/featurefmt/rpm/rpm.go +++ b/ext/featurefmt/rpm/rpm.go @@ -38,27 +38,27 @@ func init() { featurefmt.RegisterLister("rpm", rpm.ParserName, &lister{}) } -func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, error) { +func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.Feature, error) { f, hasFile := files["var/lib/rpm/Packages"] if !hasFile { - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } // Create a map to store packages and ensure their uniqueness - packagesMap := make(map[string]database.FeatureVersion) + packagesMap := make(map[string]database.Feature) // Write the required "Packages" file to disk tmpDir, err := ioutil.TempDir(os.TempDir(), "rpm") defer os.RemoveAll(tmpDir) if err != nil { log.WithError(err).Error("could not create temporary folder for RPM detection") - return []database.FeatureVersion{}, commonerr.ErrFilesystem + return []database.Feature{}, commonerr.ErrFilesystem } err = ioutil.WriteFile(tmpDir+"/Packages", f, 0700) if err != nil { log.WithError(err).Error("could not create temporary file for RPM detection") - return []database.FeatureVersion{}, commonerr.ErrFilesystem + return []database.Feature{}, commonerr.ErrFilesystem } // Extract binary package names because RHSA refers to binary package names. @@ -67,7 +67,7 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, log.WithError(err).WithField("output", string(out)).Error("could not query RPM") // Do not bubble up because we probably won't be able to fix it, // the database must be corrupted - return []database.FeatureVersion{}, nil + return []database.Feature{}, nil } scanner := bufio.NewScanner(strings.NewReader(string(out))) @@ -93,18 +93,17 @@ func (l lister) ListFeatures(files tarutil.FilesMap) ([]database.FeatureVersion, } // Add package - pkg := database.FeatureVersion{ - Feature: database.Feature{ - Name: line[0], - }, + pkg := database.Feature{ + Name: line[0], Version: version, } - packagesMap[pkg.Feature.Name+"#"+pkg.Version] = pkg + packagesMap[pkg.Name+"#"+pkg.Version] = pkg } // Convert the map to a slice - packages := make([]database.FeatureVersion, 0, len(packagesMap)) + packages := make([]database.Feature, 0, len(packagesMap)) for _, pkg := range packagesMap { + pkg.VersionFormat = rpm.ParserName packages = append(packages, pkg) } diff --git a/ext/featurefmt/rpm/rpm_test.go b/ext/featurefmt/rpm/rpm_test.go index 1b6f531c..0b674523 100644 --- a/ext/featurefmt/rpm/rpm_test.go +++ b/ext/featurefmt/rpm/rpm_test.go @@ -19,6 +19,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/versionfmt/rpm" "github.com/coreos/clair/pkg/tarutil" ) @@ -27,16 +28,18 @@ func TestRpmFeatureDetection(t *testing.T) { // Test a CentOS 7 RPM database // Memo: Use the following command on a RPM-based system to shrink a database: rpm -qa --qf "%{NAME}\n" |tail -n +3| xargs rpm -e --justdb { - FeatureVersions: []database.FeatureVersion{ + Features: []database.Feature{ // Two packages from this source are installed, it should only appear once { - Feature: database.Feature{Name: "centos-release"}, - Version: "7-1.1503.el7.centos.2.8", + Name: "centos-release", + Version: "7-1.1503.el7.centos.2.8", + VersionFormat: rpm.ParserName, }, // Two packages from this source are installed, it should only appear once { - Feature: database.Feature{Name: "filesystem"}, - Version: "3.2-18.el7", + Name: "filesystem", + Version: "3.2-18.el7", + VersionFormat: rpm.ParserName, }, }, Files: tarutil.FilesMap{ diff --git a/ext/featurens/driver.go b/ext/featurens/driver.go index 754ed8c5..b7e0ad37 100644 --- a/ext/featurens/driver.go +++ b/ext/featurens/driver.go @@ -69,20 +69,24 @@ func RegisterDetector(name string, d Detector) { } // Detect iterators through all registered Detectors and returns all non-nil detected namespaces -func Detect(files tarutil.FilesMap) ([]database.Namespace, error) { +func Detect(files tarutil.FilesMap, detectorNames []string) ([]database.Namespace, error) { detectorsM.RLock() defer detectorsM.RUnlock() namespaces := map[string]*database.Namespace{} - for name, detector := range detectors { - namespace, err := detector.Detect(files) - if err != nil { - log.WithError(err).WithField("name", name).Warning("failed while attempting to detect namespace") - return []database.Namespace{}, err - } - - if namespace != nil { - log.WithFields(log.Fields{"name": name, "namespace": namespace.Name}).Debug("detected namespace") - namespaces[namespace.Name] = namespace + for _, name := range detectorNames { + if detector, ok := detectors[name]; ok { + namespace, err := detector.Detect(files) + if err != nil { + log.WithError(err).WithField("name", name).Warning("failed while attempting to detect namespace") + return nil, err + } + + if namespace != nil { + log.WithFields(log.Fields{"name": name, "namespace": namespace.Name}).Debug("detected namespace") + namespaces[namespace.Name] = namespace + } + } else { + log.WithField("Name", name).Warn("Unknown namespace detector") } } @@ -95,7 +99,7 @@ func Detect(files tarutil.FilesMap) ([]database.Namespace, error) { // RequiredFilenames returns the total list of files required for all // registered Detectors. -func RequiredFilenames() (files []string) { +func RequiredFilenames(detectorNames []string) (files []string) { detectorsM.RLock() defer detectorsM.RUnlock() @@ -106,6 +110,15 @@ func RequiredFilenames() (files []string) { return } +// ListDetectors returns the names of all registered namespace detectors. +func ListDetectors() []string { + r := []string{} + for name := range detectors { + r = append(r, name) + } + return r +} + // TestData represents the data used to test an implementation of Detector. type TestData struct { Files tarutil.FilesMap diff --git a/ext/featurens/driver_test.go b/ext/featurens/driver_test.go index e1a47ef6..8493c0cc 100644 --- a/ext/featurens/driver_test.go +++ b/ext/featurens/driver_test.go @@ -8,7 +8,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/pkg/tarutil" - + _ "github.com/coreos/clair/ext/featurens/alpinerelease" _ "github.com/coreos/clair/ext/featurens/aptsources" _ "github.com/coreos/clair/ext/featurens/lsbrelease" @@ -35,7 +35,7 @@ func assertnsNameEqual(t *testing.T, nslist_expected, nslist []database.Namespac func testMultipleNamespace(t *testing.T, testData []MultipleNamespaceTestData) { for _, td := range testData { - nslist, err := featurens.Detect(td.Files) + nslist, err := featurens.Detect(td.Files, featurens.ListDetectors()) assert.Nil(t, err) assertnsNameEqual(t, td.ExpectedNamespaces, nslist) } diff --git a/ext/imagefmt/driver.go b/ext/imagefmt/driver.go index 6997e93b..178de53b 100644 --- a/ext/imagefmt/driver.go +++ b/ext/imagefmt/driver.go @@ -38,7 +38,7 @@ import ( var ( // ErrCouldNotFindLayer is returned when we could not download or open the layer file. - ErrCouldNotFindLayer = commonerr.NewBadRequestError("could not find layer") + ErrCouldNotFindLayer = commonerr.NewBadRequestError("could not find layer from given path") // insecureTLS controls whether TLS server's certificate chain and hostname are verified // when pulling layers, verified in default. diff --git a/ext/notification/driver.go b/ext/notification/driver.go index 8b961ae8..2768b7e3 100644 --- a/ext/notification/driver.go +++ b/ext/notification/driver.go @@ -23,8 +23,6 @@ package notification import ( "sync" "time" - - "github.com/coreos/clair/database" ) var ( @@ -47,7 +45,7 @@ type Sender interface { Configure(*Config) (bool, error) // Send informs the existence of the specified notification. - Send(notification database.VulnerabilityNotification) error + Send(notificationName string) error } // RegisterSender makes a Sender available by the provided name. diff --git a/ext/notification/webhook/webhook.go b/ext/notification/webhook/webhook.go index d54b588b..14ef48b2 100644 --- a/ext/notification/webhook/webhook.go +++ b/ext/notification/webhook/webhook.go @@ -29,7 +29,6 @@ import ( "gopkg.in/yaml.v2" - "github.com/coreos/clair/database" "github.com/coreos/clair/ext/notification" ) @@ -112,9 +111,9 @@ type notificationEnvelope struct { } } -func (s *sender) Send(notification database.VulnerabilityNotification) error { +func (s *sender) Send(notificationName string) error { // Marshal notification. - jsonNotification, err := json.Marshal(notificationEnvelope{struct{ Name string }{notification.Name}}) + jsonNotification, err := json.Marshal(notificationEnvelope{struct{ Name string }{notificationName}}) if err != nil { return fmt.Errorf("could not marshal: %s", err) } diff --git a/ext/versionfmt/dpkg/parser.go b/ext/versionfmt/dpkg/parser.go index 2d6eefbc..a2c82ec6 100644 --- a/ext/versionfmt/dpkg/parser.go +++ b/ext/versionfmt/dpkg/parser.go @@ -120,6 +120,18 @@ func (p parser) Valid(str string) bool { return err == nil } +func (p parser) InRange(versionA, rangeB string) (bool, error) { + cmp, err := p.Compare(versionA, rangeB) + if err != nil { + return false, err + } + return cmp < 0, nil +} + +func (p parser) GetFixedIn(fixedIn string) (string, error) { + return fixedIn, nil +} + // Compare function compares two Debian-like package version // // The implementation is based on http://man.he.net/man5/deb-version diff --git a/ext/versionfmt/driver.go b/ext/versionfmt/driver.go index 42f6c5b8..03179cd1 100644 --- a/ext/versionfmt/driver.go +++ b/ext/versionfmt/driver.go @@ -19,6 +19,8 @@ package versionfmt import ( "errors" "sync" + + log "github.com/sirupsen/logrus" ) const ( @@ -50,6 +52,18 @@ type Parser interface { // Compare parses two different version strings. // Returns 0 when equal, -1 when a < b, 1 when b < a. Compare(a, b string) (int, error) + + // InRange computes if a is in range of b + // + // NOTE(Sida): For legacy version formats, rangeB is a version and + // always use if versionA < rangeB as threshold. + InRange(versionA, rangeB string) (bool, error) + + // GetFixedIn computes a fixed in version for a certain version range. + // + // NOTE(Sida): For legacy version formats, rangeA is a version and + // be returned directly becuase it was considered fixed in version. + GetFixedIn(rangeA string) (string, error) } // RegisterParser provides a way to dynamically register an implementation of a @@ -110,3 +124,28 @@ func Compare(format, versionA, versionB string) (int, error) { return versionParser.Compare(versionA, versionB) } + +// InRange is a helper function that checks if `versionA` is in `rangeB` +func InRange(format, version, versionRange string) (bool, error) { + versionParser, exists := GetParser(format) + if !exists { + return false, ErrUnknownVersionFormat + } + + in, err := versionParser.InRange(version, versionRange) + if err != nil { + log.WithFields(log.Fields{"Format": format, "Version": version, "Range": versionRange}).Error(err) + } + return in, err +} + +// GetFixedIn is a helper function that computes the next fixed in version given +// a affected version range `rangeA`. +func GetFixedIn(format, rangeA string) (string, error) { + versionParser, exists := GetParser(format) + if !exists { + return "", ErrUnknownVersionFormat + } + + return versionParser.GetFixedIn(rangeA) +} diff --git a/ext/versionfmt/rpm/parser.go b/ext/versionfmt/rpm/parser.go index 34fbb9b9..55266ca5 100644 --- a/ext/versionfmt/rpm/parser.go +++ b/ext/versionfmt/rpm/parser.go @@ -121,6 +121,20 @@ func (p parser) Valid(str string) bool { return err == nil } +func (p parser) InRange(versionA, rangeB string) (bool, error) { + cmp, err := p.Compare(versionA, rangeB) + if err != nil { + return false, err + } + return cmp < 0, nil +} + +func (p parser) GetFixedIn(fixedIn string) (string, error) { + // In the old version format parser design, the string to determine fixed in + // version is the fixed in version. + return fixedIn, nil +} + func (p parser) Compare(a, b string) (int, error) { v1, err := newVersion(a) if err != nil { diff --git a/ext/vulnsrc/alpine/alpine.go b/ext/vulnsrc/alpine/alpine.go index 271b6553..5b6f46e1 100644 --- a/ext/vulnsrc/alpine/alpine.go +++ b/ext/vulnsrc/alpine/alpine.go @@ -60,10 +60,20 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er // Ask the database for the latest commit we successfully applied. var dbCommit string - dbCommit, err = db.GetKeyValue(updaterFlag) + tx, err := db.Begin() if err != nil { return } + defer tx.Rollback() + + dbCommit, ok, err := tx.FindKeyValue(updaterFlag) + if err != nil { + return + } + + if !ok { + dbCommit = "" + } // Set the updaterFlag to equal the commit processed. resp.FlagName = updaterFlag @@ -84,7 +94,7 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er // Append any changed vulnerabilities to the response. for _, namespace := range namespaces { - var vulns []database.Vulnerability + var vulns []database.VulnerabilityWithAffected var note string vulns, note, err = parseVulnsFromNamespace(u.repositoryLocalPath, namespace) if err != nil { @@ -144,7 +154,7 @@ func ls(path string, filter lsFilter) ([]string, error) { return files, nil } -func parseVulnsFromNamespace(repositoryPath, namespace string) (vulns []database.Vulnerability, note string, err error) { +func parseVulnsFromNamespace(repositoryPath, namespace string) (vulns []database.VulnerabilityWithAffected, note string, err error) { nsDir := filepath.Join(repositoryPath, namespace) var dbFilenames []string dbFilenames, err = ls(nsDir, filesOnly) @@ -159,7 +169,7 @@ func parseVulnsFromNamespace(repositoryPath, namespace string) (vulns []database return } - var fileVulns []database.Vulnerability + var fileVulns []database.VulnerabilityWithAffected fileVulns, err = parseYAML(file) if err != nil { return @@ -216,7 +226,7 @@ type secDBFile struct { } `yaml:"packages"` } -func parseYAML(r io.Reader) (vulns []database.Vulnerability, err error) { +func parseYAML(r io.Reader) (vulns []database.VulnerabilityWithAffected, err error) { var rBytes []byte rBytes, err = ioutil.ReadAll(r) if err != nil { @@ -239,20 +249,24 @@ func parseYAML(r io.Reader) (vulns []database.Vulnerability, err error) { } for _, vulnStr := range vulnStrs { - var vuln database.Vulnerability + var vuln database.VulnerabilityWithAffected vuln.Severity = database.UnknownSeverity vuln.Name = vulnStr vuln.Link = nvdURLPrefix + vulnStr - vuln.FixedIn = []database.FeatureVersion{ + + var fixedInVersion string + if version != versionfmt.MaxVersion { + fixedInVersion = version + } + vuln.Affected = []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "alpine:" + file.Distro, - VersionFormat: dpkg.ParserName, - }, - Name: pkg.Name, + FeatureName: pkg.Name, + AffectedVersion: version, + FixedInVersion: fixedInVersion, + Namespace: database.Namespace{ + Name: "alpine:" + file.Distro, + VersionFormat: dpkg.ParserName, }, - Version: version, }, } vulns = append(vulns, vuln) diff --git a/ext/vulnsrc/alpine/alpine_test.go b/ext/vulnsrc/alpine/alpine_test.go index ac95f5c5..eddcc759 100644 --- a/ext/vulnsrc/alpine/alpine_test.go +++ b/ext/vulnsrc/alpine/alpine_test.go @@ -36,7 +36,7 @@ func TestYAMLParsing(t *testing.T) { } assert.Equal(t, 105, len(vulns)) assert.Equal(t, "CVE-2016-5387", vulns[0].Name) - assert.Equal(t, "alpine:v3.4", vulns[0].FixedIn[0].Feature.Namespace.Name) - assert.Equal(t, "apache2", vulns[0].FixedIn[0].Feature.Name) + assert.Equal(t, "alpine:v3.4", vulns[0].Affected[0].Namespace.Name) + assert.Equal(t, "apache2", vulns[0].Affected[0].FeatureName) assert.Equal(t, "https://cve.mitre.org/cgi-bin/cvename.cgi?name=CVE-2016-5387", vulns[0].Link) } diff --git a/ext/vulnsrc/debian/debian.go b/ext/vulnsrc/debian/debian.go index 3288e46b..c0efc37e 100644 --- a/ext/vulnsrc/debian/debian.go +++ b/ext/vulnsrc/debian/debian.go @@ -62,19 +62,34 @@ func init() { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", "Debian").Info("Start fetching vulnerabilities") - // Download JSON. - r, err := http.Get(url) + tx, err := datastore.Begin() if err != nil { - log.WithError(err).Error("could not download Debian's update") - return resp, commonerr.ErrCouldNotDownload + return resp, err } // Get the SHA-1 of the latest update's JSON data - latestHash, err := datastore.GetKeyValue(updaterFlag) + latestHash, ok, err := tx.FindKeyValue(updaterFlag) if err != nil { return resp, err } + // NOTE(sida): The transaction won't mutate the database and I want the + // transaction to be short. + if err := tx.Rollback(); err != nil { + return resp, err + } + + if !ok { + latestHash = "" + } + + // Download JSON. + r, err := http.Get(url) + if err != nil { + log.WithError(err).Error("could not download Debian's update") + return resp, commonerr.ErrCouldNotDownload + } + // Parse the JSON. resp, err = buildResponse(r.Body, latestHash) if err != nil { @@ -131,8 +146,8 @@ func buildResponse(jsonReader io.Reader, latestKnownHash string) (resp vulnsrc.U return resp, nil } -func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, unknownReleases map[string]struct{}) { - mvulnerabilities := make(map[string]*database.Vulnerability) +func parseDebianJSON(data *jsonData) (vulnerabilities []database.VulnerabilityWithAffected, unknownReleases map[string]struct{}) { + mvulnerabilities := make(map[string]*database.VulnerabilityWithAffected) unknownReleases = make(map[string]struct{}) for pkgName, pkgNode := range *data { @@ -145,6 +160,7 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, } // Skip if the status is not determined or the vulnerability is a temporary one. + // TODO: maybe add "undetermined" as Unknown severity. if !strings.HasPrefix(vulnName, "CVE-") || releaseNode.Status == "undetermined" { continue } @@ -152,11 +168,13 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, // Get or create the vulnerability. vulnerability, vulnerabilityAlreadyExists := mvulnerabilities[vulnName] if !vulnerabilityAlreadyExists { - vulnerability = &database.Vulnerability{ - Name: vulnName, - Link: strings.Join([]string{cveURLPrefix, "/", vulnName}, ""), - Severity: database.UnknownSeverity, - Description: vulnNode.Description, + vulnerability = &database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: vulnName, + Link: strings.Join([]string{cveURLPrefix, "/", vulnName}, ""), + Severity: database.UnknownSeverity, + Description: vulnNode.Description, + }, } } @@ -171,10 +189,7 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, // Determine the version of the package the vulnerability affects. var version string var err error - if releaseNode.FixedVersion == "0" { - // This means that the package is not affected by this vulnerability. - version = versionfmt.MinVersion - } else if releaseNode.Status == "open" { + if releaseNode.Status == "open" { // Open means that the package is currently vulnerable in the latest // version of this Debian release. version = versionfmt.MaxVersion @@ -186,21 +201,34 @@ func parseDebianJSON(data *jsonData) (vulnerabilities []database.Vulnerability, log.WithError(err).WithField("version", version).Warning("could not parse package version. skipping") continue } - version = releaseNode.FixedVersion + + // FixedVersion = "0" means that the vulnerability affecting + // current feature is not that important + if releaseNode.FixedVersion != "0" { + version = releaseNode.FixedVersion + } + } + + if version == "" { + continue + } + + var fixedInVersion string + if version != versionfmt.MaxVersion { + fixedInVersion = version } // Create and add the feature version. - pkg := database.FeatureVersion{ - Feature: database.Feature{ - Name: pkgName, - Namespace: database.Namespace{ - Name: "debian:" + database.DebianReleasesMapping[releaseName], - VersionFormat: dpkg.ParserName, - }, + pkg := database.AffectedFeature{ + FeatureName: pkgName, + AffectedVersion: version, + FixedInVersion: fixedInVersion, + Namespace: database.Namespace{ + Name: "debian:" + database.DebianReleasesMapping[releaseName], + VersionFormat: dpkg.ParserName, }, - Version: version, } - vulnerability.FixedIn = append(vulnerability.FixedIn, pkg) + vulnerability.Affected = append(vulnerability.Affected, pkg) // Store the vulnerability. mvulnerabilities[vulnName] = vulnerability @@ -223,30 +251,16 @@ func SeverityFromUrgency(urgency string) database.Severity { case "not yet assigned": return database.UnknownSeverity - case "end-of-life": - fallthrough - case "unimportant": + case "end-of-life", "unimportant": return database.NegligibleSeverity - case "low": - fallthrough - case "low*": - fallthrough - case "low**": + case "low", "low*", "low**": return database.LowSeverity - case "medium": - fallthrough - case "medium*": - fallthrough - case "medium**": + case "medium", "medium*", "medium**": return database.MediumSeverity - case "high": - fallthrough - case "high*": - fallthrough - case "high**": + case "high", "high*", "high**": return database.HighSeverity default: diff --git a/ext/vulnsrc/debian/debian_test.go b/ext/vulnsrc/debian/debian_test.go index 1c62500c..3a6f9ace 100644 --- a/ext/vulnsrc/debian/debian_test.go +++ b/ext/vulnsrc/debian/debian_test.go @@ -32,103 +32,76 @@ func TestDebianParser(t *testing.T) { // Test parsing testdata/fetcher_debian_test.json testFile, _ := os.Open(filepath.Join(filepath.Dir(filename)) + "/testdata/fetcher_debian_test.json") response, err := buildResponse(testFile, "") - if assert.Nil(t, err) && assert.Len(t, response.Vulnerabilities, 3) { + if assert.Nil(t, err) && assert.Len(t, response.Vulnerabilities, 2) { for _, vulnerability := range response.Vulnerabilities { if vulnerability.Name == "CVE-2015-1323" { assert.Equal(t, "https://security-tracker.debian.org/tracker/CVE-2015-1323", vulnerability.Link) assert.Equal(t, database.LowSeverity, vulnerability.Severity) assert.Equal(t, "This vulnerability is not very dangerous.", vulnerability.Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, }, - Version: versionfmt.MaxVersion, + FeatureName: "aptdaemon", + AffectedVersion: versionfmt.MaxVersion, }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:unstable", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:unstable", + VersionFormat: dpkg.ParserName, }, - Version: "1.1.1+bzr982-1", + FeatureName: "aptdaemon", + AffectedVersion: "1.1.1+bzr982-1", + FixedInVersion: "1.1.1+bzr982-1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerability.Affected, expectedFeature) } } else if vulnerability.Name == "CVE-2003-0779" { assert.Equal(t, "https://security-tracker.debian.org/tracker/CVE-2003-0779", vulnerability.Link) assert.Equal(t, database.HighSeverity, vulnerability.Severity) assert.Equal(t, "But this one is very dangerous.", vulnerability.Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, }, - Version: "0.7.0", + FeatureName: "aptdaemon", + FixedInVersion: "0.7.0", + AffectedVersion: "0.7.0", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:unstable", - VersionFormat: dpkg.ParserName, - }, - Name: "aptdaemon", + Namespace: database.Namespace{ + Name: "debian:unstable", + VersionFormat: dpkg.ParserName, }, - Version: "0.7.0", + FeatureName: "aptdaemon", + FixedInVersion: "0.7.0", + AffectedVersion: "0.7.0", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "asterisk", + Namespace: database.Namespace{ + Name: "debian:8", + VersionFormat: dpkg.ParserName, }, - Version: "0.5.56", + FeatureName: "asterisk", + FixedInVersion: "0.5.56", + AffectedVersion: "0.5.56", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) - } - } else if vulnerability.Name == "CVE-2013-2685" { - assert.Equal(t, "https://security-tracker.debian.org/tracker/CVE-2013-2685", vulnerability.Link) - assert.Equal(t, database.NegligibleSeverity, vulnerability.Severity) - assert.Equal(t, "Un-affected packages.", vulnerability.Description) - - expectedFeatureVersions := []database.FeatureVersion{ - { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "debian:8", - VersionFormat: dpkg.ParserName, - }, - Name: "asterisk", - }, - Version: versionfmt.MinVersion, - }, - } - - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerability.Affected, expectedFeature) } } else { - assert.Fail(t, "Wrong vulnerability name: ", vulnerability.ID) + assert.Fail(t, "Wrong vulnerability name: ", vulnerability.Namespace.Name+":"+vulnerability.Name) } } } diff --git a/ext/vulnsrc/driver.go b/ext/vulnsrc/driver.go index fd442416..91b28831 100644 --- a/ext/vulnsrc/driver.go +++ b/ext/vulnsrc/driver.go @@ -39,11 +39,10 @@ type UpdateResponse struct { FlagName string FlagValue string Notes []string - Vulnerabilities []database.Vulnerability + Vulnerabilities []database.VulnerabilityWithAffected } -// Updater represents anything that can fetch vulnerabilities and insert them -// into a Clair datastore. +// Updater represents anything that can fetch vulnerabilities. type Updater interface { // Update gets vulnerability updates. Update(database.Datastore) (UpdateResponse, error) @@ -88,3 +87,12 @@ func Updaters() map[string]Updater { return ret } + +// ListUpdaters returns the names of registered vulnerability updaters. +func ListUpdaters() []string { + r := []string{} + for u := range updaters { + r = append(r, u) + } + return r +} diff --git a/ext/vulnsrc/oracle/oracle.go b/ext/vulnsrc/oracle/oracle.go index ee6a8343..40dcd669 100644 --- a/ext/vulnsrc/oracle/oracle.go +++ b/ext/vulnsrc/oracle/oracle.go @@ -118,10 +118,20 @@ func compareELSA(left, right int) int { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", "Oracle Linux").Info("Start fetching vulnerabilities") // Get the first ELSA we have to manage. - flagValue, err := datastore.GetKeyValue(updaterFlag) + tx, err := datastore.Begin() if err != nil { return resp, err } + defer tx.Rollback() + + flagValue, ok, err := tx.FindKeyValue(updaterFlag) + if err != nil { + return resp, err + } + + if !ok { + flagValue = "" + } firstELSA, err := strconv.Atoi(flagValue) if firstELSA == 0 || err != nil { @@ -192,7 +202,7 @@ func largest(list []int) (largest int) { func (u *updater) Clean() {} -func parseELSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, err error) { +func parseELSA(ovalReader io.Reader) (vulnerabilities []database.VulnerabilityWithAffected, err error) { // Decode the XML. var ov oval err = xml.NewDecoder(ovalReader).Decode(&ov) @@ -205,16 +215,18 @@ func parseELSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, // Iterate over the definitions and collect any vulnerabilities that affect // at least one package. for _, definition := range ov.Definitions { - pkgs := toFeatureVersions(definition.Criteria) + pkgs := toFeatures(definition.Criteria) if len(pkgs) > 0 { - vulnerability := database.Vulnerability{ - Name: name(definition), - Link: link(definition), - Severity: severity(definition), - Description: description(definition), + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: name(definition), + Link: link(definition), + Severity: severity(definition), + Description: description(definition), + }, } for _, p := range pkgs { - vulnerability.FixedIn = append(vulnerability.FixedIn, p) + vulnerability.Affected = append(vulnerability.Affected, p) } vulnerabilities = append(vulnerabilities, vulnerability) } @@ -298,15 +310,15 @@ func getPossibilities(node criteria) [][]criterion { return possibilities } -func toFeatureVersions(criteria criteria) []database.FeatureVersion { +func toFeatures(criteria criteria) []database.AffectedFeature { // There are duplicates in Oracle .xml files. // This map is for deduplication. - featureVersionParameters := make(map[string]database.FeatureVersion) + featureVersionParameters := make(map[string]database.AffectedFeature) possibilities := getPossibilities(criteria) for _, criterions := range possibilities { var ( - featureVersion database.FeatureVersion + featureVersion database.AffectedFeature osVersion int err error ) @@ -321,29 +333,32 @@ func toFeatureVersions(criteria criteria) []database.FeatureVersion { } } else if strings.Contains(c.Comment, " is earlier than ") { const prefixLen = len(" is earlier than ") - featureVersion.Feature.Name = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) + featureVersion.FeatureName = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) version := c.Comment[strings.Index(c.Comment, " is earlier than ")+prefixLen:] err := versionfmt.Valid(rpm.ParserName, version) if err != nil { log.WithError(err).WithField("version", version).Warning("could not parse package version. skipping") } else { - featureVersion.Version = version + featureVersion.AffectedVersion = version + if version != versionfmt.MaxVersion { + featureVersion.FixedInVersion = version + } } } } - featureVersion.Feature.Namespace.Name = "oracle" + ":" + strconv.Itoa(osVersion) - featureVersion.Feature.Namespace.VersionFormat = rpm.ParserName + featureVersion.Namespace.Name = "oracle" + ":" + strconv.Itoa(osVersion) + featureVersion.Namespace.VersionFormat = rpm.ParserName - if featureVersion.Feature.Namespace.Name != "" && featureVersion.Feature.Name != "" && featureVersion.Version != "" { - featureVersionParameters[featureVersion.Feature.Namespace.Name+":"+featureVersion.Feature.Name] = featureVersion + if featureVersion.Namespace.Name != "" && featureVersion.FeatureName != "" && featureVersion.AffectedVersion != "" && featureVersion.FixedInVersion != "" { + featureVersionParameters[featureVersion.Namespace.Name+":"+featureVersion.FeatureName] = featureVersion } else { log.WithField("criterions", fmt.Sprintf("%v", criterions)).Warning("could not determine a valid package from criterions") } } // Convert the map to slice. - var featureVersionParametersArray []database.FeatureVersion + var featureVersionParametersArray []database.AffectedFeature for _, fv := range featureVersionParameters { featureVersionParametersArray = append(featureVersionParametersArray, fv) } diff --git a/ext/vulnsrc/oracle/oracle_test.go b/ext/vulnsrc/oracle/oracle_test.go index bab98bcc..a9348d48 100644 --- a/ext/vulnsrc/oracle/oracle_test.go +++ b/ext/vulnsrc/oracle/oracle_test.go @@ -40,41 +40,38 @@ func TestOracleParser(t *testing.T) { assert.Equal(t, database.MediumSeverity, vulnerabilities[0].Severity) assert.Equal(t, ` [3.1.1-7] Resolves: rhbz#1217104 CVE-2015-0252 `, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c", + FixedInVersion: "0:3.1.1-7.el7_1", + AffectedVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-devel", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-devel", + FixedInVersion: "0:3.1.1-7.el7_1", + AffectedVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-doc", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-doc", + FixedInVersion: "0:3.1.1-7.el7_1", + AffectedVersion: "0:3.1.1-7.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } @@ -87,31 +84,29 @@ func TestOracleParser(t *testing.T) { assert.Equal(t, "http://linux.oracle.com/errata/ELSA-2015-1207.html", vulnerabilities[0].Link) assert.Equal(t, database.CriticalSeverity, vulnerabilities[0].Severity) assert.Equal(t, ` [38.1.0-1.0.1.el7_1] - Add firefox-oracle-default-prefs.js and remove the corresponding Red Hat file [38.1.0-1] - Update to 38.1.0 ESR [38.0.1-2] - Fixed rhbz#1222807 by removing preun section `, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:6", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "oracle:6", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.0.1.el6_6", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.0.1.el6_6", + AffectedVersion: "0:38.1.0-1.0.1.el6_6", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "oracle:7", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "oracle:7", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.0.1.el7_1", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.0.1.el7_1", + AffectedVersion: "0:38.1.0-1.0.1.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } } diff --git a/ext/vulnsrc/rhel/rhel.go b/ext/vulnsrc/rhel/rhel.go index bbd48c15..f4cbce8f 100644 --- a/ext/vulnsrc/rhel/rhel.go +++ b/ext/vulnsrc/rhel/rhel.go @@ -90,11 +90,26 @@ func init() { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", "RHEL").Info("Start fetching vulnerabilities") + + tx, err := datastore.Begin() + if err != nil { + return resp, err + } + // Get the first RHSA we have to manage. - flagValue, err := datastore.GetKeyValue(updaterFlag) + flagValue, ok, err := tx.FindKeyValue(updaterFlag) if err != nil { return resp, err } + + if err := tx.Rollback(); err != nil { + return resp, err + } + + if !ok { + flagValue = "" + } + firstRHSA, err := strconv.Atoi(flagValue) if firstRHSA == 0 || err != nil { firstRHSA = firstRHEL5RHSA @@ -154,7 +169,7 @@ func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateRespo func (u *updater) Clean() {} -func parseRHSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, err error) { +func parseRHSA(ovalReader io.Reader) (vulnerabilities []database.VulnerabilityWithAffected, err error) { // Decode the XML. var ov oval err = xml.NewDecoder(ovalReader).Decode(&ov) @@ -167,16 +182,18 @@ func parseRHSA(ovalReader io.Reader) (vulnerabilities []database.Vulnerability, // Iterate over the definitions and collect any vulnerabilities that affect // at least one package. for _, definition := range ov.Definitions { - pkgs := toFeatureVersions(definition.Criteria) + pkgs := toFeatures(definition.Criteria) if len(pkgs) > 0 { - vulnerability := database.Vulnerability{ - Name: name(definition), - Link: link(definition), - Severity: severity(definition), - Description: description(definition), + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: name(definition), + Link: link(definition), + Severity: severity(definition), + Description: description(definition), + }, } for _, p := range pkgs { - vulnerability.FixedIn = append(vulnerability.FixedIn, p) + vulnerability.Affected = append(vulnerability.Affected, p) } vulnerabilities = append(vulnerabilities, vulnerability) } @@ -260,15 +277,15 @@ func getPossibilities(node criteria) [][]criterion { return possibilities } -func toFeatureVersions(criteria criteria) []database.FeatureVersion { +func toFeatures(criteria criteria) []database.AffectedFeature { // There are duplicates in Red Hat .xml files. // This map is for deduplication. - featureVersionParameters := make(map[string]database.FeatureVersion) + featureVersionParameters := make(map[string]database.AffectedFeature) possibilities := getPossibilities(criteria) for _, criterions := range possibilities { var ( - featureVersion database.FeatureVersion + featureVersion database.AffectedFeature osVersion int err error ) @@ -283,34 +300,37 @@ func toFeatureVersions(criteria criteria) []database.FeatureVersion { } } else if strings.Contains(c.Comment, " is earlier than ") { const prefixLen = len(" is earlier than ") - featureVersion.Feature.Name = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) + featureVersion.FeatureName = strings.TrimSpace(c.Comment[:strings.Index(c.Comment, " is earlier than ")]) version := c.Comment[strings.Index(c.Comment, " is earlier than ")+prefixLen:] err := versionfmt.Valid(rpm.ParserName, version) if err != nil { log.WithError(err).WithField("version", version).Warning("could not parse package version. skipping") } else { - featureVersion.Version = version - featureVersion.Feature.Namespace.VersionFormat = rpm.ParserName + featureVersion.AffectedVersion = version + if version != versionfmt.MaxVersion { + featureVersion.FixedInVersion = version + } + featureVersion.Namespace.VersionFormat = rpm.ParserName } } } if osVersion >= firstConsideredRHEL { // TODO(vbatts) this is where features need multiple labels ('centos' and 'rhel') - featureVersion.Feature.Namespace.Name = "centos" + ":" + strconv.Itoa(osVersion) + featureVersion.Namespace.Name = "centos" + ":" + strconv.Itoa(osVersion) } else { continue } - if featureVersion.Feature.Namespace.Name != "" && featureVersion.Feature.Name != "" && featureVersion.Version != "" { - featureVersionParameters[featureVersion.Feature.Namespace.Name+":"+featureVersion.Feature.Name] = featureVersion + if featureVersion.Namespace.Name != "" && featureVersion.FeatureName != "" && featureVersion.AffectedVersion != "" && featureVersion.FixedInVersion != "" { + featureVersionParameters[featureVersion.Namespace.Name+":"+featureVersion.FeatureName] = featureVersion } else { log.WithField("criterions", fmt.Sprintf("%v", criterions)).Warning("could not determine a valid package from criterions") } } // Convert the map to slice. - var featureVersionParametersArray []database.FeatureVersion + var featureVersionParametersArray []database.AffectedFeature for _, fv := range featureVersionParameters { featureVersionParametersArray = append(featureVersionParametersArray, fv) } diff --git a/ext/vulnsrc/rhel/rhel_test.go b/ext/vulnsrc/rhel/rhel_test.go index db762610..e91ec502 100644 --- a/ext/vulnsrc/rhel/rhel_test.go +++ b/ext/vulnsrc/rhel/rhel_test.go @@ -38,41 +38,38 @@ func TestRHELParser(t *testing.T) { assert.Equal(t, database.MediumSeverity, vulnerabilities[0].Severity) assert.Equal(t, `Xerces-C is a validating XML parser written in a portable subset of C++. A flaw was found in the way the Xerces-C XML parser processed certain XML documents. A remote attacker could provide specially crafted XML input that, when parsed by an application using Xerces-C, would cause that application to crash.`, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c", + AffectedVersion: "0:3.1.1-7.el7_1", + FixedInVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-devel", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-devel", + AffectedVersion: "0:3.1.1-7.el7_1", + FixedInVersion: "0:3.1.1-7.el7_1", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "xerces-c-doc", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:3.1.1-7.el7_1", + FeatureName: "xerces-c-doc", + AffectedVersion: "0:3.1.1-7.el7_1", + FixedInVersion: "0:3.1.1-7.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } @@ -85,31 +82,29 @@ func TestRHELParser(t *testing.T) { assert.Equal(t, database.CriticalSeverity, vulnerabilities[0].Severity) assert.Equal(t, `Mozilla Firefox is an open source web browser. XULRunner provides the XUL Runtime environment for Mozilla Firefox. Several flaws were found in the processing of malformed web content. A web page containing malicious content could cause Firefox to crash or, potentially, execute arbitrary code with the privileges of the user running Firefox.`, vulnerabilities[0].Description) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:6", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "centos:6", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.el6_6", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.el6_6", + AffectedVersion: "0:38.1.0-1.el6_6", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "centos:7", - VersionFormat: rpm.ParserName, - }, - Name: "firefox", + Namespace: database.Namespace{ + Name: "centos:7", + VersionFormat: rpm.ParserName, }, - Version: "0:38.1.0-1.el7_1", + FeatureName: "firefox", + FixedInVersion: "0:38.1.0-1.el7_1", + AffectedVersion: "0:38.1.0-1.el7_1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerabilities[0].FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerabilities[0].Affected, expectedFeature) } } } diff --git a/ext/vulnsrc/ubuntu/ubuntu.go b/ext/vulnsrc/ubuntu/ubuntu.go index 28803c76..6af0c14c 100644 --- a/ext/vulnsrc/ubuntu/ubuntu.go +++ b/ext/vulnsrc/ubuntu/ubuntu.go @@ -98,12 +98,25 @@ func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateRespo return resp, err } + tx, err := datastore.Begin() + if err != nil { + return resp, err + } + // Get the latest revision number we successfully applied in the database. - dbRevisionNumber, err := datastore.GetKeyValue("ubuntuUpdater") + dbRevisionNumber, ok, err := tx.FindKeyValue("ubuntuUpdater") if err != nil { return resp, err } + if err := tx.Rollback(); err != nil { + return resp, err + } + + if !ok { + dbRevisionNumber = "" + } + // Get the list of vulnerabilities that we have to update. modifiedCVE, err := collectModifiedVulnerabilities(revisionNumber, dbRevisionNumber, u.repositoryLocalPath) if err != nil { @@ -278,11 +291,15 @@ func collectModifiedVulnerabilities(revision int, dbRevision, repositoryLocalPat return modifiedCVE, nil } -func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability, unknownReleases map[string]struct{}, err error) { +func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.VulnerabilityWithAffected, unknownReleases map[string]struct{}, err error) { unknownReleases = make(map[string]struct{}) readingDescription := false scanner := bufio.NewScanner(fileContent) + // only unique major releases will be considered. All sub releases' (e.g. + // precise/esm) features are considered belong to major releases. + uniqueRelease := map[string]struct{}{} + for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -344,7 +361,7 @@ func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability // Only consider the package if its status is needed, active, deferred, not-affected or // released. Ignore DNE (package does not exist), needs-triage, ignored, pending. if md["status"] == "needed" || md["status"] == "active" || md["status"] == "deferred" || md["status"] == "released" || md["status"] == "not-affected" { - md["release"] = strings.Split(md["release"], "/")[0] + md["release"] = strings.Split(md["release"], "/")[0] if _, isReleaseIgnored := ubuntuIgnoredReleases[md["release"]]; isReleaseIgnored { continue } @@ -363,8 +380,6 @@ func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability } version = md["note"] } - } else if md["status"] == "not-affected" { - version = versionfmt.MinVersion } else { version = versionfmt.MaxVersion } @@ -372,18 +387,30 @@ func parseUbuntuCVE(fileContent io.Reader) (vulnerability database.Vulnerability continue } + releaseName := "ubuntu:" + database.UbuntuReleasesMapping[md["release"]] + if _, ok := uniqueRelease[releaseName+"_:_"+md["package"]]; ok { + continue + } + + uniqueRelease[releaseName+"_:_"+md["package"]] = struct{}{} + var fixedinVersion string + if version == versionfmt.MaxVersion { + fixedinVersion = "" + } else { + fixedinVersion = version + } + // Create and add the new package. - featureVersion := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:" + database.UbuntuReleasesMapping[md["release"]], - VersionFormat: dpkg.ParserName, - }, - Name: md["package"], + featureVersion := database.AffectedFeature{ + Namespace: database.Namespace{ + Name: releaseName, + VersionFormat: dpkg.ParserName, }, - Version: version, + FeatureName: md["package"], + AffectedVersion: version, + FixedInVersion: fixedinVersion, } - vulnerability.FixedIn = append(vulnerability.FixedIn, featureVersion) + vulnerability.Affected = append(vulnerability.Affected, featureVersion) } } } diff --git a/ext/vulnsrc/ubuntu/ubuntu_test.go b/ext/vulnsrc/ubuntu/ubuntu_test.go index 5cdbd9a4..a4bd8afd 100644 --- a/ext/vulnsrc/ubuntu/ubuntu_test.go +++ b/ext/vulnsrc/ubuntu/ubuntu_test.go @@ -44,41 +44,37 @@ func TestUbuntuParser(t *testing.T) { _, hasUnkownRelease := unknownReleases["unknown"] assert.True(t, hasUnkownRelease) - expectedFeatureVersions := []database.FeatureVersion{ + expectedFeatures := []database.AffectedFeature{ { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:14.04", - VersionFormat: dpkg.ParserName, - }, - Name: "libmspack", + Namespace: database.Namespace{ + Name: "ubuntu:14.04", + VersionFormat: dpkg.ParserName, }, - Version: versionfmt.MaxVersion, + FeatureName: "libmspack", + AffectedVersion: versionfmt.MaxVersion, }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:15.04", - VersionFormat: dpkg.ParserName, - }, - Name: "libmspack", + Namespace: database.Namespace{ + Name: "ubuntu:15.04", + VersionFormat: dpkg.ParserName, }, - Version: "0.4-3", + FeatureName: "libmspack", + FixedInVersion: "0.4-3", + AffectedVersion: "0.4-3", }, { - Feature: database.Feature{ - Namespace: database.Namespace{ - Name: "ubuntu:15.10", - VersionFormat: dpkg.ParserName, - }, - Name: "libmspack-anotherpkg", + Namespace: database.Namespace{ + Name: "ubuntu:15.10", + VersionFormat: dpkg.ParserName, }, - Version: "0.1", + FeatureName: "libmspack-anotherpkg", + FixedInVersion: "0.1", + AffectedVersion: "0.1", }, } - for _, expectedFeatureVersion := range expectedFeatureVersions { - assert.Contains(t, vulnerability.FixedIn, expectedFeatureVersion) + for _, expectedFeature := range expectedFeatures { + assert.Contains(t, vulnerability.Affected, expectedFeature) } } } diff --git a/notifier.go b/notifier.go index ad3e947c..3b4d5f49 100644 --- a/notifier.go +++ b/notifier.go @@ -24,7 +24,6 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/ext/notification" - "github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/stopper" ) @@ -94,14 +93,16 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop go func() { success, interrupted := handleTask(*notification, stopper, config.Attempts) if success { - datastore.SetNotificationNotified(notification.Name) - + err := markNotificationNotified(datastore, notification.Name) + if err != nil { + log.WithError(err).Error("Failed to mark notification notified") + } promNotifierLatencyMilliseconds.Observe(float64(time.Since(notification.Created).Nanoseconds()) / float64(time.Millisecond)) } if interrupted { running = false } - datastore.Unlock(notification.Name, whoAmI) + unlock(datastore, notification.Name, whoAmI) done <- true }() @@ -112,7 +113,10 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop case <-done: break outer case <-time.After(notifierLockRefreshDuration): - datastore.Lock(notification.Name, whoAmI, notifierLockDuration, true) + lock(datastore, notification.Name, whoAmI, notifierLockDuration, true) + case <-stopper.Chan(): + running = false + break } } } @@ -120,13 +124,11 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop log.Info("notifier service stopped") } -func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.VulnerabilityNotification { +func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook { for { - // Find a notification to send. - notification, err := datastore.GetAvailableNotification(renotifyInterval) - if err != nil { - // There is no notification or an error occurred. - if err != commonerr.ErrNotFound { + notification, ok, err := findNewNotification(datastore, renotifyInterval) + if err != nil || !ok { + if !ok { log.WithError(err).Warning("could not get notification to send") } @@ -139,14 +141,14 @@ func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoA } // Lock the notification. - if hasLock, _ := datastore.Lock(notification.Name, whoAmI, notifierLockDuration, false); hasLock { + if hasLock, _ := lock(datastore, notification.Name, whoAmI, notifierLockDuration, false); hasLock { log.WithField(logNotiName, notification.Name).Info("found and locked a notification") return ¬ification } } } -func handleTask(n database.VulnerabilityNotification, st *stopper.Stopper, maxAttempts int) (bool, bool) { +func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts int) (bool, bool) { // Send notification. for senderName, sender := range notification.Senders() { var attempts int @@ -167,7 +169,7 @@ func handleTask(n database.VulnerabilityNotification, st *stopper.Stopper, maxAt } // Send using the current notifier. - if err := sender.Send(n); err != nil { + if err := sender.Send(n.Name); err != nil { // Send failed; increase attempts/backoff and retry. promNotifierBackendErrorsTotal.WithLabelValues(senderName).Inc() log.WithError(err).WithFields(log.Fields{logSenderName: senderName, logNotiName: n.Name}).Error("could not send notification via notifier") @@ -184,3 +186,66 @@ func handleTask(n database.VulnerabilityNotification, st *stopper.Stopper, maxAt log.WithField(logNotiName, n.Name).Info("successfully sent notification") return true, false } + +func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) { + tx, err := datastore.Begin() + if err != nil { + return database.NotificationHook{}, false, err + } + defer tx.Rollback() + return tx.FindNewNotification(time.Now().Add(-renotifyInterval)) +} + +func markNotificationNotified(datastore database.Datastore, name string) error { + tx, err := datastore.Begin() + if err != nil { + log.WithError(err).Error("an error happens when beginning database transaction") + } + defer tx.Rollback() + + if err := tx.MarkNotificationNotified(name); err != nil { + return err + } + return tx.Commit() +} + +// unlock removes a lock with provided name, owner. Internally, it handles +// database transaction and catches error. +func unlock(datastore database.Datastore, name, owner string) { + tx, err := datastore.Begin() + if err != nil { + return + } + + defer tx.Rollback() + + if err := tx.Unlock(name, owner); err != nil { + return + } + if err := tx.Commit(); err != nil { + return + } +} + +func lock(datastore database.Datastore, name string, owner string, duration time.Duration, renew bool) (bool, time.Time) { + // any error will cause the function to catch the error and return false. + tx, err := datastore.Begin() + if err != nil { + return false, time.Time{} + } + + defer tx.Rollback() + + locked, t, err := tx.Lock(name, owner, duration, renew) + if err != nil { + return false, time.Time{} + } + + if locked { + if err := tx.Commit(); err != nil { + return false, time.Time{} + } + } + + return locked, t +} diff --git a/pkg/commonerr/errors.go b/pkg/commonerr/errors.go index 1e690eea..6b268d74 100644 --- a/pkg/commonerr/errors.go +++ b/pkg/commonerr/errors.go @@ -16,7 +16,11 @@ // codebase. package commonerr -import "errors" +import ( + "errors" + "fmt" + "strings" +) var ( // ErrFilesystem occurs when a filesystem interaction fails. @@ -45,3 +49,19 @@ func NewBadRequestError(message string) error { func (e *ErrBadRequest) Error() string { return e.s } + +// CombineErrors merges a slice of errors into one separated by ";". If all +// errors are nil, return nil. +func CombineErrors(errs ...error) error { + errStr := []string{} + for i, err := range errs { + if err != nil { + errStr = append(errStr, fmt.Sprintf("[%d] %s", i, err.Error())) + } + } + + if len(errStr) != 0 { + return errors.New(strings.Join(errStr, ";")) + } + return nil +} diff --git a/updater.go b/updater.go index 2e3aa216..792e068b 100644 --- a/updater.go +++ b/updater.go @@ -15,6 +15,7 @@ package clair import ( + "fmt" "math/rand" "strconv" "sync" @@ -53,6 +54,9 @@ var ( Name: "clair_updater_notes_total", Help: "Number of notes that the vulnerability fetchers generated.", }) + + // EnabledUpdaters contains all updaters to be used for update. + EnabledUpdaters []string ) func init() { @@ -63,7 +67,13 @@ func init() { // UpdaterConfig is the configuration for the Updater service. type UpdaterConfig struct { - Interval time.Duration + EnabledUpdaters []string + Interval time.Duration +} + +type vulnerabilityChange struct { + old *database.VulnerabilityWithAffected + new *database.VulnerabilityWithAffected } // RunUpdater begins a process that updates the vulnerability database at @@ -72,7 +82,7 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper defer st.End() // Do not run the updater if there is no config or if the interval is 0. - if config == nil || config.Interval == 0 { + if config == nil || config.Interval == 0 || len(config.EnabledUpdaters) == 0 { log.Info("updater service is disabled.") return } @@ -86,11 +96,11 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper // Determine if this is the first update and define the next update time. // The next update time is (last update time + interval) or now if this is the first update. nextUpdate := time.Now().UTC() - lastUpdate, firstUpdate, err := getLastUpdate(datastore) + lastUpdate, firstUpdate, err := GetLastUpdateTime(datastore) if err != nil { - log.WithError(err).Error("an error occured while getting the last update time") + log.WithError(err).Error("an error occurred while getting the last update time") nextUpdate = nextUpdate.Add(config.Interval) - } else if firstUpdate == false { + } else if !firstUpdate { nextUpdate = lastUpdate.Add(config.Interval) } @@ -98,7 +108,7 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper if nextUpdate.Before(time.Now().UTC()) { // Attempt to get a lock on the the update. log.Debug("attempting to obtain update lock") - hasLock, hasLockUntil := datastore.Lock(updaterLockName, whoAmI, updaterLockDuration, false) + hasLock, hasLockUntil := lock(datastore, updaterLockName, whoAmI, updaterLockDuration, false) if hasLock { // Launch update in a new go routine. doneC := make(chan bool, 1) @@ -113,14 +123,14 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper done = true case <-time.After(updaterLockRefreshDuration): // Refresh the lock until the update is done. - datastore.Lock(updaterLockName, whoAmI, updaterLockDuration, true) + lock(datastore, updaterLockName, whoAmI, updaterLockDuration, true) case <-st.Chan(): stop = true } } - // Unlock the update. - datastore.Unlock(updaterLockName, whoAmI) + // Unlock the updater. + unlock(datastore, updaterLockName, whoAmI) if stop { break @@ -132,10 +142,9 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper break } continue - } else { - lockOwner, lockExpiration, err := datastore.FindLock(updaterLockName) - if err != nil { + lockOwner, lockExpiration, ok, err := findLock(datastore, updaterLockName) + if !ok || err != nil { log.Debug("update lock is already taken") nextUpdate = hasLockUntil } else { @@ -174,40 +183,74 @@ func sleepUpdater(approxWakeup time.Time, st *stopper.Stopper) (stopped bool) { return false } -// update fetches all the vulnerabilities from the registered fetchers, upserts -// them into the database and then sends notifications. +// update fetches all the vulnerabilities from the registered fetchers, updates +// vulnerabilities, and updater flags, and logs notes from updaters. func update(datastore database.Datastore, firstUpdate bool) { defer setUpdaterDuration(time.Now()) log.Info("updating vulnerabilities") // Fetch updates. - status, vulnerabilities, flags, notes := fetch(datastore) + success, vulnerabilities, flags, notes := fetch(datastore) + + // do vulnerability namespacing again to merge potentially duplicated + // vulnerabilities from each updater. + vulnerabilities = doVulnerabilitiesNamespacing(vulnerabilities) + + // deduplicate fetched namespaces and store them into database. + nsMap := map[database.Namespace]struct{}{} + for _, vuln := range vulnerabilities { + nsMap[vuln.Namespace] = struct{}{} + } + + namespaces := make([]database.Namespace, 0, len(nsMap)) + for ns := range nsMap { + namespaces = append(namespaces, ns) + } + + if err := persistNamespaces(datastore, namespaces); err != nil { + log.WithError(err).Error("Unable to insert namespaces") + return + } + + changes, err := updateVulnerabilities(datastore, vulnerabilities) + + defer func() { + if err != nil { + promUpdaterErrorsTotal.Inc() + } + }() - // Insert vulnerabilities. - log.WithField("count", len(vulnerabilities)).Debug("inserting vulnerabilities for update") - err := datastore.InsertVulnerabilities(vulnerabilities, !firstUpdate) if err != nil { - promUpdaterErrorsTotal.Inc() - log.WithError(err).Error("an error occured when inserting vulnerabilities for update") + log.WithError(err).Error("Unable to update vulnerabilities") return } - vulnerabilities = nil - // Update flags. - for flagName, flagValue := range flags { - datastore.InsertKeyValue(flagName, flagValue) + if !firstUpdate { + err = createVulnerabilityNotifications(datastore, changes) + if err != nil { + log.WithError(err).Error("Unable to create notifications") + return + } + } + + err = updateUpdaterFlags(datastore, flags) + if err != nil { + log.WithError(err).Error("Unable to update updater flags") + return } - // Log notes. for _, note := range notes { log.WithField("note", note).Warning("fetcher note") } promUpdaterNotesTotal.Set(float64(len(notes))) - // Update last successful update if every fetchers worked properly. - if status { - datastore.InsertKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10)) + if success { + err = setLastUpdateTime(datastore) + if err != nil { + log.WithError(err).Error("Unable to set last update time") + return + } } log.Info("update finished") @@ -218,8 +261,8 @@ func setUpdaterDuration(start time.Time) { } // fetch get data from the registered fetchers, in parallel. -func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[string]string, []string) { - var vulnerabilities []database.Vulnerability +func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffected, map[string]string, []string) { + var vulnerabilities []database.VulnerabilityWithAffected var notes []string status := true flags := make(map[string]string) @@ -227,12 +270,17 @@ func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[st // Fetch updates in parallel. log.Info("fetching vulnerability updates") var responseC = make(chan *vulnsrc.UpdateResponse, 0) + numUpdaters := 0 for n, u := range vulnsrc.Updaters() { + if !updaterEnabled(n) { + continue + } + numUpdaters++ go func(name string, u vulnsrc.Updater) { response, err := u.Update(datastore) if err != nil { promUpdaterErrorsTotal.Inc() - log.WithError(err).WithField("updater name", name).Error("an error occured when fetching update") + log.WithError(err).WithField("updater name", name).Error("an error occurred when fetching update") status = false responseC <- nil return @@ -244,7 +292,7 @@ func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[st } // Collect results of updates. - for i := 0; i < len(vulnsrc.Updaters()); i++ { + for i := 0; i < numUpdaters; i++ { resp := <-responseC if resp != nil { vulnerabilities = append(vulnerabilities, doVulnerabilitiesNamespacing(resp.Vulnerabilities)...) @@ -259,9 +307,10 @@ func fetch(datastore database.Datastore) (bool, []database.Vulnerability, map[st return status, addMetadata(datastore, vulnerabilities), flags, notes } -// Add metadata to the specified vulnerabilities using the registered MetadataFetchers, in parallel. -func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulnerability) []database.Vulnerability { - if len(vulnmdsrc.Appenders()) == 0 { +// Add metadata to the specified vulnerabilities using the registered +// MetadataFetchers, in parallel. +func addMetadata(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { + if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 { return vulnerabilities } @@ -272,7 +321,7 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner lockableVulnerabilities := make([]*lockableVulnerability, 0, len(vulnerabilities)) for i := 0; i < len(vulnerabilities); i++ { lockableVulnerabilities = append(lockableVulnerabilities, &lockableVulnerability{ - Vulnerability: &vulnerabilities[i], + VulnerabilityWithAffected: &vulnerabilities[i], }) } @@ -286,7 +335,7 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner // Build up a metadata cache. if err := appender.BuildCache(datastore); err != nil { promUpdaterErrorsTotal.Inc() - log.WithError(err).WithField("appender name", name).Error("an error occured when loading metadata fetcher") + log.WithError(err).WithField("appender name", name).Error("an error occurred when loading metadata fetcher") return } @@ -305,13 +354,21 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner return vulnerabilities } -func getLastUpdate(datastore database.Datastore) (time.Time, bool, error) { - lastUpdateTSS, err := datastore.GetKeyValue(updaterLastFlagName) +// GetLastUpdateTime retrieves the latest successful time of update and whether +// or not it's the first update. +func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) { + tx, err := datastore.Begin() + if err != nil { + return time.Time{}, false, err + } + defer tx.Rollback() + + lastUpdateTSS, ok, err := tx.FindKeyValue(updaterLastFlagName) if err != nil { return time.Time{}, false, err } - if lastUpdateTSS == "" { + if !ok { // This is the first update. return time.Time{}, true, nil } @@ -325,7 +382,7 @@ func getLastUpdate(datastore database.Datastore) (time.Time, bool, error) { } type lockableVulnerability struct { - *database.Vulnerability + *database.VulnerabilityWithAffected sync.Mutex } @@ -349,39 +406,293 @@ func (lv *lockableVulnerability) appendFunc(metadataKey string, metadata interfa // doVulnerabilitiesNamespacing takes Vulnerabilities that don't have a // Namespace and split them into multiple vulnerabilities that have a Namespace -// and only contains the FixedIn FeatureVersions corresponding to their +// and only contains the Affected Features corresponding to their // Namespace. // // It helps simplifying the fetchers that share the same metadata about a // Vulnerability regardless of their actual namespace (ie. same vulnerability // information for every version of a distro). -func doVulnerabilitiesNamespacing(vulnerabilities []database.Vulnerability) []database.Vulnerability { - vulnerabilitiesMap := make(map[string]*database.Vulnerability) +// +// It also validates the vulnerabilities fetched from updaters. If any +// vulnerability is mal-formated, the updater process will continue but will log +// warning. +func doVulnerabilitiesNamespacing(vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { + vulnerabilitiesMap := make(map[string]*database.VulnerabilityWithAffected) for _, v := range vulnerabilities { - featureVersions := v.FixedIn - v.FixedIn = []database.FeatureVersion{} - - for _, fv := range featureVersions { - index := fv.Feature.Namespace.Name + ":" + v.Name + namespacedFeatures := v.Affected + v.Affected = []database.AffectedFeature{} + + for _, fv := range namespacedFeatures { + // validate vulnerabilities, throw out the invalid vulnerabilities + if fv.AffectedVersion == "" || fv.FeatureName == "" || fv.Namespace.Name == "" || fv.Namespace.VersionFormat == "" { + log.WithFields(log.Fields{ + "Name": fv.FeatureName, + "Affected Version": fv.AffectedVersion, + "Namespace": fv.Namespace.Name + ":" + fv.Namespace.VersionFormat, + }).Warn("Mal-formated affected feature (skipped)") + continue + } + index := fv.Namespace.Name + ":" + v.Name if vulnerability, ok := vulnerabilitiesMap[index]; !ok { newVulnerability := v - newVulnerability.Namespace = fv.Feature.Namespace - newVulnerability.FixedIn = []database.FeatureVersion{fv} + newVulnerability.Namespace = fv.Namespace + newVulnerability.Affected = []database.AffectedFeature{fv} vulnerabilitiesMap[index] = &newVulnerability } else { - vulnerability.FixedIn = append(vulnerability.FixedIn, fv) + vulnerability.Affected = append(vulnerability.Affected, fv) } } } // Convert map into a slice. - var response []database.Vulnerability - for _, vulnerability := range vulnerabilitiesMap { - response = append(response, *vulnerability) + var response []database.VulnerabilityWithAffected + for _, v := range vulnerabilitiesMap { + // throw out invalid vulnerabilities. + if v.Name == "" || !v.Severity.Valid() || v.Namespace.Name == "" || v.Namespace.VersionFormat == "" { + log.WithFields(log.Fields{ + "Name": v.Name, + "Severity": v.Severity, + "Namespace": v.Namespace.Name + ":" + v.Namespace.VersionFormat, + }).Warning("Vulnerability is mal-formatted") + continue + } + response = append(response, *v) } return response } + +func findLock(datastore database.Datastore, updaterLockName string) (string, time.Time, bool, error) { + tx, err := datastore.Begin() + if err != nil { + log.WithError(err).Error() + } + defer tx.Rollback() + return tx.FindLock(updaterLockName) +} + +// updateUpdaterFlags updates the flags specified by updaters, every transaction +// is independent of each other. +func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error { + for key, value := range flags { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + err = tx.UpdateKeyValue(key, value) + if err != nil { + return err + } + if err = tx.Commit(); err != nil { + return err + } + } + return nil +} + +// setLastUpdateTime records the last successful date time in database. +func setLastUpdateTime(datastore database.Datastore) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + err = tx.UpdateKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10)) + if err != nil { + return err + } + return tx.Commit() +} + +// isVulnerabilityChange compares two vulnerabilities by their severity and +// affected features, and return true if they are different. +func isVulnerabilityChanged(a *database.VulnerabilityWithAffected, b *database.VulnerabilityWithAffected) bool { + if a == b { + return false + } else if a != nil && b != nil && a.Severity == b.Severity && len(a.Affected) == len(b.Affected) { + checked := map[string]bool{} + for _, affected := range a.Affected { + checked[affected.Namespace.Name+":"+affected.FeatureName] = false + } + + for _, affected := range b.Affected { + key := affected.Namespace.Name + ":" + affected.FeatureName + if visited, ok := checked[key]; !ok || visited { + return true + } + checked[key] = true + } + return false + } + return true +} + +// findVulnerabilityChanges finds vulnerability changes from old +// vulnerabilities to new vulnerabilities. +// old and new vulnerabilities should be unique. +func findVulnerabilityChanges(old []database.VulnerabilityWithAffected, new []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { + changes := map[database.VulnerabilityID]vulnerabilityChange{} + for i, vuln := range old { + key := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + + if _, ok := changes[key]; ok { + return nil, fmt.Errorf("duplicated old vulnerability") + } + changes[key] = vulnerabilityChange{old: &old[i]} + } + + for i, vuln := range new { + key := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + + if change, ok := changes[key]; ok { + if isVulnerabilityChanged(change.old, &vuln) { + change.new = &new[i] + changes[key] = change + } else { + delete(changes, key) + } + } else { + changes[key] = vulnerabilityChange{new: &new[i]} + } + } + + vulnChange := make([]vulnerabilityChange, 0, len(changes)) + for _, change := range changes { + vulnChange = append(vulnChange, change) + } + return vulnChange, nil +} + +// createVulnerabilityNotifications makes notifications out of vulnerability +// changes and insert them into database. +func createVulnerabilityNotifications(datastore database.Datastore, changes []vulnerabilityChange) error { + log.WithField("count", len(changes)).Debug("creating vulnerability notifications") + if len(changes) == 0 { + return nil + } + + tx, err := datastore.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + notifications := make([]database.VulnerabilityNotification, 0, len(changes)) + for _, change := range changes { + var oldVuln, newVuln *database.Vulnerability + if change.old != nil { + oldVuln = &change.old.Vulnerability + } + + if change.new != nil { + newVuln = &change.new.Vulnerability + } + + notifications = append(notifications, database.VulnerabilityNotification{ + NotificationHook: database.NotificationHook{ + Name: uuid.New(), + Created: time.Now(), + }, + Old: oldVuln, + New: newVuln, + }) + } + + if err := tx.InsertVulnerabilityNotifications(notifications); err != nil { + return err + } + + return tx.Commit() +} + +// updateVulnerabilities upserts unique vulnerabilities into the database and +// computes vulnerability changes. +func updateVulnerabilities(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { + log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities") + if len(vulnerabilities) == 0 { + return nil, nil + } + + ids := make([]database.VulnerabilityID, 0, len(vulnerabilities)) + for _, vuln := range vulnerabilities { + ids = append(ids, database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + }) + } + + tx, err := datastore.Begin() + if err != nil { + return nil, err + } + + defer tx.Rollback() + oldVulnNullable, err := tx.FindVulnerabilities(ids) + if err != nil { + return nil, err + } + + oldVuln := []database.VulnerabilityWithAffected{} + for _, vuln := range oldVulnNullable { + if vuln.Valid { + oldVuln = append(oldVuln, vuln.VulnerabilityWithAffected) + } + } + + changes, err := findVulnerabilityChanges(oldVuln, vulnerabilities) + if err != nil { + return nil, err + } + + toRemove := []database.VulnerabilityID{} + toAdd := []database.VulnerabilityWithAffected{} + for _, change := range changes { + if change.old != nil { + toRemove = append(toRemove, database.VulnerabilityID{ + Name: change.old.Name, + Namespace: change.old.Namespace.Name, + }) + } + + if change.new != nil { + toAdd = append(toAdd, *change.new) + } + } + + log.WithField("count", len(toRemove)).Debug("marking vulnerabilities as outdated") + if err := tx.DeleteVulnerabilities(toRemove); err != nil { + return nil, err + } + + log.WithField("count", len(toAdd)).Debug("inserting new vulnerabilities") + if err := tx.InsertVulnerabilities(toAdd); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return changes, nil +} + +func updaterEnabled(updaterName string) bool { + for _, u := range EnabledUpdaters { + if u == updaterName { + return true + } + } + return false +} diff --git a/updater_test.go b/updater_test.go index 380ff277..bb2a8e60 100644 --- a/updater_test.go +++ b/updater_test.go @@ -15,6 +15,7 @@ package clair import ( + "errors" "fmt" "testing" @@ -23,49 +24,301 @@ import ( "github.com/coreos/clair/database" ) +type mockUpdaterDatastore struct { + database.MockDatastore + + namespaces map[string]database.Namespace + vulnerabilities map[database.VulnerabilityID]database.VulnerabilityWithAffected + vulnNotification map[string]database.VulnerabilityNotification + keyValues map[string]string +} + +type mockUpdaterSession struct { + database.MockSession + + store *mockUpdaterDatastore + copy mockUpdaterDatastore + terminated bool +} + +func copyUpdaterDatastore(md *mockUpdaterDatastore) mockUpdaterDatastore { + namespaces := map[string]database.Namespace{} + for k, n := range md.namespaces { + namespaces[k] = n + } + + vulnerabilities := map[database.VulnerabilityID]database.VulnerabilityWithAffected{} + for key, v := range md.vulnerabilities { + newV := v + affected := []database.AffectedFeature{} + for _, f := range v.Affected { + affected = append(affected, f) + } + newV.Affected = affected + vulnerabilities[key] = newV + } + + vulnNoti := map[string]database.VulnerabilityNotification{} + for key, v := range md.vulnNotification { + vulnNoti[key] = v + } + + kv := map[string]string{} + for key, value := range md.keyValues { + kv[key] = value + } + + return mockUpdaterDatastore{ + namespaces: namespaces, + vulnerabilities: vulnerabilities, + vulnNotification: vulnNoti, + keyValues: kv, + } +} + +func newmockUpdaterDatastore() *mockUpdaterDatastore { + errSessionDone := errors.New("Session Done") + md := &mockUpdaterDatastore{ + namespaces: make(map[string]database.Namespace), + vulnerabilities: make(map[database.VulnerabilityID]database.VulnerabilityWithAffected), + vulnNotification: make(map[string]database.VulnerabilityNotification), + keyValues: make(map[string]string), + } + + md.FctBegin = func() (database.Session, error) { + session := &mockUpdaterSession{ + store: md, + copy: copyUpdaterDatastore(md), + terminated: false, + } + + session.FctCommit = func() error { + if session.terminated { + return errSessionDone + } + session.store.namespaces = session.copy.namespaces + session.store.vulnerabilities = session.copy.vulnerabilities + session.store.vulnNotification = session.copy.vulnNotification + session.store.keyValues = session.copy.keyValues + session.terminated = true + return nil + } + + session.FctRollback = func() error { + if session.terminated { + return errSessionDone + } + session.terminated = true + session.copy = mockUpdaterDatastore{} + return nil + } + + session.FctPersistNamespaces = func(ns []database.Namespace) error { + if session.terminated { + return errSessionDone + } + for _, n := range ns { + _, ok := session.copy.namespaces[n.Name] + if !ok { + session.copy.namespaces[n.Name] = n + } + } + return nil + } + + session.FctFindVulnerabilities = func(ids []database.VulnerabilityID) ([]database.NullableVulnerability, error) { + r := []database.NullableVulnerability{} + for _, id := range ids { + vuln, ok := session.copy.vulnerabilities[id] + r = append(r, database.NullableVulnerability{ + VulnerabilityWithAffected: vuln, + Valid: ok, + }) + } + return r, nil + } + + session.FctDeleteVulnerabilities = func(ids []database.VulnerabilityID) error { + for _, id := range ids { + delete(session.copy.vulnerabilities, id) + } + return nil + } + + session.FctInsertVulnerabilities = func(vulnerabilities []database.VulnerabilityWithAffected) error { + for _, vuln := range vulnerabilities { + id := database.VulnerabilityID{ + Name: vuln.Name, + Namespace: vuln.Namespace.Name, + } + if _, ok := session.copy.vulnerabilities[id]; ok { + return errors.New("Vulnerability already exists") + } + session.copy.vulnerabilities[id] = vuln + } + return nil + } + + session.FctUpdateKeyValue = func(key, value string) error { + session.copy.keyValues[key] = value + return nil + } + + session.FctFindKeyValue = func(key string) (string, bool, error) { + s, b := session.copy.keyValues[key] + return s, b, nil + } + + session.FctInsertVulnerabilityNotifications = func(notifications []database.VulnerabilityNotification) error { + for _, noti := range notifications { + session.copy.vulnNotification[noti.Name] = noti + } + return nil + } + + return session, nil + } + return md +} + func TestDoVulnerabilitiesNamespacing(t *testing.T) { - fv1 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{Name: "Namespace1"}, - Name: "Feature1", - }, - Version: "0.1", + fv1 := database.AffectedFeature{ + Namespace: database.Namespace{Name: "Namespace1"}, + FeatureName: "Feature1", + FixedInVersion: "0.1", + AffectedVersion: "0.1", } - fv2 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{Name: "Namespace2"}, - Name: "Feature1", - }, - Version: "0.2", + fv2 := database.AffectedFeature{ + Namespace: database.Namespace{Name: "Namespace2"}, + FeatureName: "Feature1", + FixedInVersion: "0.2", + AffectedVersion: "0.2", } - fv3 := database.FeatureVersion{ - Feature: database.Feature{ - Namespace: database.Namespace{Name: "Namespace2"}, - Name: "Feature2", - }, - Version: "0.3", + fv3 := database.AffectedFeature{ + + Namespace: database.Namespace{Name: "Namespace2"}, + FeatureName: "Feature2", + FixedInVersion: "0.3", + AffectedVersion: "0.3", } - vulnerability := database.Vulnerability{ - Name: "DoVulnerabilityNamespacing", - FixedIn: []database.FeatureVersion{fv1, fv2, fv3}, + vulnerability := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "DoVulnerabilityNamespacing", + }, + Affected: []database.AffectedFeature{fv1, fv2, fv3}, } - vulnerabilities := doVulnerabilitiesNamespacing([]database.Vulnerability{vulnerability}) + vulnerabilities := doVulnerabilitiesNamespacing([]database.VulnerabilityWithAffected{vulnerability}) for _, vulnerability := range vulnerabilities { switch vulnerability.Namespace.Name { - case fv1.Feature.Namespace.Name: - assert.Len(t, vulnerability.FixedIn, 1) - assert.Contains(t, vulnerability.FixedIn, fv1) - case fv2.Feature.Namespace.Name: - assert.Len(t, vulnerability.FixedIn, 2) - assert.Contains(t, vulnerability.FixedIn, fv2) - assert.Contains(t, vulnerability.FixedIn, fv3) + case fv1.Namespace.Name: + assert.Len(t, vulnerability.Affected, 1) + assert.Contains(t, vulnerability.Affected, fv1) + case fv2.Namespace.Name: + assert.Len(t, vulnerability.Affected, 2) + assert.Contains(t, vulnerability.Affected, fv2) + assert.Contains(t, vulnerability.Affected, fv3) default: t.Errorf("Should not have a Vulnerability with '%s' as its Namespace.", vulnerability.Namespace.Name) fmt.Printf("%#v\n", vulnerability) } } } + +func TestCreatVulnerabilityNotification(t *testing.T) { + vf1 := "VersionFormat1" + ns1 := database.Namespace{ + Name: "namespace 1", + VersionFormat: vf1, + } + af1 := database.AffectedFeature{ + Namespace: ns1, + FeatureName: "feature 1", + } + + v1 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "vulnerability 1", + Namespace: ns1, + Severity: database.UnknownSeverity, + }, + } + + // severity change + v2 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "vulnerability 1", + Namespace: ns1, + Severity: database.LowSeverity, + }, + } + + // affected versions change + v3 := database.VulnerabilityWithAffected{ + Vulnerability: database.Vulnerability{ + Name: "vulnerability 1", + Namespace: ns1, + Severity: database.UnknownSeverity, + }, + Affected: []database.AffectedFeature{af1}, + } + + datastore := newmockUpdaterDatastore() + change, err := updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{}) + assert.Nil(t, err) + assert.Len(t, change, 0) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) + assert.Nil(t, err) + assert.Len(t, change, 1) + assert.Nil(t, change[0].old) + assertVulnerability(t, *change[0].new, v1) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) + assert.Nil(t, err) + assert.Len(t, change, 0) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v2}) + assert.Nil(t, err) + assert.Len(t, change, 1) + assertVulnerability(t, *change[0].new, v2) + assertVulnerability(t, *change[0].old, v1) + + change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v3}) + assert.Nil(t, err) + assert.Len(t, change, 1) + assertVulnerability(t, *change[0].new, v3) + assertVulnerability(t, *change[0].old, v2) + + err = createVulnerabilityNotifications(datastore, change) + assert.Nil(t, err) + assert.Len(t, datastore.vulnNotification, 1) + for _, noti := range datastore.vulnNotification { + assert.Equal(t, *noti.New, v3.Vulnerability) + assert.Equal(t, *noti.Old, v2.Vulnerability) + } +} + +func assertVulnerability(t *testing.T, expected database.VulnerabilityWithAffected, actual database.VulnerabilityWithAffected) bool { + expectedAF := expected.Affected + actualAF := actual.Affected + expected.Affected, actual.Affected = nil, nil + + assert.Equal(t, expected, actual) + assert.Len(t, actualAF, len(expectedAF)) + + mapAF := map[database.AffectedFeature]bool{} + for _, af := range expectedAF { + mapAF[af] = false + } + + for _, af := range actualAF { + if visited, ok := mapAF[af]; !ok || visited { + return false + } + } + return true +} diff --git a/worker.go b/worker.go index d407253f..a9c82762 100644 --- a/worker.go +++ b/worker.go @@ -15,7 +15,9 @@ package clair import ( + "errors" "regexp" + "sync" log "github.com/sirupsen/logrus" @@ -24,13 +26,10 @@ import ( "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/imagefmt" "github.com/coreos/clair/pkg/commonerr" - "github.com/coreos/clair/pkg/tarutil" + "github.com/coreos/clair/pkg/strutil" ) const ( - // Version (integer) represents the worker version. - // Increased each time the engine changes. - Version = 3 logLayerName = "layer" ) @@ -44,177 +43,525 @@ var ( ErrParentUnknown = commonerr.NewBadRequestError("worker: parent layer is unknown, it must be processed first") urlParametersRegexp = regexp.MustCompile(`(\?|\&)([^=]+)\=([^ &]+)`) + + // Processors contain the names of namespace detectors and feature listers + // enabled in this instance of Clair. + // + // Processors are initialized during booting and configured in the + // configuration file. + Processors database.Processors ) +type WorkerConfig struct { + EnabledDetectors []string `yaml:"namespace_detectors"` + EnabledListers []string `yaml:"feature_listers"` +} + +// LayerRequest represents all information necessary to download and process a +// layer. +type LayerRequest struct { + Hash string + Path string + Headers map[string]string +} + +// partialLayer stores layer's content detected by `processedBy` processors. +type partialLayer struct { + hash string + processedBy database.Processors + namespaces []database.Namespace + features []database.Feature + + err error +} + +// processRequest stores parameters used for processing layers. +type processRequest struct { + request LayerRequest + // notProcessedBy represents a set of processors used to process the + // request. + notProcessedBy database.Processors +} + // cleanURL removes all parameters from an URL. func cleanURL(str string) string { return urlParametersRegexp.ReplaceAllString(str, "") } -// ProcessLayer detects the Namespace of a layer, the features it adds/removes, -// and then stores everything in the database. -// -// TODO(Quentin-M): We could have a goroutine that looks for layers that have -// been analyzed with an older engine version and that processes them. -func ProcessLayer(datastore database.Datastore, imageFormat, name, parentName, path string, headers map[string]string) error { - // Verify parameters. - if name == "" { - return commonerr.NewBadRequestError("could not process a layer which does not have a name") +// processLayers in parallel processes a set of requests for unique set of layers +// and returns sets of unique namespaces, features and layers to be inserted +// into the database. +func processRequests(imageFormat string, toDetect []processRequest) ([]database.Namespace, []database.Feature, map[string]partialLayer, error) { + wg := &sync.WaitGroup{} + wg.Add(len(toDetect)) + results := make([]partialLayer, len(toDetect)) + for i := range toDetect { + go func(req *processRequest, res *partialLayer) { + res.hash = req.request.Hash + res.processedBy = req.notProcessedBy + res.namespaces, res.features, res.err = detectContent(imageFormat, req.request.Hash, req.request.Path, req.request.Headers, req.notProcessedBy) + wg.Done() + }(&toDetect[i], &results[i]) } + wg.Wait() + distinctNS := map[database.Namespace]struct{}{} + distinctF := map[database.Feature]struct{}{} - if path == "" { - return commonerr.NewBadRequestError("could not process a layer which does not have a path") + errs := []error{} + for _, r := range results { + errs = append(errs, r.err) } - if imageFormat == "" { - return commonerr.NewBadRequestError("could not process a layer which does not have a format") + if err := commonerr.CombineErrors(errs...); err != nil { + return nil, nil, nil, err } - log.WithFields(log.Fields{logLayerName: name, "path": cleanURL(path), "engine version": Version, "parent layer": parentName, "format": imageFormat}).Debug("processing layer") + updates := map[string]partialLayer{} + for _, r := range results { + for _, ns := range r.namespaces { + distinctNS[ns] = struct{}{} + } - // Check to see if the layer is already in the database. - layer, err := datastore.FindLayer(name, false, false) - if err != nil && err != commonerr.ErrNotFound { - return err + for _, f := range r.features { + distinctF[f] = struct{}{} + } + + if _, ok := updates[r.hash]; !ok { + updates[r.hash] = r + } else { + return nil, nil, nil, errors.New("Duplicated updates is not allowed") + } } - if err == commonerr.ErrNotFound { - // New layer case. - layer = database.Layer{Name: name, EngineVersion: Version} + namespaces := make([]database.Namespace, 0, len(distinctNS)) + features := make([]database.Feature, 0, len(distinctF)) - // Retrieve the parent if it has one. - // We need to get it with its Features in order to diff them. - if parentName != "" { - parent, err := datastore.FindLayer(parentName, true, false) - if err != nil && err != commonerr.ErrNotFound { - return err - } - if err == commonerr.ErrNotFound { - log.WithFields(log.Fields{logLayerName: name, "parent layer": parentName}).Warning("the parent layer is unknown. it must be processed first") - return ErrParentUnknown - } - layer.Parent = &parent + for ns := range distinctNS { + namespaces = append(namespaces, ns) + } + + for f := range distinctF { + features = append(features, f) + } + return namespaces, features, updates, nil +} + +func getLayer(datastore database.Datastore, req LayerRequest) (layer database.LayerWithContent, preq *processRequest, err error) { + var ok bool + tx, err := datastore.Begin() + if err != nil { + return + } + defer tx.Rollback() + + layer, ok, err = tx.FindLayerWithContent(req.Hash) + if err != nil { + return + } + + if !ok { + l := database.Layer{Hash: req.Hash} + err = tx.PersistLayer(l) + if err != nil { + return + } + + if err = tx.Commit(); err != nil { + return + } + + layer = database.LayerWithContent{Layer: l} + preq = &processRequest{ + request: req, + notProcessedBy: Processors, } } else { - // The layer is already in the database, check if we need to update it. - if layer.EngineVersion >= Version { - log.WithFields(log.Fields{logLayerName: name, "past engine version": layer.EngineVersion, "current engine version": Version}).Debug("layer content has already been processed in the past with older engine. skipping analysis") - return nil + notProcessed := getNotProcessedBy(layer.ProcessedBy) + if !(len(notProcessed.Detectors) == 0 && len(notProcessed.Listers) == 0 && ok) { + preq = &processRequest{ + request: req, + notProcessedBy: notProcessed, + } } - log.WithFields(log.Fields{logLayerName: name, "past engine version": layer.EngineVersion, "current engine version": Version}).Debug("layer content has already been processed in the past with older engine. analyzing again") } + return +} - // Analyze the content. - layer.Namespaces, layer.Features, err = detectContent(imageFormat, name, path, headers, layer.Parent) +// processLayers processes a set of post layer requests, stores layers and +// returns an ordered list of processed layers with detected features and +// namespaces. +func processLayers(datastore database.Datastore, imageFormat string, requests []LayerRequest) ([]database.LayerWithContent, error) { + toDetect := []processRequest{} + layers := map[string]database.LayerWithContent{} + for _, req := range requests { + if _, ok := layers[req.Hash]; ok { + continue + } + layer, preq, err := getLayer(datastore, req) + if err != nil { + return nil, err + } + layers[req.Hash] = layer + if preq != nil { + toDetect = append(toDetect, *preq) + } + } + + namespaces, features, partialRes, err := processRequests(imageFormat, toDetect) + if err != nil { + return nil, err + } + + // Store partial results. + if err := persistNamespaces(datastore, namespaces); err != nil { + return nil, err + } + + if err := persistFeatures(datastore, features); err != nil { + return nil, err + } + + for _, res := range partialRes { + if err := persistPartialLayer(datastore, res); err != nil { + return nil, err + } + } + + // NOTE(Sida): The full layers are computed using partially + // processed layers in current database session. If any other instances of + // Clair are changing some layers in this set of layers, it might generate + // different results especially when the other Clair is with different + // processors. + completeLayers := []database.LayerWithContent{} + for _, req := range requests { + if partialLayer, ok := partialRes[req.Hash]; ok { + completeLayers = append(completeLayers, combineLayers(layers[req.Hash], partialLayer)) + } else { + completeLayers = append(completeLayers, layers[req.Hash]) + } + } + + return completeLayers, nil +} + +func persistPartialLayer(datastore database.Datastore, layer partialLayer) error { + tx, err := datastore.Begin() if err != nil { return err } + defer tx.Rollback() - return datastore.InsertLayer(layer) + if err := tx.PersistLayerContent(layer.hash, layer.namespaces, layer.features, layer.processedBy); err != nil { + return err + } + return tx.Commit() } -// detectContent downloads a layer's archive and extracts its Namespace and -// Features. -func detectContent(imageFormat, name, path string, headers map[string]string, parent *database.Layer) (namespaces []database.Namespace, featureVersions []database.FeatureVersion, err error) { - totalRequiredFiles := append(featurefmt.RequiredFilenames(), featurens.RequiredFilenames()...) - files, err := imagefmt.Extract(imageFormat, path, headers, totalRequiredFiles) +func persistFeatures(datastore database.Datastore, features []database.Feature) error { + tx, err := datastore.Begin() if err != nil { - log.WithError(err).WithFields(log.Fields{logLayerName: name, "path": cleanURL(path)}).Error("failed to extract data from path") - return + return err + } + defer tx.Rollback() + + if err := tx.PersistFeatures(features); err != nil { + return err } + return tx.Commit() +} - namespaces, err = detectNamespaces(name, files, parent) +func persistNamespaces(datastore database.Datastore, namespaces []database.Namespace) error { + tx, err := datastore.Begin() if err != nil { - return + return err + } + defer tx.Rollback() + + if err := tx.PersistNamespaces(namespaces); err != nil { + return err } - featureVersions, err = detectFeatureVersions(name, files, namespaces, parent) + return tx.Commit() +} + +// combineLayers merges `layer` and `partial` without duplicated content. +func combineLayers(layer database.LayerWithContent, partial partialLayer) database.LayerWithContent { + mapF := map[database.Feature]struct{}{} + mapNS := map[database.Namespace]struct{}{} + for _, f := range layer.Features { + mapF[f] = struct{}{} + } + for _, ns := range layer.Namespaces { + mapNS[ns] = struct{}{} + } + for _, f := range partial.features { + mapF[f] = struct{}{} + } + for _, ns := range partial.namespaces { + mapNS[ns] = struct{}{} + } + features := make([]database.Feature, 0, len(mapF)) + namespaces := make([]database.Namespace, 0, len(mapNS)) + for f := range mapF { + features = append(features, f) + } + for ns := range mapNS { + namespaces = append(namespaces, ns) + } + + layer.ProcessedBy.Detectors = append(layer.ProcessedBy.Detectors, strutil.CompareStringLists(partial.processedBy.Detectors, layer.ProcessedBy.Detectors)...) + layer.ProcessedBy.Listers = append(layer.ProcessedBy.Listers, strutil.CompareStringLists(partial.processedBy.Listers, layer.ProcessedBy.Listers)...) + return database.LayerWithContent{ + Layer: database.Layer{ + Hash: layer.Hash, + }, + ProcessedBy: layer.ProcessedBy, + Features: features, + Namespaces: namespaces, + } +} + +func isAncestryProcessed(datastore database.Datastore, name string) (bool, error) { + tx, err := datastore.Begin() if err != nil { - return + return false, err + } + defer tx.Rollback() + _, processed, ok, err := tx.FindAncestry(name) + if err != nil { + return false, err + } + if !ok { + return false, nil } - if len(featureVersions) > 0 { - log.WithFields(log.Fields{logLayerName: name, "feature count": len(featureVersions)}).Debug("detected features") + notProcessed := getNotProcessedBy(processed) + return len(notProcessed.Detectors) == 0 && len(notProcessed.Listers) == 0, nil +} + +// ProcessAncestry downloads and scans an ancestry if it's not scanned by all +// enabled processors in this instance of Clair. +func ProcessAncestry(datastore database.Datastore, imageFormat, name string, layerRequest []LayerRequest) error { + var err error + if name == "" { + return commonerr.NewBadRequestError("could not process a layer which does not have a name") } - return + if imageFormat == "" { + return commonerr.NewBadRequestError("could not process a layer which does not have a format") + } + + if ok, err := isAncestryProcessed(datastore, name); ok && err == nil { + log.WithField("ancestry", name).Debug("Ancestry is processed") + return nil + } else if err != nil { + return err + } + + layers, err := processLayers(datastore, imageFormat, layerRequest) + if err != nil { + return err + } + + if !validateProcessors(layers) { + // This error might be triggered because of multiple workers are + // processing the same instance with different processors. + return errors.New("ancestry layers are scanned with different listers and detectors") + } + + return processAncestry(datastore, name, layers) } -// detectNamespaces returns a list of unique namespaces detected in a layer and its ancestry. -func detectNamespaces(name string, files tarutil.FilesMap, parent *database.Layer) (namespaces []database.Namespace, err error) { - nsSet := map[string]*database.Namespace{} - nsCurrent, err := featurens.Detect(files) +func processAncestry(datastore database.Datastore, name string, layers []database.LayerWithContent) error { + ancestryFeatures, err := computeAncestryFeatures(layers) if err != nil { - return + return err } - if parent != nil { - for _, ns := range parent.Namespaces { - // Under assumption that one version format corresponds to one type - // of namespace. - nsSet[ns.VersionFormat] = &ns - log.WithFields(log.Fields{logLayerName: name, "detected namespace": ns.Name, "version format": ns.VersionFormat}).Debug("detected namespace (from parent)") - } + ancestryLayers := make([]database.Layer, 0, len(layers)) + for _, layer := range layers { + ancestryLayers = append(ancestryLayers, layer.Layer) } - for _, ns := range nsCurrent { - nsSet[ns.VersionFormat] = &ns - log.WithFields(log.Fields{logLayerName: name, "detected namespace": ns.Name, "version format": ns.VersionFormat}).Debug("detected namespace") + log.WithFields(log.Fields{ + "ancestry": name, + "number of features": len(ancestryFeatures), + "processed by": Processors, + "number of layers": len(ancestryLayers), + }).Debug("compute ancestry features") + + if err := persistNamespacedFeatures(datastore, ancestryFeatures); err != nil { + return err } - for _, ns := range nsSet { - namespaces = append(namespaces, *ns) + tx, err := datastore.Begin() + if err != nil { + return err } - return + + err = tx.UpsertAncestry(database.Ancestry{Name: name, Layers: ancestryLayers}, ancestryFeatures, Processors) + if err != nil { + tx.Rollback() + return err + } + + err = tx.Commit() + if err != nil { + return err + } + return nil +} + +func persistNamespacedFeatures(datastore database.Datastore, features []database.NamespacedFeature) error { + tx, err := datastore.Begin() + if err != nil { + return err + } + + if err := tx.PersistNamespacedFeatures(features); err != nil { + tx.Rollback() + return err + } + + if err := tx.Commit(); err != nil { + return err + } + + tx, err = datastore.Begin() + if err != nil { + return err + } + + if err := tx.CacheAffectedNamespacedFeatures(features); err != nil { + tx.Rollback() + return err + } + + return tx.Commit() } -func detectFeatureVersions(name string, files tarutil.FilesMap, namespaces []database.Namespace, parent *database.Layer) (features []database.FeatureVersion, err error) { - // Build a map of the namespaces for each FeatureVersion in our parent layer. - parentFeatureNamespaces := make(map[string]database.Namespace) - if parent != nil { - for _, parentFeature := range parent.Features { - parentFeatureNamespaces[parentFeature.Feature.Name+":"+parentFeature.Version] = parentFeature.Feature.Namespace +// validateProcessors checks if the layers processed by same set of processors. +func validateProcessors(layers []database.LayerWithContent) bool { + if len(layers) == 0 { + return true + } + detectors := layers[0].ProcessedBy.Detectors + listers := layers[0].ProcessedBy.Listers + + for _, l := range layers[1:] { + if len(strutil.CompareStringLists(detectors, l.ProcessedBy.Detectors)) != 0 || + len(strutil.CompareStringLists(listers, l.ProcessedBy.Listers)) != 0 { + return false } } + return true +} - for _, ns := range namespaces { - // TODO(Quentin-M): We need to pass the parent image to DetectFeatures because it's possible that - // some detectors would need it in order to produce the entire feature list (if they can only - // detect a diff). Also, we should probably pass the detected namespace so detectors could - // make their own decision. - detectedFeatures, err := featurefmt.ListFeatures(files, &ns) - if err != nil { - return features, err +// computeAncestryFeatures computes the features in an ancestry based on all +// layers. +func computeAncestryFeatures(ancestryLayers []database.LayerWithContent) ([]database.NamespacedFeature, error) { + // version format -> namespace + namespaces := map[string]database.Namespace{} + // version format -> feature ID -> feature + features := map[string]map[string]database.NamespacedFeature{} + for _, layer := range ancestryLayers { + // At start of the loop, namespaces and features always contain the + // previous layer's result. + for _, ns := range layer.Namespaces { + namespaces[ns.VersionFormat] = ns } - // Ensure that each FeatureVersion has an associated Namespace. - for i, feature := range detectedFeatures { - if feature.Feature.Namespace.Name != "" { - // There is a Namespace associated. - continue + // version format -> feature ID -> feature + currentFeatures := map[string]map[string]database.NamespacedFeature{} + for _, f := range layer.Features { + if ns, ok := namespaces[f.VersionFormat]; ok { + var currentMap map[string]database.NamespacedFeature + if currentMap, ok = currentFeatures[f.VersionFormat]; !ok { + currentFeatures[f.VersionFormat] = make(map[string]database.NamespacedFeature) + currentMap = currentFeatures[f.VersionFormat] + } + + inherited := false + if mapF, ok := features[f.VersionFormat]; ok { + if parentFeature, ok := mapF[f.Name+":"+f.Version]; ok { + currentMap[f.Name+":"+f.Version] = parentFeature + inherited = true + } + } + + if !inherited { + currentMap[f.Name+":"+f.Version] = database.NamespacedFeature{ + Feature: f, + Namespace: ns, + } + } + + } else { + return nil, errors.New("No corresponding version format") } + } - if parentFeatureNamespace, ok := parentFeatureNamespaces[feature.Feature.Name+":"+feature.Version]; ok { - // The FeatureVersion is present in the parent layer; associate - // with their Namespace. - // This might cause problem because a package with same feature - // name and version could be different in parent layer's - // namespace and current layer's namespace - detectedFeatures[i].Feature.Namespace = parentFeatureNamespace - continue - } + // NOTE(Sida): we update the feature map in some version format + // only if there's at least one feature with that version format. This + // approach won't differentiate feature file removed vs all detectable + // features removed from that file vs feature file not changed. + // + // One way to differentiate (feature file removed or not changed) vs + // all detectable features removed is to pass in the file status. + for vf, mapF := range currentFeatures { + features[vf] = mapF + } + } - detectedFeatures[i].Feature.Namespace = ns + ancestryFeatures := []database.NamespacedFeature{} + for _, featureMap := range features { + for _, feature := range featureMap { + ancestryFeatures = append(ancestryFeatures, feature) } - features = append(features, detectedFeatures...) + } + return ancestryFeatures, nil +} + +// getNotProcessedBy returns a processors, which contains the detectors and +// listers not in `processedBy` but implemented in the current clair instance. +func getNotProcessedBy(processedBy database.Processors) database.Processors { + notProcessedLister := strutil.CompareStringLists(Processors.Listers, processedBy.Listers) + notProcessedDetector := strutil.CompareStringLists(Processors.Detectors, processedBy.Detectors) + return database.Processors{ + Listers: notProcessedLister, + Detectors: notProcessedDetector, + } +} + +// detectContent downloads a layer and detects all features and namespaces. +func detectContent(imageFormat, name, path string, headers map[string]string, toProcess database.Processors) (namespaces []database.Namespace, featureVersions []database.Feature, err error) { + log.WithFields(log.Fields{"Hash": name}).Debug("Process Layer") + totalRequiredFiles := append(featurefmt.RequiredFilenames(toProcess.Listers), featurens.RequiredFilenames(toProcess.Detectors)...) + files, err := imagefmt.Extract(imageFormat, path, headers, totalRequiredFiles) + if err != nil { + log.WithError(err).WithFields(log.Fields{ + logLayerName: name, + "path": cleanURL(path), + }).Error("failed to extract data from path") + return + } + + namespaces, err = featurens.Detect(files, toProcess.Detectors) + if err != nil { + return + } + + if len(featureVersions) > 0 { + log.WithFields(log.Fields{logLayerName: name, "count": len(namespaces)}).Debug("detected layer namespaces") } - // If there are no FeatureVersions, use parent's FeatureVersions if possible. - // TODO(Quentin-M): We eventually want to give the choice to each detectors to use none/some of - // their parent's FeatureVersions. It would be useful for detectors that can't find their entire - // result using one Layer. - if len(features) == 0 && parent != nil { - features = parent.Features + featureVersions, err = featurefmt.ListFeatures(files, toProcess.Listers) + if err != nil { + return + } + + if len(featureVersions) > 0 { + log.WithFields(log.Fields{logLayerName: name, "count": len(featureVersions)}).Debug("detected layer features") } return diff --git a/worker_test.go b/worker_test.go index 950c7689..5f6e0ff4 100644 --- a/worker_test.go +++ b/worker_test.go @@ -15,18 +15,23 @@ package clair import ( + "errors" "path/filepath" "runtime" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/coreos/clair/database" + "github.com/coreos/clair/ext/featurefmt" + "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/versionfmt/dpkg" - "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/strutil" // Register the required detectors. _ "github.com/coreos/clair/ext/featurefmt/dpkg" + _ "github.com/coreos/clair/ext/featurefmt/rpm" _ "github.com/coreos/clair/ext/featurens/aptsources" _ "github.com/coreos/clair/ext/featurens/osrelease" _ "github.com/coreos/clair/ext/imagefmt/docker" @@ -34,42 +39,306 @@ import ( type mockDatastore struct { database.MockDatastore - layers map[string]database.Layer + + layers map[string]database.LayerWithContent + ancestry map[string]database.AncestryWithFeatures + namespaces map[string]database.Namespace + features map[string]database.Feature + namespacedFeatures map[string]database.NamespacedFeature } -func newMockDatastore() *mockDatastore { - return &mockDatastore{ - layers: make(map[string]database.Layer), - } +type mockSession struct { + database.MockSession + + store *mockDatastore + copy mockDatastore + terminated bool } -func TestProcessWithDistUpgrade(t *testing.T) { - _, f, _, _ := runtime.Caller(0) - testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" +func copyDatastore(md *mockDatastore) mockDatastore { + layers := map[string]database.LayerWithContent{} + for k, l := range md.layers { + features := append([]database.Feature(nil), l.Features...) + namespaces := append([]database.Namespace(nil), l.Namespaces...) + listers := append([]string(nil), l.ProcessedBy.Listers...) + detectors := append([]string(nil), l.ProcessedBy.Detectors...) + layers[k] = database.LayerWithContent{ + Layer: database.Layer{ + Hash: l.Hash, + }, + ProcessedBy: database.Processors{ + Listers: listers, + Detectors: detectors, + }, + Features: features, + Namespaces: namespaces, + } + } - // Create a mock datastore. - datastore := newMockDatastore() - datastore.FctInsertLayer = func(layer database.Layer) error { - datastore.layers[layer.Name] = layer - return nil + ancestry := map[string]database.AncestryWithFeatures{} + for k, a := range md.ancestry { + nf := append([]database.NamespacedFeature(nil), a.Features...) + l := append([]database.Layer(nil), a.Layers...) + listers := append([]string(nil), a.ProcessedBy.Listers...) + detectors := append([]string(nil), a.ProcessedBy.Detectors...) + ancestry[k] = database.AncestryWithFeatures{ + Ancestry: database.Ancestry{ + Name: a.Name, + Layers: l, + }, + ProcessedBy: database.Processors{ + Detectors: detectors, + Listers: listers, + }, + Features: nf, + } + } + + namespaces := map[string]database.Namespace{} + for k, n := range md.namespaces { + namespaces[k] = n + } + + features := map[string]database.Feature{} + for k, f := range md.features { + features[k] = f + } + + namespacedFeatures := map[string]database.NamespacedFeature{} + for k, f := range md.namespacedFeatures { + namespacedFeatures[k] = f + } + return mockDatastore{ + layers: layers, + ancestry: ancestry, + namespaces: namespaces, + namespacedFeatures: namespacedFeatures, + features: features, + } +} + +func newMockDatastore() *mockDatastore { + errSessionDone := errors.New("Session Done") + md := &mockDatastore{ + layers: make(map[string]database.LayerWithContent), + ancestry: make(map[string]database.AncestryWithFeatures), + namespaces: make(map[string]database.Namespace), + features: make(map[string]database.Feature), + namespacedFeatures: make(map[string]database.NamespacedFeature), } - datastore.FctFindLayer = func(name string, withFeatures, withVulnerabilities bool) (database.Layer, error) { - if layer, exists := datastore.layers[name]; exists { - return layer, nil + + md.FctBegin = func() (database.Session, error) { + session := &mockSession{ + store: md, + copy: copyDatastore(md), + terminated: false, + } + + session.FctCommit = func() error { + if session.terminated { + return nil + } + session.store.layers = session.copy.layers + session.store.ancestry = session.copy.ancestry + session.store.namespaces = session.copy.namespaces + session.store.features = session.copy.features + session.store.namespacedFeatures = session.copy.namespacedFeatures + session.terminated = true + return nil + } + + session.FctRollback = func() error { + if session.terminated { + return nil + } + session.terminated = true + session.copy = mockDatastore{} + return nil + } + + session.FctFindAncestry = func(name string) (database.Ancestry, database.Processors, bool, error) { + processors := database.Processors{} + if session.terminated { + return database.Ancestry{}, processors, false, errSessionDone + } + ancestry, ok := session.copy.ancestry[name] + return ancestry.Ancestry, ancestry.ProcessedBy, ok, nil + } + + session.FctFindLayer = func(name string) (database.Layer, database.Processors, bool, error) { + processors := database.Processors{} + if session.terminated { + return database.Layer{}, processors, false, errSessionDone + } + layer, ok := session.copy.layers[name] + return layer.Layer, layer.ProcessedBy, ok, nil + } + + session.FctFindLayerWithContent = func(name string) (database.LayerWithContent, bool, error) { + if session.terminated { + return database.LayerWithContent{}, false, errSessionDone + } + layer, ok := session.copy.layers[name] + return layer, ok, nil + } + + session.FctPersistLayer = func(layer database.Layer) error { + if session.terminated { + return errSessionDone + } + if _, ok := session.copy.layers[layer.Hash]; !ok { + session.copy.layers[layer.Hash] = database.LayerWithContent{Layer: layer} + } + return nil + } + + session.FctPersistNamespaces = func(ns []database.Namespace) error { + if session.terminated { + return errSessionDone + } + for _, n := range ns { + _, ok := session.copy.namespaces[n.Name] + if !ok { + session.copy.namespaces[n.Name] = n + } + } + return nil + } + + session.FctPersistFeatures = func(fs []database.Feature) error { + if session.terminated { + return errSessionDone + } + for _, f := range fs { + key := FeatureKey(&f) + _, ok := session.copy.features[key] + if !ok { + session.copy.features[key] = f + } + } + return nil } - return database.Layer{}, commonerr.ErrNotFound + + session.FctPersistLayerContent = func(hash string, namespaces []database.Namespace, features []database.Feature, processedBy database.Processors) error { + if session.terminated { + return errSessionDone + } + + // update the layer + layer, ok := session.copy.layers[hash] + if !ok { + return errors.New("layer not found") + } + + layerFeatures := map[string]database.Feature{} + layerNamespaces := map[string]database.Namespace{} + for _, f := range layer.Features { + layerFeatures[FeatureKey(&f)] = f + } + for _, n := range layer.Namespaces { + layerNamespaces[n.Name] = n + } + + // ensure that all the namespaces, features are in the database + for _, ns := range namespaces { + if _, ok := session.copy.namespaces[ns.Name]; !ok { + return errors.New("Namespaces should be in the database") + } + if _, ok := layerNamespaces[ns.Name]; !ok { + layer.Namespaces = append(layer.Namespaces, ns) + layerNamespaces[ns.Name] = ns + } + } + + for _, f := range features { + if _, ok := session.copy.features[FeatureKey(&f)]; !ok { + return errors.New("Namespaces should be in the database") + } + if _, ok := layerFeatures[FeatureKey(&f)]; !ok { + layer.Features = append(layer.Features, f) + layerFeatures[FeatureKey(&f)] = f + } + } + + layer.ProcessedBy.Detectors = append(layer.ProcessedBy.Detectors, strutil.CompareStringLists(processedBy.Detectors, layer.ProcessedBy.Detectors)...) + layer.ProcessedBy.Listers = append(layer.ProcessedBy.Listers, strutil.CompareStringLists(processedBy.Listers, layer.ProcessedBy.Listers)...) + + session.copy.layers[hash] = layer + return nil + } + + session.FctUpsertAncestry = func(ancestry database.Ancestry, features []database.NamespacedFeature, processors database.Processors) error { + if session.terminated { + return errSessionDone + } + + // ensure features are in the database + for _, f := range features { + if _, ok := session.copy.namespacedFeatures[NamespacedFeatureKey(&f)]; !ok { + return errors.New("namepsaced feature not in db") + } + } + + ancestryWFeature := database.AncestryWithFeatures{ + Ancestry: ancestry, + Features: features, + ProcessedBy: processors, + } + + session.copy.ancestry[ancestry.Name] = ancestryWFeature + return nil + } + + session.FctPersistNamespacedFeatures = func(namespacedFeatures []database.NamespacedFeature) error { + for i, f := range namespacedFeatures { + session.copy.namespacedFeatures[NamespacedFeatureKey(&f)] = namespacedFeatures[i] + } + return nil + } + + session.FctCacheAffectedNamespacedFeatures = func(namespacedFeatures []database.NamespacedFeature) error { + // The function does nothing because we don't care about the vulnerability cache in worker_test. + return nil + } + + return session, nil } + return md +} - // Create the list of FeatureVersions that should not been upgraded from one layer to another. - nonUpgradedFeatureVersions := []database.FeatureVersion{ - {Feature: database.Feature{Name: "libtext-wrapi18n-perl"}, Version: "0.06-7"}, - {Feature: database.Feature{Name: "libtext-charwidth-perl"}, Version: "0.04-7"}, - {Feature: database.Feature{Name: "libtext-iconv-perl"}, Version: "1.7-5"}, - {Feature: database.Feature{Name: "mawk"}, Version: "1.3.3-17"}, - {Feature: database.Feature{Name: "insserv"}, Version: "1.14.0-5"}, - {Feature: database.Feature{Name: "db"}, Version: "5.1.29-5"}, - {Feature: database.Feature{Name: "ustr"}, Version: "1.0.4-3"}, - {Feature: database.Feature{Name: "xz-utils"}, Version: "5.1.1alpha+20120614-2"}, +func TestMain(m *testing.M) { + Processors = database.Processors{ + Listers: featurefmt.ListListers(), + Detectors: featurens.ListDetectors(), + } + m.Run() +} + +func FeatureKey(f *database.Feature) string { + return strings.Join([]string{f.Name, f.VersionFormat, f.Version}, "__") +} + +func NamespacedFeatureKey(f *database.NamespacedFeature) string { + return strings.Join([]string{f.Name, f.Namespace.Name}, "__") +} + +func TestProcessAncestryWithDistUpgrade(t *testing.T) { + // Create the list of Features that should not been upgraded from one layer to another. + nonUpgradedFeatures := []database.Feature{ + {Name: "libtext-wrapi18n-perl", Version: "0.06-7"}, + {Name: "libtext-charwidth-perl", Version: "0.04-7"}, + {Name: "libtext-iconv-perl", Version: "1.7-5"}, + {Name: "mawk", Version: "1.3.3-17"}, + {Name: "insserv", Version: "1.14.0-5"}, + {Name: "db", Version: "5.1.29-5"}, + {Name: "ustr", Version: "1.0.4-3"}, + {Name: "xz-utils", Version: "5.1.1alpha+20120614-2"}, + } + + nonUpgradedMap := map[database.Feature]struct{}{} + for _, f := range nonUpgradedFeatures { + f.VersionFormat = "dpkg" + nonUpgradedMap[f] = struct{}{} } // Process test layers. @@ -78,42 +347,294 @@ func TestProcessWithDistUpgrade(t *testing.T) { // wheezy.tar: FROM debian:wheezy // jessie.tar: RUN sed -i "s/precise/trusty/" /etc/apt/sources.list && apt-get update && // apt-get -y dist-upgrade - assert.Nil(t, ProcessLayer(datastore, "Docker", "blank", "", testDataPath+"blank.tar.gz", nil)) - assert.Nil(t, ProcessLayer(datastore, "Docker", "wheezy", "blank", testDataPath+"wheezy.tar.gz", nil)) - assert.Nil(t, ProcessLayer(datastore, "Docker", "jessie", "wheezy", testDataPath+"jessie.tar.gz", nil)) + _, f, _, _ := runtime.Caller(0) + testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" - // Ensure that the 'wheezy' layer has the expected namespace and features. - wheezy, ok := datastore.layers["wheezy"] - if assert.True(t, ok, "layer 'wheezy' not processed") { - if !assert.Len(t, wheezy.Namespaces, 1) { - return - } - assert.Equal(t, "debian:7", wheezy.Namespaces[0].Name) - assert.Len(t, wheezy.Features, 52) + datastore := newMockDatastore() - for _, nufv := range nonUpgradedFeatureVersions { - nufv.Feature.Namespace.Name = "debian:7" - nufv.Feature.Namespace.VersionFormat = dpkg.ParserName - assert.Contains(t, wheezy.Features, nufv) + layers := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + {Hash: "jessie", Path: testDataPath + "jessie.tar.gz"}, + } + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + // check the ancestry features + assert.Len(t, datastore.ancestry["Mock"].Features, 74) + for _, f := range datastore.ancestry["Mock"].Features { + if _, ok := nonUpgradedMap[f.Feature]; ok { + assert.Equal(t, "debian:7", f.Namespace.Name) + } else { + assert.Equal(t, "debian:8", f.Namespace.Name) } } - // Ensure that the 'wheezy' layer has the expected namespace and non-upgraded features. - jessie, ok := datastore.layers["jessie"] - if assert.True(t, ok, "layer 'jessie' not processed") { - assert.Len(t, jessie.Namespaces, 1) - assert.Equal(t, "debian:8", jessie.Namespaces[0].Name) + assert.Equal(t, []database.Layer{ + {Hash: "blank"}, + {Hash: "wheezy"}, + {Hash: "jessie"}, + }, datastore.ancestry["Mock"].Layers) +} + +func TestProcessLayers(t *testing.T) { + _, f, _, _ := runtime.Caller(0) + testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" + + datastore := newMockDatastore() + + layers := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + {Hash: "jessie", Path: testDataPath + "jessie.tar.gz"}, + } + + processedLayers, err := processLayers(datastore, "Docker", layers) + assert.Nil(t, err) + assert.Len(t, processedLayers, 3) + // ensure resubmit won't break the stuff + processedLayers, err = processLayers(datastore, "Docker", layers) + assert.Nil(t, err) + assert.Len(t, processedLayers, 3) + // Ensure each processed layer is correct + assert.Len(t, processedLayers[0].Namespaces, 0) + assert.Len(t, processedLayers[1].Namespaces, 1) + assert.Len(t, processedLayers[2].Namespaces, 1) + assert.Len(t, processedLayers[0].Features, 0) + assert.Len(t, processedLayers[1].Features, 52) + assert.Len(t, processedLayers[2].Features, 74) + + // Ensure each layer has expected namespaces and features detected + if blank, ok := datastore.layers["blank"]; ok { + assert.Equal(t, blank.ProcessedBy.Detectors, Processors.Detectors) + assert.Equal(t, blank.ProcessedBy.Listers, Processors.Listers) + assert.Len(t, blank.Namespaces, 0) + assert.Len(t, blank.Features, 0) + } else { + assert.Fail(t, "blank is not stored") + return + } + + if wheezy, ok := datastore.layers["wheezy"]; ok { + assert.Equal(t, wheezy.ProcessedBy.Detectors, Processors.Detectors) + assert.Equal(t, wheezy.ProcessedBy.Listers, Processors.Listers) + assert.Equal(t, wheezy.Namespaces, []database.Namespace{{Name: "debian:7", VersionFormat: dpkg.ParserName}}) + assert.Len(t, wheezy.Features, 52) + } else { + assert.Fail(t, "wheezy is not stored") + return + } + + if jessie, ok := datastore.layers["jessie"]; ok { + assert.Equal(t, jessie.ProcessedBy.Detectors, Processors.Detectors) + assert.Equal(t, jessie.ProcessedBy.Listers, Processors.Listers) + assert.Equal(t, jessie.Namespaces, []database.Namespace{{Name: "debian:8", VersionFormat: dpkg.ParserName}}) assert.Len(t, jessie.Features, 74) + } else { + assert.Fail(t, "jessie is not stored") + return + } +} + +// TestUpgradeClair checks if a clair is upgraded and certain ancestry's +// features should not change. We assume that Clair should only upgrade +func TestClairUpgrade(t *testing.T) { + _, f, _, _ := runtime.Caller(0) + testDataPath := filepath.Join(filepath.Dir(f)) + "/testdata/DistUpgrade/" + + datastore := newMockDatastore() + + // suppose there are two ancestries. + layers := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + {Hash: "jessie", Path: testDataPath + "jessie.tar.gz"}, + } + + layers2 := []LayerRequest{ + {Hash: "blank", Path: testDataPath + "blank.tar.gz"}, + {Hash: "wheezy", Path: testDataPath + "wheezy.tar.gz"}, + } + + // Suppose user scan an ancestry with an old instance of Clair. + Processors = database.Processors{ + Detectors: []string{"os-release"}, + Listers: []string{"rpm"}, + } + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + assert.Len(t, datastore.ancestry["Mock"].Features, 0) + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock2", layers2)) + assert.Len(t, datastore.ancestry["Mock2"].Features, 0) + + // Clair is upgraded to use a new namespace detector. The expected + // behavior is that all layers will be rescanned with "apt-sources" and + // the ancestry's features are recalculated. + Processors = database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"rpm"}, + } + + // Even though Clair processors are upgraded, the ancestry's features should + // not be upgraded without posting the ancestry to Clair again. + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + assert.Len(t, datastore.ancestry["Mock"].Features, 0) + + // Clair is upgraded to use a new feature lister. The expected behavior is + // that all layers will be rescanned with "dpkg" and the ancestry's features + // are invalidated and recalculated. + Processors = database.Processors{ + Detectors: []string{"os-release", "apt-sources"}, + Listers: []string{"rpm", "dpkg"}, + } + + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock", layers)) + assert.Len(t, datastore.ancestry["Mock"].Features, 74) + assert.Nil(t, ProcessAncestry(datastore, "Docker", "Mock2", layers2)) + assert.Len(t, datastore.ancestry["Mock2"].Features, 52) + + // check the namespaces are correct + for _, f := range datastore.ancestry["Mock"].Features { + if !assert.NotEqual(t, database.Namespace{}, f.Namespace) { + assert.Fail(t, "Every feature should have a namespace attached") + } + } - for _, nufv := range nonUpgradedFeatureVersions { - nufv.Feature.Namespace.Name = "debian:7" - nufv.Feature.Namespace.VersionFormat = dpkg.ParserName - assert.Contains(t, jessie.Features, nufv) + for _, f := range datastore.ancestry["Mock2"].Features { + if !assert.NotEqual(t, database.Namespace{}, f.Namespace) { + assert.Fail(t, "Every feature should have a namespace attached") } - for _, nufv := range nonUpgradedFeatureVersions { - nufv.Feature.Namespace.Name = "debian:8" - nufv.Feature.Namespace.VersionFormat = dpkg.ParserName - assert.NotContains(t, jessie.Features, nufv) + } +} + +// TestMultipleNamespaces tests computing ancestry features +func TestComputeAncestryFeatures(t *testing.T) { + vf1 := "format 1" + vf2 := "format 2" + + ns1a := database.Namespace{ + Name: "namespace 1:a", + VersionFormat: vf1, + } + + ns1b := database.Namespace{ + Name: "namespace 1:b", + VersionFormat: vf1, + } + + ns2a := database.Namespace{ + Name: "namespace 2:a", + VersionFormat: vf2, + } + + ns2b := database.Namespace{ + Name: "namespace 2:b", + VersionFormat: vf2, + } + + f1 := database.Feature{ + Name: "feature 1", + Version: "0.1", + VersionFormat: vf1, + } + + f2 := database.Feature{ + Name: "feature 2", + Version: "0.2", + VersionFormat: vf1, + } + + f3 := database.Feature{ + Name: "feature 1", + Version: "0.3", + VersionFormat: vf2, + } + + f4 := database.Feature{ + Name: "feature 2", + Version: "0.3", + VersionFormat: vf2, + } + + // Suppose Clair is watching two files for namespaces one containing ns1 + // changes e.g. os-release and the other one containing ns2 changes e.g. + // node. + blank := database.LayerWithContent{Layer: database.Layer{Hash: "blank"}} + initNS1a := database.LayerWithContent{ + Layer: database.Layer{Hash: "init ns1a"}, + Namespaces: []database.Namespace{ns1a}, + Features: []database.Feature{f1, f2}, + } + + upgradeNS2b := database.LayerWithContent{ + Layer: database.Layer{Hash: "upgrade ns2b"}, + Namespaces: []database.Namespace{ns2b}, + } + + upgradeNS1b := database.LayerWithContent{ + Layer: database.Layer{Hash: "upgrade ns1b"}, + Namespaces: []database.Namespace{ns1b}, + Features: []database.Feature{f1, f2}, + } + + initNS2a := database.LayerWithContent{ + Layer: database.Layer{Hash: "init ns2a"}, + Namespaces: []database.Namespace{ns2a}, + Features: []database.Feature{f3, f4}, + } + + removeF2 := database.LayerWithContent{ + Layer: database.Layer{Hash: "remove f2"}, + Features: []database.Feature{f1}, + } + + // blank -> ns1:a, f1 f2 (init) + // -> f1 (feature change) + // -> ns2:a, f3, f4 (init ns2a) + // -> ns2:b (ns2 upgrade without changing features) + // -> blank (empty) + // -> ns1:b, f1 f2 (ns1 upgrade and add f2) + // -> f1 (remove f2) + // -> blank (empty) + + layers := []database.LayerWithContent{ + blank, + initNS1a, + removeF2, + initNS2a, + upgradeNS2b, + blank, + upgradeNS1b, + removeF2, + blank, + } + + expected := map[database.NamespacedFeature]bool{ + { + Feature: f1, + Namespace: ns1a, + }: false, + { + Feature: f3, + Namespace: ns2a, + }: false, + { + Feature: f4, + Namespace: ns2a, + }: false, + } + + features, err := computeAncestryFeatures(layers) + assert.Nil(t, err) + for _, f := range features { + if assert.Contains(t, expected, f) { + if assert.False(t, expected[f]) { + expected[f] = true + } } } + + for f, visited := range expected { + assert.True(t, visited, "expected feature is missing : "+f.Namespace.Name+":"+f.Name) + } }