diff --git a/api/v3/rpc.go b/api/v3/rpc.go index 27abe352..9a7cae74 100644 --- a/api/v3/rpc.go +++ b/api/v3/rpc.go @@ -15,8 +15,6 @@ package v3 import ( - "fmt" - "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -25,7 +23,6 @@ import ( pb "github.com/coreos/clair/api/v3/clairpb" "github.com/coreos/clair/database" "github.com/coreos/clair/ext/imagefmt" - "github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/pagination" ) @@ -128,20 +125,13 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty") } - tx, err := s.Store.Begin() - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - defer tx.Rollback() - - ancestry, ok, err := tx.FindAncestry(name) + ancestry, ok, err := database.FindAncestryAndRollback(s.Store, name) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, newRPCErrorWithClairError(codes.Internal, err) } if !ok { - return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName())) + return nil, status.Errorf(codes.NotFound, "requested ancestry '%s' is not found", req.GetAncestryName()) } pbAncestry := &pb.GetAncestryResponse_Ancestry{ @@ -150,7 +140,7 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq } for _, layer := range ancestry.Layers { - pbLayer, err := GetPbAncestryLayer(tx, layer) + pbLayer, err := s.GetPbAncestryLayer(layer) if err != nil { return nil, err } @@ -180,13 +170,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1") } - tx, err := s.Store.Begin() - if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - defer tx.Rollback() - - dbNotification, ok, err := tx.FindVulnerabilityNotification( + dbNotification, ok, err := database.FindVulnerabilityNotificationAndRollback( + s.Store, req.GetName(), int(req.GetLimit()), pagination.Token(req.GetOldVulnerabilityPage()), @@ -194,11 +179,11 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot ) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, newRPCErrorWithClairError(codes.Internal, err) } if !ok { - return nil, status.Error(codes.NotFound, fmt.Sprintf("requested notification '%s' is not found", req.GetName())) + return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName()) } notification, err := pb.NotificationFromDatabaseModel(dbNotification) @@ -216,21 +201,13 @@ func (s *NotificationServer) MarkNotificationAsRead(ctx context.Context, req *pb return nil, status.Error(codes.InvalidArgument, "notification name should not be empty") } - tx, err := s.Store.Begin() + found, err := database.MarkNotificationAsReadAndCommit(s.Store, req.GetName()) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) - } - - defer tx.Rollback() - err = tx.DeleteNotification(req.GetName()) - if err == commonerr.ErrNotFound { - return nil, status.Error(codes.NotFound, "requested notification \""+req.GetName()+"\" is not found") - } else if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, newRPCErrorWithClairError(codes.Internal, err) } - if err := tx.Commit(); err != nil { - return nil, status.Error(codes.Internal, err.Error()) + if !found { + return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName()) } return &pb.MarkNotificationAsReadResponse{}, nil diff --git a/api/v3/util.go b/api/v3/util.go index fa0ff3bc..4e7446b5 100644 --- a/api/v3/util.go +++ b/api/v3/util.go @@ -33,7 +33,7 @@ func GetClairStatus(store database.Datastore) (*pb.ClairStatus, error) { // GetPbAncestryLayer retrieves an ancestry layer with vulnerabilities and // features in an ancestry based on the provided database layer. -func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) { +func (s *AncestryServer) GetPbAncestryLayer(layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) { pbLayer := &pb.GetAncestryResponse_AncestryLayer{ Layer: &pb.Layer{ Hash: layer.Hash, @@ -41,18 +41,14 @@ func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb. } features := layer.GetFeatures() - affectedFeatures, err := tx.FindAffectedNamespacedFeatures(features) + affectedFeatures, err := database.FindAffectedNamespacedFeaturesAndRollback(s.Store, features) if err != nil { - return nil, status.Error(codes.Internal, err.Error()) + return nil, newRPCErrorWithClairError(codes.Internal, err) } - // NOTE(sidac): It's quite inefficient, but the easiest way to implement - // this feature for now, we should refactor the implementation if there's - // any performance issue. It's expected that the number of features is less - // than 1000. for _, feature := range affectedFeatures { if !feature.Valid { - return nil, status.Error(codes.Internal, "ancestry feature is not found") + panic("feature is missing in the database, it indicates the database is corrupted.") } for _, detectedFeature := range layer.Features { diff --git a/database/dbutil.go b/database/dbutil.go index e7c85c8f..e8e1f912 100644 --- a/database/dbutil.go +++ b/database/dbutil.go @@ -20,6 +20,8 @@ import ( log "github.com/sirupsen/logrus" + "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" "github.com/deckarep/golang-set" ) @@ -400,3 +402,125 @@ func PersistDetectorsAndCommit(store Datastore, detectors []Detector) error { return nil } + +// MarkNotificationAsReadAndCommit marks a notification as read. +func MarkNotificationAsReadAndCommit(store Datastore, name string) (bool, error) { + tx, err := store.Begin() + if err != nil { + return false, err + } + + defer tx.Rollback() + err = tx.DeleteNotification(name) + if err == commonerr.ErrNotFound { + return false, nil + } else if err != nil { + return false, err + } + + if err := tx.Commit(); err != nil { + return false, err + } + + return true, nil +} + +// FindAffectedNamespacedFeaturesAndRollback finds the vulnerabilities on each +// feature. +func FindAffectedNamespacedFeaturesAndRollback(store Datastore, features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) { + tx, err := store.Begin() + if err != nil { + return nil, err + } + + defer tx.Rollback() + nullableFeatures, err := tx.FindAffectedNamespacedFeatures(features) + if err != nil { + return nil, err + } + + return nullableFeatures, nil +} + +// FindVulnerabilityNotificationAndRollback finds the vulnerability notification +// and rollback. +func FindVulnerabilityNotificationAndRollback(store Datastore, name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (VulnerabilityNotificationWithVulnerable, bool, error) { + tx, err := store.Begin() + if err != nil { + return VulnerabilityNotificationWithVulnerable{}, false, err + } + + defer tx.Rollback() + return tx.FindVulnerabilityNotification(name, limit, oldVulnerabilityPage, newVulnerabilityPage) +} + +// FindNewNotification finds notifications either never notified or notified +// before the given time. +func FindNewNotification(store Datastore, notifiedBefore time.Time) (NotificationHook, bool, error) { + tx, err := store.Begin() + if err != nil { + return NotificationHook{}, false, err + } + + defer tx.Rollback() + return tx.FindNewNotification(notifiedBefore) +} + +// UpdateKeyValueAndCommit stores the key value to storage. +func UpdateKeyValueAndCommit(store Datastore, key, value string) error { + tx, err := store.Begin() + if err != nil { + return err + } + + defer tx.Rollback() + if err = tx.UpdateKeyValue(key, value); err != nil { + return err + } + + return tx.Commit() +} + +// InsertVulnerabilityNotificationsAndCommit inserts the notifications into db +// and commit. +func InsertVulnerabilityNotificationsAndCommit(store Datastore, notifications []VulnerabilityNotification) error { + tx, err := store.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if err := tx.InsertVulnerabilityNotifications(notifications); err != nil { + return err + } + + return tx.Commit() +} + +// FindVulnerabilitiesAndRollback finds the vulnerabilities based on given ids. +func FindVulnerabilitiesAndRollback(store Datastore, ids []VulnerabilityID) ([]NullableVulnerability, error) { + tx, err := store.Begin() + if err != nil { + return nil, err + } + + defer tx.Rollback() + return tx.FindVulnerabilities(ids) +} + +func UpdateVulnerabilitiesAndCommit(store Datastore, toRemove []VulnerabilityID, toAdd []VulnerabilityWithAffected) error { + tx, err := store.Begin() + if err != nil { + return err + } + + if err := tx.DeleteVulnerabilities(toRemove); err != nil { + return err + } + + if err := tx.InsertVulnerabilities(toAdd); err != nil { + return err + } + + return tx.Commit() +} diff --git a/ext/vulnsrc/suse/suse.go b/ext/vulnsrc/suse/suse.go index 815668ea..e7203960 100644 --- a/ext/vulnsrc/suse/suse.go +++ b/ext/vulnsrc/suse/suse.go @@ -127,16 +127,10 @@ func init() { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { log.WithField("package", u.Name).Info("Start fetching vulnerabilities") - tx, err := datastore.Begin() - if err != nil { - return resp, err - } - defer tx.Rollback() - // openSUSE and SUSE have one single xml file for all the products, there are no incremental // xml files. We store into the database the value of the generation timestamp // of the latest file we parsed. - flagValue, ok, err := tx.FindKeyValue(u.UpdaterFlag) + flagValue, ok, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag) if err != nil { return resp, err } diff --git a/ext/vulnsrc/ubuntu/ubuntu.go b/ext/vulnsrc/ubuntu/ubuntu.go index 1fef9e34..c4cd3f1b 100644 --- a/ext/vulnsrc/ubuntu/ubuntu.go +++ b/ext/vulnsrc/ubuntu/ubuntu.go @@ -94,13 +94,6 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er return resp, err } - // Open a database transaction. - tx, err := db.Begin() - if err != nil { - return resp, err - } - defer tx.Rollback() - // Ask the database for the latest commit we successfully applied. dbCommit, ok, err := database.FindKeyValueAndRollback(db, updaterFlag) if err != nil { diff --git a/notifier.go b/notifier.go index 69eb6058..07c5004c 100644 --- a/notifier.go +++ b/notifier.go @@ -93,7 +93,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop go func() { success, interrupted := handleTask(*notification, stopper, config.Attempts) if success { - err := markNotificationAsRead(datastore, notification.Name) + _, err := database.MarkNotificationAsReadAndCommit(datastore, notification.Name) if err != nil { log.WithError(err).Error("Failed to mark notification notified") } @@ -126,7 +126,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook { for { - notification, ok, err := findNewNotification(datastore, renotifyInterval) + notification, ok, err := database.FindNewNotification(datastore, time.Now().Add(-renotifyInterval)) if err != nil || !ok { if !ok { log.WithError(err).Warning("could not get notification to send") @@ -186,25 +186,3 @@ func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts in log.WithField(logNotiName, n.Name).Info("successfully sent notification") return true, false } - -func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) { - tx, err := datastore.Begin() - if err != nil { - return database.NotificationHook{}, false, err - } - defer tx.Rollback() - return tx.FindNewNotification(time.Now().Add(-renotifyInterval)) -} - -func markNotificationAsRead(datastore database.Datastore, name string) error { - tx, err := datastore.Begin() - if err != nil { - log.WithError(err).Error("an error happens when beginning database transaction") - } - defer tx.Rollback() - - if err := tx.MarkNotificationAsRead(name); err != nil { - return err - } - return tx.Commit() -} diff --git a/updater.go b/updater.go index e0d29ad9..31bf55f1 100644 --- a/updater.go +++ b/updater.go @@ -431,13 +431,7 @@ func addMetadata(ctx context.Context, datastore database.Datastore, vulnerabilit // GetLastUpdateTime retrieves the latest successful time of update and whether // or not it's the first update. func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) { - tx, err := datastore.Begin() - if err != nil { - return time.Time{}, false, err - } - defer tx.Rollback() - - lastUpdateTSS, ok, err := tx.FindKeyValue(updaterLastFlagName) + lastUpdateTSS, ok, err := database.FindKeyValueAndRollback(datastore, updaterLastFlagName) if err != nil { return time.Time{}, false, err } @@ -449,7 +443,7 @@ func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) { lastUpdateTS, err := strconv.ParseInt(lastUpdateTSS, 10, 64) if err != nil { - return time.Time{}, false, err + panic(err) } return time.Unix(lastUpdateTS, 0).UTC(), false, nil @@ -539,40 +533,19 @@ func doVulnerabilitiesNamespacing(vulnerabilities []database.VulnerabilityWithAf return response } -// updateUpdaterFlags updates the flags specified by updaters, every transaction -// is independent of each other. func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error { for key, value := range flags { - tx, err := datastore.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - err = tx.UpdateKeyValue(key, value) - if err != nil { - return err - } - if err = tx.Commit(); err != nil { + if err := database.UpdateKeyValueAndCommit(datastore, key, value); err != nil { return err } } + return nil } // setLastUpdateTime records the last successful date time in database. func setLastUpdateTime(datastore database.Datastore) error { - tx, err := datastore.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - err = tx.UpdateKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10)) - if err != nil { - return err - } - return tx.Commit() + return database.UpdateKeyValueAndCommit(datastore, updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10)) } // isVulnerabilityChange compares two vulnerabilities by their severity and @@ -648,12 +621,6 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu return nil } - tx, err := datastore.Begin() - if err != nil { - return err - } - defer tx.Rollback() - notifications := make([]database.VulnerabilityNotification, 0, len(changes)) for _, change := range changes { var oldVuln, newVuln *database.Vulnerability @@ -675,11 +642,7 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu }) } - if err := tx.InsertVulnerabilityNotifications(notifications); err != nil { - return err - } - - return tx.Commit() + return database.InsertVulnerabilityNotificationsAndCommit(datastore, notifications) } // updateVulnerabilities upserts unique vulnerabilities into the database and @@ -698,13 +661,7 @@ func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vu }) } - tx, err := datastore.Begin() - if err != nil { - return nil, err - } - defer tx.Rollback() - - oldVulnNullable, err := tx.FindVulnerabilities(ids) + oldVulnNullable, err := database.FindVulnerabilitiesAndRollback(datastore, ids) if err != nil { return nil, err } @@ -748,21 +705,8 @@ func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vu } } - log.WithField("count", len(toRemove)).Debug("marking vulnerabilities as outdated") - if err := tx.DeleteVulnerabilities(toRemove); err != nil { - return nil, err - } - - log.WithField("count", len(toAdd)).Debug("inserting new vulnerabilities") - if err := tx.InsertVulnerabilities(toAdd); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return changes, nil + log.Debugf("there are %d vulnerability changes", len(changes)) + return changes, database.UpdateVulnerabilitiesAndCommit(datastore, toRemove, toAdd) } func updaterEnabled(updaterName string) bool {