updater: reimplement fetch() with errgroup

This adds context support to a few more functions in the update process.
This makes progress towards to goal of having cancellable updates.
This commit is contained in:
Jimmy Zelinskie 2019-01-09 16:04:51 -05:00
parent 6c5be7e1c6
commit 0d41968acd
2 changed files with 95 additions and 72 deletions

View File

@ -255,7 +255,6 @@ var errReceivedStopSignal = errors.New("stopped")
func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirstUpdate bool, st *stopper.Stopper) (sleepDuration time.Duration, err error) { func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirstUpdate bool, st *stopper.Stopper) (sleepDuration time.Duration, err error) {
g, ctx := errgroup.WithContext(context.Background()) g, ctx := errgroup.WithContext(context.Background())
g.Go(func() error { g.Go(func() error {
// todo handle ctx
return update(ctx, datastore, isFirstUpdate) return update(ctx, datastore, isFirstUpdate)
}) })
@ -271,7 +270,7 @@ func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirs
refreshDuration = time.Until(lockExpiration) refreshDuration = time.Until(lockExpiration)
case <-ctx.Done(): case <-ctx.Done():
database.ReleaseLock(datastore, updaterLockName, whoAmI) database.ReleaseLock(datastore, updaterLockName, whoAmI)
return nil return ctx.Err()
} }
} }
}) })
@ -281,7 +280,7 @@ func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirs
case <-st.Chan(): case <-st.Chan():
return errReceivedStopSignal return errReceivedStopSignal
case <-ctx.Done(): case <-ctx.Done():
return nil return ctx.Err()
} }
}) })
@ -297,7 +296,7 @@ func update(ctx context.Context, datastore database.Datastore, firstUpdate bool)
log.Info("updating vulnerabilities") log.Info("updating vulnerabilities")
// Fetch updates. // Fetch updates.
success, vulnerabilities, flags, notes := fetch(datastore) success, vulnerabilities, flags, notes := fetchUpdates(ctx, datastore)
// do vulnerability namespacing again to merge potentially duplicated // do vulnerability namespacing again to merge potentially duplicated
// vulnerabilities from each updater. // vulnerabilities from each updater.
@ -319,7 +318,7 @@ func update(ctx context.Context, datastore database.Datastore, firstUpdate bool)
return err return err
} }
changes, err := updateVulnerabilities(datastore, vulnerabilities) changes, err := updateVulnerabilities(ctx, datastore, vulnerabilities)
defer func() { defer func() {
if err != nil { if err != nil {
@ -367,56 +366,60 @@ func setUpdaterDuration(start time.Time) {
promUpdaterDurationSeconds.Set(time.Since(start).Seconds()) promUpdaterDurationSeconds.Set(time.Since(start).Seconds())
} }
// fetch get data from the registered fetchers, in parallel. // fetchUpdates asynchronously runs all of the enabled Updaters, aggregates
func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffected, map[string]string, []string) { // their results, and appends metadata to the vulnerabilities found.
var vulnerabilities []database.VulnerabilityWithAffected func fetchUpdates(ctx context.Context, datastore database.Datastore) (status bool, vulns []database.VulnerabilityWithAffected, flags map[string]string, notes []string) {
var notes []string flags = make(map[string]string)
status := true
flags := make(map[string]string)
// Fetch updates in parallel.
log.Info("fetching vulnerability updates") log.Info("fetching vulnerability updates")
var responseC = make(chan *vulnsrc.UpdateResponse, 0)
numUpdaters := 0 var mu sync.RWMutex
for n, u := range vulnsrc.Updaters() { g, ctx := errgroup.WithContext(ctx)
if !updaterEnabled(n) { for updaterName, updater := range vulnsrc.Updaters() {
continue // Shadow the loop variables to avoid closing over the wrong thing.
} // See: https://golang.org/doc/faq#closures_and_goroutines
numUpdaters++ updaterName := updaterName
go func(name string, u vulnsrc.Updater) { updater := updater
response, err := u.Update(datastore)
g.Go(func() error {
if !updaterEnabled(updaterName) {
return nil
}
// TODO(jzelinskie): add context to Update()
response, err := updater.Update(datastore)
if err != nil { if err != nil {
promUpdaterErrorsTotal.Inc() promUpdaterErrorsTotal.Inc()
log.WithError(err).WithField("updater name", name).Error("an error occurred when fetching update") log.WithError(err).WithField("updater", updaterName).Error("an error occurred when fetching an update")
status = false return err
responseC <- nil
return
} }
responseC <- &response namespacedVulns := doVulnerabilitiesNamespacing(response.Vulnerabilities)
log.WithField("updater name", name).Info("finished fetching")
}(n, u)
}
// Collect results of updates. mu.Lock()
for i := 0; i < numUpdaters; i++ { vulns = append(vulns, namespacedVulns...)
resp := <-responseC notes = append(notes, response.Notes...)
if resp != nil { if response.FlagName != "" && response.FlagValue != "" {
vulnerabilities = append(vulnerabilities, doVulnerabilitiesNamespacing(resp.Vulnerabilities)...) flags[response.FlagName] = response.FlagValue
notes = append(notes, resp.Notes...)
if resp.FlagName != "" && resp.FlagValue != "" {
flags[resp.FlagName] = resp.FlagValue
} }
} mu.Unlock()
return nil
})
} }
close(responseC) if err := g.Wait(); err == nil {
return status, addMetadata(datastore, vulnerabilities), flags, notes status = true
}
vulns = addMetadata(ctx, datastore, vulns)
return
} }
// Add metadata to the specified vulnerabilities using the registered // addMetadata asynchronously updates a list of vulnerabilities with metadata
// MetadataFetchers, in parallel. // from the vulnerability metadata sources.
func addMetadata(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { func addMetadata(ctx context.Context, datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected {
if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 { if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 {
return vulnerabilities return vulnerabilities
} }
@ -432,31 +435,39 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner
}) })
} }
var wg sync.WaitGroup g, ctx := errgroup.WithContext(ctx)
wg.Add(len(vulnmdsrc.Appenders())) for name, metadataAppender := range vulnmdsrc.Appenders() {
// Shadow the loop variables to avoid closing over the wrong thing.
// See: https://golang.org/doc/faq#closures_and_goroutines
name := name
metadataAppender := metadataAppender
for n, a := range vulnmdsrc.Appenders() { g.Go(func() error {
go func(name string, appender vulnmdsrc.Appender) { // TODO(jzelinskie): add ctx to BuildCache()
defer wg.Done() if err := metadataAppender.BuildCache(datastore); err != nil {
// Build up a metadata cache.
if err := appender.BuildCache(datastore); err != nil {
promUpdaterErrorsTotal.Inc() promUpdaterErrorsTotal.Inc()
log.WithError(err).WithField("appender name", name).Error("an error occurred when loading metadata fetcher") log.WithError(err).WithField("appender", name).Error("an error occurred when fetching vulnerability metadata")
return return err
}
defer metadataAppender.PurgeCache()
for i, vulnerability := range lockableVulnerabilities {
metadataAppender.Append(vulnerability.Name, vulnerability.appendFunc)
if i%10 == 0 {
select {
case <-ctx.Done():
return nil
default:
}
}
} }
// Append vulnerability metadata to each vulnerability. return nil
for _, vulnerability := range lockableVulnerabilities { })
appender.Append(vulnerability.Name, vulnerability.appendFunc)
}
// Purge the metadata cache.
appender.PurgeCache()
}(n, a)
} }
wg.Wait() g.Wait()
return vulnerabilities return vulnerabilities
} }
@ -717,7 +728,7 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu
// updateVulnerabilities upserts unique vulnerabilities into the database and // updateVulnerabilities upserts unique vulnerabilities into the database and
// computes vulnerability changes. // computes vulnerability changes.
func updateVulnerabilities(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) {
log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities") log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities")
if len(vulnerabilities) == 0 { if len(vulnerabilities) == 0 {
return nil, nil return nil, nil
@ -735,13 +746,19 @@ func updateVulnerabilities(datastore database.Datastore, vulnerabilities []datab
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
oldVulnNullable, err := tx.FindVulnerabilities(ids) oldVulnNullable, err := tx.FindVulnerabilities(ids)
if err != nil { if err != nil {
return nil, err return nil, err
} }
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
oldVuln := []database.VulnerabilityWithAffected{} oldVuln := []database.VulnerabilityWithAffected{}
for _, vuln := range oldVulnNullable { for _, vuln := range oldVulnNullable {
if vuln.Valid { if vuln.Valid {
@ -754,6 +771,12 @@ func updateVulnerabilities(datastore database.Datastore, vulnerabilities []datab
return nil, err return nil, err
} }
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
toRemove := []database.VulnerabilityID{} toRemove := []database.VulnerabilityID{}
toAdd := []database.VulnerabilityWithAffected{} toAdd := []database.VulnerabilityWithAffected{}
for _, change := range changes { for _, change := range changes {

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2019 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,13 +15,13 @@
package clair package clair
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/stretchr/testify/assert"
) )
type mockUpdaterDatastore struct { type mockUpdaterDatastore struct {
@ -270,27 +270,27 @@ func TestCreatVulnerabilityNotification(t *testing.T) {
} }
datastore := newmockUpdaterDatastore() datastore := newmockUpdaterDatastore()
change, err := updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{}) change, err := updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 0) assert.Len(t, change, 0)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 1) assert.Len(t, change, 1)
assert.Nil(t, change[0].old) assert.Nil(t, change[0].old)
assertVulnerability(t, *change[0].new, v1) assertVulnerability(t, *change[0].new, v1)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 0) assert.Len(t, change, 0)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v2}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v2})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 1) assert.Len(t, change, 1)
assertVulnerability(t, *change[0].new, v2) assertVulnerability(t, *change[0].new, v2)
assertVulnerability(t, *change[0].old, v1) assertVulnerability(t, *change[0].old, v1)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v3}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v3})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 1) assert.Len(t, change, 1)
assertVulnerability(t, *change[0].new, v3) assertVulnerability(t, *change[0].new, v3)