diff --git a/updater.go b/updater.go index 23e46ddd..9ef710b2 100644 --- a/updater.go +++ b/updater.go @@ -255,7 +255,6 @@ var errReceivedStopSignal = errors.New("stopped") func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirstUpdate bool, st *stopper.Stopper) (sleepDuration time.Duration, err error) { g, ctx := errgroup.WithContext(context.Background()) g.Go(func() error { - // todo handle ctx return update(ctx, datastore, isFirstUpdate) }) @@ -271,7 +270,7 @@ func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirs refreshDuration = time.Until(lockExpiration) case <-ctx.Done(): database.ReleaseLock(datastore, updaterLockName, whoAmI) - return nil + return ctx.Err() } } }) @@ -281,7 +280,7 @@ func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirs case <-st.Chan(): return errReceivedStopSignal case <-ctx.Done(): - return nil + return ctx.Err() } }) @@ -297,7 +296,7 @@ func update(ctx context.Context, datastore database.Datastore, firstUpdate bool) log.Info("updating vulnerabilities") // Fetch updates. - success, vulnerabilities, flags, notes := fetch(datastore) + success, vulnerabilities, flags, notes := fetchUpdates(ctx, datastore) // do vulnerability namespacing again to merge potentially duplicated // vulnerabilities from each updater. @@ -319,7 +318,7 @@ func update(ctx context.Context, datastore database.Datastore, firstUpdate bool) return err } - changes, err := updateVulnerabilities(datastore, vulnerabilities) + changes, err := updateVulnerabilities(ctx, datastore, vulnerabilities) defer func() { if err != nil { @@ -367,56 +366,60 @@ func setUpdaterDuration(start time.Time) { promUpdaterDurationSeconds.Set(time.Since(start).Seconds()) } -// fetch get data from the registered fetchers, in parallel. -func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffected, map[string]string, []string) { - var vulnerabilities []database.VulnerabilityWithAffected - var notes []string - status := true - flags := make(map[string]string) +// fetchUpdates asynchronously runs all of the enabled Updaters, aggregates +// their results, and appends metadata to the vulnerabilities found. +func fetchUpdates(ctx context.Context, datastore database.Datastore) (status bool, vulns []database.VulnerabilityWithAffected, flags map[string]string, notes []string) { + flags = make(map[string]string) - // Fetch updates in parallel. log.Info("fetching vulnerability updates") - var responseC = make(chan *vulnsrc.UpdateResponse, 0) - numUpdaters := 0 - for n, u := range vulnsrc.Updaters() { - if !updaterEnabled(n) { - continue - } - numUpdaters++ - go func(name string, u vulnsrc.Updater) { - response, err := u.Update(datastore) + + var mu sync.RWMutex + g, ctx := errgroup.WithContext(ctx) + for updaterName, updater := range vulnsrc.Updaters() { + // Shadow the loop variables to avoid closing over the wrong thing. + // See: https://golang.org/doc/faq#closures_and_goroutines + updaterName := updaterName + updater := updater + + g.Go(func() error { + if !updaterEnabled(updaterName) { + return nil + } + + // TODO(jzelinskie): add context to Update() + response, err := updater.Update(datastore) if err != nil { promUpdaterErrorsTotal.Inc() - log.WithError(err).WithField("updater name", name).Error("an error occurred when fetching update") - status = false - responseC <- nil - return + log.WithError(err).WithField("updater", updaterName).Error("an error occurred when fetching an update") + return err } - responseC <- &response - log.WithField("updater name", name).Info("finished fetching") - }(n, u) - } + namespacedVulns := doVulnerabilitiesNamespacing(response.Vulnerabilities) - // Collect results of updates. - for i := 0; i < numUpdaters; i++ { - resp := <-responseC - if resp != nil { - vulnerabilities = append(vulnerabilities, doVulnerabilitiesNamespacing(resp.Vulnerabilities)...) - notes = append(notes, resp.Notes...) - if resp.FlagName != "" && resp.FlagValue != "" { - flags[resp.FlagName] = resp.FlagValue + mu.Lock() + vulns = append(vulns, namespacedVulns...) + notes = append(notes, response.Notes...) + if response.FlagName != "" && response.FlagValue != "" { + flags[response.FlagName] = response.FlagValue } - } + mu.Unlock() + + return nil + }) + } + + if err := g.Wait(); err == nil { + status = true } - close(responseC) - return status, addMetadata(datastore, vulnerabilities), flags, notes + vulns = addMetadata(ctx, datastore, vulns) + + return } -// Add metadata to the specified vulnerabilities using the registered -// MetadataFetchers, in parallel. -func addMetadata(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { +// addMetadata asynchronously updates a list of vulnerabilities with metadata +// from the vulnerability metadata sources. +func addMetadata(ctx context.Context, datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 { return vulnerabilities } @@ -432,31 +435,39 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner }) } - var wg sync.WaitGroup - wg.Add(len(vulnmdsrc.Appenders())) + g, ctx := errgroup.WithContext(ctx) + for name, metadataAppender := range vulnmdsrc.Appenders() { + // Shadow the loop variables to avoid closing over the wrong thing. + // See: https://golang.org/doc/faq#closures_and_goroutines + name := name + metadataAppender := metadataAppender - for n, a := range vulnmdsrc.Appenders() { - go func(name string, appender vulnmdsrc.Appender) { - defer wg.Done() - - // Build up a metadata cache. - if err := appender.BuildCache(datastore); err != nil { + g.Go(func() error { + // TODO(jzelinskie): add ctx to BuildCache() + if err := metadataAppender.BuildCache(datastore); err != nil { promUpdaterErrorsTotal.Inc() - log.WithError(err).WithField("appender name", name).Error("an error occurred when loading metadata fetcher") - return + log.WithError(err).WithField("appender", name).Error("an error occurred when fetching vulnerability metadata") + return err } + defer metadataAppender.PurgeCache() + + for i, vulnerability := range lockableVulnerabilities { + metadataAppender.Append(vulnerability.Name, vulnerability.appendFunc) - // Append vulnerability metadata to each vulnerability. - for _, vulnerability := range lockableVulnerabilities { - appender.Append(vulnerability.Name, vulnerability.appendFunc) + if i%10 == 0 { + select { + case <-ctx.Done(): + return nil + default: + } + } } - // Purge the metadata cache. - appender.PurgeCache() - }(n, a) + return nil + }) } - wg.Wait() + g.Wait() return vulnerabilities } @@ -717,7 +728,7 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu // updateVulnerabilities upserts unique vulnerabilities into the database and // computes vulnerability changes. -func updateVulnerabilities(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { +func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities") if len(vulnerabilities) == 0 { return nil, nil @@ -735,13 +746,19 @@ func updateVulnerabilities(datastore database.Datastore, vulnerabilities []datab if err != nil { return nil, err } - defer tx.Rollback() + oldVulnNullable, err := tx.FindVulnerabilities(ids) if err != nil { return nil, err } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + oldVuln := []database.VulnerabilityWithAffected{} for _, vuln := range oldVulnNullable { if vuln.Valid { @@ -754,6 +771,12 @@ func updateVulnerabilities(datastore database.Datastore, vulnerabilities []datab return nil, err } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + toRemove := []database.VulnerabilityID{} toAdd := []database.VulnerabilityWithAffected{} for _, change := range changes { diff --git a/updater_test.go b/updater_test.go index 93ad2ecd..c3f58dd1 100644 --- a/updater_test.go +++ b/updater_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2019 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -15,13 +15,13 @@ package clair import ( + "context" "errors" "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/coreos/clair/database" + "github.com/stretchr/testify/assert" ) type mockUpdaterDatastore struct { @@ -270,27 +270,27 @@ func TestCreatVulnerabilityNotification(t *testing.T) { } datastore := newmockUpdaterDatastore() - change, err := updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{}) + change, err := updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{}) assert.Nil(t, err) assert.Len(t, change, 0) - change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) + change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1}) assert.Nil(t, err) assert.Len(t, change, 1) assert.Nil(t, change[0].old) assertVulnerability(t, *change[0].new, v1) - change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) + change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1}) assert.Nil(t, err) assert.Len(t, change, 0) - change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v2}) + change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v2}) assert.Nil(t, err) assert.Len(t, change, 1) assertVulnerability(t, *change[0].new, v2) assertVulnerability(t, *change[0].old, v1) - change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v3}) + change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v3}) assert.Nil(t, err) assert.Len(t, change, 1) assertVulnerability(t, *change[0].new, v3)