database: Move db logic to dbutil

Move all transaction related logic to dbutil to simplify and later unify
the db interface.
This commit is contained in:
Sida Chen 2019-03-04 18:55:14 -05:00
parent 4fa03d1c78
commit 1b9ed99646
7 changed files with 152 additions and 146 deletions

View File

@ -15,8 +15,6 @@
package v3 package v3
import ( import (
"fmt"
"golang.org/x/net/context" "golang.org/x/net/context"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
@ -25,7 +23,6 @@ import (
pb "github.com/coreos/clair/api/v3/clairpb" pb "github.com/coreos/clair/api/v3/clairpb"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/imagefmt" "github.com/coreos/clair/ext/imagefmt"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination" "github.com/coreos/clair/pkg/pagination"
) )
@ -128,20 +125,13 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty") return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty")
} }
tx, err := s.Store.Begin() ancestry, ok, err := database.FindAncestryAndRollback(s.Store, name)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, newRPCErrorWithClairError(codes.Internal, err)
}
defer tx.Rollback()
ancestry, ok, err := tx.FindAncestry(name)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
} }
if !ok { if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName())) return nil, status.Errorf(codes.NotFound, "requested ancestry '%s' is not found", req.GetAncestryName())
} }
pbAncestry := &pb.GetAncestryResponse_Ancestry{ pbAncestry := &pb.GetAncestryResponse_Ancestry{
@ -150,7 +140,7 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
} }
for _, layer := range ancestry.Layers { for _, layer := range ancestry.Layers {
pbLayer, err := GetPbAncestryLayer(tx, layer) pbLayer, err := s.GetPbAncestryLayer(layer)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -180,13 +170,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1") return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1")
} }
tx, err := s.Store.Begin() dbNotification, ok, err := database.FindVulnerabilityNotificationAndRollback(
if err != nil { s.Store,
return nil, status.Error(codes.Internal, err.Error())
}
defer tx.Rollback()
dbNotification, ok, err := tx.FindVulnerabilityNotification(
req.GetName(), req.GetName(),
int(req.GetLimit()), int(req.GetLimit()),
pagination.Token(req.GetOldVulnerabilityPage()), pagination.Token(req.GetOldVulnerabilityPage()),
@ -194,11 +179,11 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
) )
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, newRPCErrorWithClairError(codes.Internal, err)
} }
if !ok { if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested notification '%s' is not found", req.GetName())) return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
} }
notification, err := pb.NotificationFromDatabaseModel(dbNotification) notification, err := pb.NotificationFromDatabaseModel(dbNotification)
@ -216,21 +201,13 @@ func (s *NotificationServer) MarkNotificationAsRead(ctx context.Context, req *pb
return nil, status.Error(codes.InvalidArgument, "notification name should not be empty") return nil, status.Error(codes.InvalidArgument, "notification name should not be empty")
} }
tx, err := s.Store.Begin() found, err := database.MarkNotificationAsReadAndCommit(s.Store, req.GetName())
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, newRPCErrorWithClairError(codes.Internal, err)
} }
defer tx.Rollback() if !found {
err = tx.DeleteNotification(req.GetName()) return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
if err == commonerr.ErrNotFound {
return nil, status.Error(codes.NotFound, "requested notification \""+req.GetName()+"\" is not found")
} else if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if err := tx.Commit(); err != nil {
return nil, status.Error(codes.Internal, err.Error())
} }
return &pb.MarkNotificationAsReadResponse{}, nil return &pb.MarkNotificationAsReadResponse{}, nil

View File

@ -33,7 +33,7 @@ func GetClairStatus(store database.Datastore) (*pb.ClairStatus, error) {
// GetPbAncestryLayer retrieves an ancestry layer with vulnerabilities and // GetPbAncestryLayer retrieves an ancestry layer with vulnerabilities and
// features in an ancestry based on the provided database layer. // features in an ancestry based on the provided database layer.
func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) { func (s *AncestryServer) GetPbAncestryLayer(layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
pbLayer := &pb.GetAncestryResponse_AncestryLayer{ pbLayer := &pb.GetAncestryResponse_AncestryLayer{
Layer: &pb.Layer{ Layer: &pb.Layer{
Hash: layer.Hash, Hash: layer.Hash,
@ -41,18 +41,14 @@ func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.
} }
features := layer.GetFeatures() features := layer.GetFeatures()
affectedFeatures, err := tx.FindAffectedNamespacedFeatures(features) affectedFeatures, err := database.FindAffectedNamespacedFeaturesAndRollback(s.Store, features)
if err != nil { if err != nil {
return nil, status.Error(codes.Internal, err.Error()) return nil, newRPCErrorWithClairError(codes.Internal, err)
} }
// NOTE(sidac): It's quite inefficient, but the easiest way to implement
// this feature for now, we should refactor the implementation if there's
// any performance issue. It's expected that the number of features is less
// than 1000.
for _, feature := range affectedFeatures { for _, feature := range affectedFeatures {
if !feature.Valid { if !feature.Valid {
return nil, status.Error(codes.Internal, "ancestry feature is not found") panic("feature is missing in the database, it indicates the database is corrupted.")
} }
for _, detectedFeature := range layer.Features { for _, detectedFeature := range layer.Features {

View File

@ -20,6 +20,8 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
"github.com/deckarep/golang-set" "github.com/deckarep/golang-set"
) )
@ -400,3 +402,125 @@ func PersistDetectorsAndCommit(store Datastore, detectors []Detector) error {
return nil return nil
} }
// MarkNotificationAsReadAndCommit marks a notification as read.
func MarkNotificationAsReadAndCommit(store Datastore, name string) (bool, error) {
tx, err := store.Begin()
if err != nil {
return false, err
}
defer tx.Rollback()
err = tx.DeleteNotification(name)
if err == commonerr.ErrNotFound {
return false, nil
} else if err != nil {
return false, err
}
if err := tx.Commit(); err != nil {
return false, err
}
return true, nil
}
// FindAffectedNamespacedFeaturesAndRollback finds the vulnerabilities on each
// feature.
func FindAffectedNamespacedFeaturesAndRollback(store Datastore, features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
nullableFeatures, err := tx.FindAffectedNamespacedFeatures(features)
if err != nil {
return nil, err
}
return nullableFeatures, nil
}
// FindVulnerabilityNotificationAndRollback finds the vulnerability notification
// and rollback.
func FindVulnerabilityNotificationAndRollback(store Datastore, name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (VulnerabilityNotificationWithVulnerable, bool, error) {
tx, err := store.Begin()
if err != nil {
return VulnerabilityNotificationWithVulnerable{}, false, err
}
defer tx.Rollback()
return tx.FindVulnerabilityNotification(name, limit, oldVulnerabilityPage, newVulnerabilityPage)
}
// FindNewNotification finds notifications either never notified or notified
// before the given time.
func FindNewNotification(store Datastore, notifiedBefore time.Time) (NotificationHook, bool, error) {
tx, err := store.Begin()
if err != nil {
return NotificationHook{}, false, err
}
defer tx.Rollback()
return tx.FindNewNotification(notifiedBefore)
}
// UpdateKeyValueAndCommit stores the key value to storage.
func UpdateKeyValueAndCommit(store Datastore, key, value string) error {
tx, err := store.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err = tx.UpdateKeyValue(key, value); err != nil {
return err
}
return tx.Commit()
}
// InsertVulnerabilityNotificationsAndCommit inserts the notifications into db
// and commit.
func InsertVulnerabilityNotificationsAndCommit(store Datastore, notifications []VulnerabilityNotification) error {
tx, err := store.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := tx.InsertVulnerabilityNotifications(notifications); err != nil {
return err
}
return tx.Commit()
}
// FindVulnerabilitiesAndRollback finds the vulnerabilities based on given ids.
func FindVulnerabilitiesAndRollback(store Datastore, ids []VulnerabilityID) ([]NullableVulnerability, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
return tx.FindVulnerabilities(ids)
}
func UpdateVulnerabilitiesAndCommit(store Datastore, toRemove []VulnerabilityID, toAdd []VulnerabilityWithAffected) error {
tx, err := store.Begin()
if err != nil {
return err
}
if err := tx.DeleteVulnerabilities(toRemove); err != nil {
return err
}
if err := tx.InsertVulnerabilities(toAdd); err != nil {
return err
}
return tx.Commit()
}

View File

@ -127,16 +127,10 @@ func init() {
func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) { func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) {
log.WithField("package", u.Name).Info("Start fetching vulnerabilities") log.WithField("package", u.Name).Info("Start fetching vulnerabilities")
tx, err := datastore.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()
// openSUSE and SUSE have one single xml file for all the products, there are no incremental // openSUSE and SUSE have one single xml file for all the products, there are no incremental
// xml files. We store into the database the value of the generation timestamp // xml files. We store into the database the value of the generation timestamp
// of the latest file we parsed. // of the latest file we parsed.
flagValue, ok, err := tx.FindKeyValue(u.UpdaterFlag) flagValue, ok, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag)
if err != nil { if err != nil {
return resp, err return resp, err
} }

View File

@ -94,13 +94,6 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er
return resp, err return resp, err
} }
// Open a database transaction.
tx, err := db.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()
// Ask the database for the latest commit we successfully applied. // Ask the database for the latest commit we successfully applied.
dbCommit, ok, err := database.FindKeyValueAndRollback(db, updaterFlag) dbCommit, ok, err := database.FindKeyValueAndRollback(db, updaterFlag)
if err != nil { if err != nil {

View File

@ -93,7 +93,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
go func() { go func() {
success, interrupted := handleTask(*notification, stopper, config.Attempts) success, interrupted := handleTask(*notification, stopper, config.Attempts)
if success { if success {
err := markNotificationAsRead(datastore, notification.Name) _, err := database.MarkNotificationAsReadAndCommit(datastore, notification.Name)
if err != nil { if err != nil {
log.WithError(err).Error("Failed to mark notification notified") log.WithError(err).Error("Failed to mark notification notified")
} }
@ -126,7 +126,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook { func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook {
for { for {
notification, ok, err := findNewNotification(datastore, renotifyInterval) notification, ok, err := database.FindNewNotification(datastore, time.Now().Add(-renotifyInterval))
if err != nil || !ok { if err != nil || !ok {
if !ok { if !ok {
log.WithError(err).Warning("could not get notification to send") log.WithError(err).Warning("could not get notification to send")
@ -186,25 +186,3 @@ func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts in
log.WithField(logNotiName, n.Name).Info("successfully sent notification") log.WithField(logNotiName, n.Name).Info("successfully sent notification")
return true, false return true, false
} }
func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) {
tx, err := datastore.Begin()
if err != nil {
return database.NotificationHook{}, false, err
}
defer tx.Rollback()
return tx.FindNewNotification(time.Now().Add(-renotifyInterval))
}
func markNotificationAsRead(datastore database.Datastore, name string) error {
tx, err := datastore.Begin()
if err != nil {
log.WithError(err).Error("an error happens when beginning database transaction")
}
defer tx.Rollback()
if err := tx.MarkNotificationAsRead(name); err != nil {
return err
}
return tx.Commit()
}

View File

@ -431,13 +431,7 @@ func addMetadata(ctx context.Context, datastore database.Datastore, vulnerabilit
// GetLastUpdateTime retrieves the latest successful time of update and whether // GetLastUpdateTime retrieves the latest successful time of update and whether
// or not it's the first update. // or not it's the first update.
func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) { func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) {
tx, err := datastore.Begin() lastUpdateTSS, ok, err := database.FindKeyValueAndRollback(datastore, updaterLastFlagName)
if err != nil {
return time.Time{}, false, err
}
defer tx.Rollback()
lastUpdateTSS, ok, err := tx.FindKeyValue(updaterLastFlagName)
if err != nil { if err != nil {
return time.Time{}, false, err return time.Time{}, false, err
} }
@ -449,7 +443,7 @@ func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) {
lastUpdateTS, err := strconv.ParseInt(lastUpdateTSS, 10, 64) lastUpdateTS, err := strconv.ParseInt(lastUpdateTSS, 10, 64)
if err != nil { if err != nil {
return time.Time{}, false, err panic(err)
} }
return time.Unix(lastUpdateTS, 0).UTC(), false, nil return time.Unix(lastUpdateTS, 0).UTC(), false, nil
@ -539,40 +533,19 @@ func doVulnerabilitiesNamespacing(vulnerabilities []database.VulnerabilityWithAf
return response return response
} }
// updateUpdaterFlags updates the flags specified by updaters, every transaction
// is independent of each other.
func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error { func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error {
for key, value := range flags { for key, value := range flags {
tx, err := datastore.Begin() if err := database.UpdateKeyValueAndCommit(datastore, key, value); err != nil {
if err != nil {
return err
}
defer tx.Rollback()
err = tx.UpdateKeyValue(key, value)
if err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err return err
} }
} }
return nil return nil
} }
// setLastUpdateTime records the last successful date time in database. // setLastUpdateTime records the last successful date time in database.
func setLastUpdateTime(datastore database.Datastore) error { func setLastUpdateTime(datastore database.Datastore) error {
tx, err := datastore.Begin() return database.UpdateKeyValueAndCommit(datastore, updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10))
if err != nil {
return err
}
defer tx.Rollback()
err = tx.UpdateKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10))
if err != nil {
return err
}
return tx.Commit()
} }
// isVulnerabilityChange compares two vulnerabilities by their severity and // isVulnerabilityChange compares two vulnerabilities by their severity and
@ -648,12 +621,6 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu
return nil return nil
} }
tx, err := datastore.Begin()
if err != nil {
return err
}
defer tx.Rollback()
notifications := make([]database.VulnerabilityNotification, 0, len(changes)) notifications := make([]database.VulnerabilityNotification, 0, len(changes))
for _, change := range changes { for _, change := range changes {
var oldVuln, newVuln *database.Vulnerability var oldVuln, newVuln *database.Vulnerability
@ -675,11 +642,7 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu
}) })
} }
if err := tx.InsertVulnerabilityNotifications(notifications); err != nil { return database.InsertVulnerabilityNotificationsAndCommit(datastore, notifications)
return err
}
return tx.Commit()
} }
// updateVulnerabilities upserts unique vulnerabilities into the database and // updateVulnerabilities upserts unique vulnerabilities into the database and
@ -698,13 +661,7 @@ func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vu
}) })
} }
tx, err := datastore.Begin() oldVulnNullable, err := database.FindVulnerabilitiesAndRollback(datastore, ids)
if err != nil {
return nil, err
}
defer tx.Rollback()
oldVulnNullable, err := tx.FindVulnerabilities(ids)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -748,21 +705,8 @@ func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vu
} }
} }
log.WithField("count", len(toRemove)).Debug("marking vulnerabilities as outdated") log.Debugf("there are %d vulnerability changes", len(changes))
if err := tx.DeleteVulnerabilities(toRemove); err != nil { return changes, database.UpdateVulnerabilitiesAndCommit(datastore, toRemove, toAdd)
return nil, err
}
log.WithField("count", len(toAdd)).Debug("inserting new vulnerabilities")
if err := tx.InsertVulnerabilities(toAdd); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return changes, nil
} }
func updaterEnabled(updaterName string) bool { func updaterEnabled(updaterName string) bool {