database: Move db logic to dbutil

Move all transaction related logic to dbutil to simplify and later unify
the db interface.
This commit is contained in:
Sida Chen 2019-03-04 18:55:14 -05:00
parent 4fa03d1c78
commit 1b9ed99646
7 changed files with 152 additions and 146 deletions

View File

@ -15,8 +15,6 @@
package v3
import (
"fmt"
"golang.org/x/net/context"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@ -25,7 +23,6 @@ import (
pb "github.com/coreos/clair/api/v3/clairpb"
"github.com/coreos/clair/database"
"github.com/coreos/clair/ext/imagefmt"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
)
@ -128,20 +125,13 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
return nil, status.Errorf(codes.InvalidArgument, "ancestry name should not be empty")
}
tx, err := s.Store.Begin()
ancestry, ok, err := database.FindAncestryAndRollback(s.Store, name)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
defer tx.Rollback()
ancestry, ok, err := tx.FindAncestry(name)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}
if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested ancestry '%s' is not found", req.GetAncestryName()))
return nil, status.Errorf(codes.NotFound, "requested ancestry '%s' is not found", req.GetAncestryName())
}
pbAncestry := &pb.GetAncestryResponse_Ancestry{
@ -150,7 +140,7 @@ func (s *AncestryServer) GetAncestry(ctx context.Context, req *pb.GetAncestryReq
}
for _, layer := range ancestry.Layers {
pbLayer, err := GetPbAncestryLayer(tx, layer)
pbLayer, err := s.GetPbAncestryLayer(layer)
if err != nil {
return nil, err
}
@ -180,13 +170,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
return nil, status.Error(codes.InvalidArgument, "notification page limit should not be empty or less than 1")
}
tx, err := s.Store.Begin()
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
defer tx.Rollback()
dbNotification, ok, err := tx.FindVulnerabilityNotification(
dbNotification, ok, err := database.FindVulnerabilityNotificationAndRollback(
s.Store,
req.GetName(),
int(req.GetLimit()),
pagination.Token(req.GetOldVulnerabilityPage()),
@ -194,11 +179,11 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}
if !ok {
return nil, status.Error(codes.NotFound, fmt.Sprintf("requested notification '%s' is not found", req.GetName()))
return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
}
notification, err := pb.NotificationFromDatabaseModel(dbNotification)
@ -216,21 +201,13 @@ func (s *NotificationServer) MarkNotificationAsRead(ctx context.Context, req *pb
return nil, status.Error(codes.InvalidArgument, "notification name should not be empty")
}
tx, err := s.Store.Begin()
found, err := database.MarkNotificationAsReadAndCommit(s.Store, req.GetName())
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}
defer tx.Rollback()
err = tx.DeleteNotification(req.GetName())
if err == commonerr.ErrNotFound {
return nil, status.Error(codes.NotFound, "requested notification \""+req.GetName()+"\" is not found")
} else if err != nil {
return nil, status.Error(codes.Internal, err.Error())
}
if err := tx.Commit(); err != nil {
return nil, status.Error(codes.Internal, err.Error())
if !found {
return nil, status.Errorf(codes.NotFound, "requested notification '%s' is not found", req.GetName())
}
return &pb.MarkNotificationAsReadResponse{}, nil

View File

@ -33,7 +33,7 @@ func GetClairStatus(store database.Datastore) (*pb.ClairStatus, error) {
// GetPbAncestryLayer retrieves an ancestry layer with vulnerabilities and
// features in an ancestry based on the provided database layer.
func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
func (s *AncestryServer) GetPbAncestryLayer(layer database.AncestryLayer) (*pb.GetAncestryResponse_AncestryLayer, error) {
pbLayer := &pb.GetAncestryResponse_AncestryLayer{
Layer: &pb.Layer{
Hash: layer.Hash,
@ -41,18 +41,14 @@ func GetPbAncestryLayer(tx database.Session, layer database.AncestryLayer) (*pb.
}
features := layer.GetFeatures()
affectedFeatures, err := tx.FindAffectedNamespacedFeatures(features)
affectedFeatures, err := database.FindAffectedNamespacedFeaturesAndRollback(s.Store, features)
if err != nil {
return nil, status.Error(codes.Internal, err.Error())
return nil, newRPCErrorWithClairError(codes.Internal, err)
}
// NOTE(sidac): It's quite inefficient, but the easiest way to implement
// this feature for now, we should refactor the implementation if there's
// any performance issue. It's expected that the number of features is less
// than 1000.
for _, feature := range affectedFeatures {
if !feature.Valid {
return nil, status.Error(codes.Internal, "ancestry feature is not found")
panic("feature is missing in the database, it indicates the database is corrupted.")
}
for _, detectedFeature := range layer.Features {

View File

@ -20,6 +20,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
"github.com/deckarep/golang-set"
)
@ -400,3 +402,125 @@ func PersistDetectorsAndCommit(store Datastore, detectors []Detector) error {
return nil
}
// MarkNotificationAsReadAndCommit marks a notification as read.
func MarkNotificationAsReadAndCommit(store Datastore, name string) (bool, error) {
tx, err := store.Begin()
if err != nil {
return false, err
}
defer tx.Rollback()
err = tx.DeleteNotification(name)
if err == commonerr.ErrNotFound {
return false, nil
} else if err != nil {
return false, err
}
if err := tx.Commit(); err != nil {
return false, err
}
return true, nil
}
// FindAffectedNamespacedFeaturesAndRollback finds the vulnerabilities on each
// feature.
func FindAffectedNamespacedFeaturesAndRollback(store Datastore, features []NamespacedFeature) ([]NullableAffectedNamespacedFeature, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
nullableFeatures, err := tx.FindAffectedNamespacedFeatures(features)
if err != nil {
return nil, err
}
return nullableFeatures, nil
}
// FindVulnerabilityNotificationAndRollback finds the vulnerability notification
// and rollback.
func FindVulnerabilityNotificationAndRollback(store Datastore, name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (VulnerabilityNotificationWithVulnerable, bool, error) {
tx, err := store.Begin()
if err != nil {
return VulnerabilityNotificationWithVulnerable{}, false, err
}
defer tx.Rollback()
return tx.FindVulnerabilityNotification(name, limit, oldVulnerabilityPage, newVulnerabilityPage)
}
// FindNewNotification finds notifications either never notified or notified
// before the given time.
func FindNewNotification(store Datastore, notifiedBefore time.Time) (NotificationHook, bool, error) {
tx, err := store.Begin()
if err != nil {
return NotificationHook{}, false, err
}
defer tx.Rollback()
return tx.FindNewNotification(notifiedBefore)
}
// UpdateKeyValueAndCommit stores the key value to storage.
func UpdateKeyValueAndCommit(store Datastore, key, value string) error {
tx, err := store.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err = tx.UpdateKeyValue(key, value); err != nil {
return err
}
return tx.Commit()
}
// InsertVulnerabilityNotificationsAndCommit inserts the notifications into db
// and commit.
func InsertVulnerabilityNotificationsAndCommit(store Datastore, notifications []VulnerabilityNotification) error {
tx, err := store.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if err := tx.InsertVulnerabilityNotifications(notifications); err != nil {
return err
}
return tx.Commit()
}
// FindVulnerabilitiesAndRollback finds the vulnerabilities based on given ids.
func FindVulnerabilitiesAndRollback(store Datastore, ids []VulnerabilityID) ([]NullableVulnerability, error) {
tx, err := store.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
return tx.FindVulnerabilities(ids)
}
func UpdateVulnerabilitiesAndCommit(store Datastore, toRemove []VulnerabilityID, toAdd []VulnerabilityWithAffected) error {
tx, err := store.Begin()
if err != nil {
return err
}
if err := tx.DeleteVulnerabilities(toRemove); err != nil {
return err
}
if err := tx.InsertVulnerabilities(toAdd); err != nil {
return err
}
return tx.Commit()
}

View File

@ -127,16 +127,10 @@ func init() {
func (u *updater) Update(datastore database.Datastore) (resp vulnsrc.UpdateResponse, err error) {
log.WithField("package", u.Name).Info("Start fetching vulnerabilities")
tx, err := datastore.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()
// openSUSE and SUSE have one single xml file for all the products, there are no incremental
// xml files. We store into the database the value of the generation timestamp
// of the latest file we parsed.
flagValue, ok, err := tx.FindKeyValue(u.UpdaterFlag)
flagValue, ok, err := database.FindKeyValueAndRollback(datastore, u.UpdaterFlag)
if err != nil {
return resp, err
}

View File

@ -94,13 +94,6 @@ func (u *updater) Update(db database.Datastore) (resp vulnsrc.UpdateResponse, er
return resp, err
}
// Open a database transaction.
tx, err := db.Begin()
if err != nil {
return resp, err
}
defer tx.Rollback()
// Ask the database for the latest commit we successfully applied.
dbCommit, ok, err := database.FindKeyValueAndRollback(db, updaterFlag)
if err != nil {

View File

@ -93,7 +93,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
go func() {
success, interrupted := handleTask(*notification, stopper, config.Attempts)
if success {
err := markNotificationAsRead(datastore, notification.Name)
_, err := database.MarkNotificationAsReadAndCommit(datastore, notification.Name)
if err != nil {
log.WithError(err).Error("Failed to mark notification notified")
}
@ -126,7 +126,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoAmI string, stopper *stopper.Stopper) *database.NotificationHook {
for {
notification, ok, err := findNewNotification(datastore, renotifyInterval)
notification, ok, err := database.FindNewNotification(datastore, time.Now().Add(-renotifyInterval))
if err != nil || !ok {
if !ok {
log.WithError(err).Warning("could not get notification to send")
@ -186,25 +186,3 @@ func handleTask(n database.NotificationHook, st *stopper.Stopper, maxAttempts in
log.WithField(logNotiName, n.Name).Info("successfully sent notification")
return true, false
}
func findNewNotification(datastore database.Datastore, renotifyInterval time.Duration) (database.NotificationHook, bool, error) {
tx, err := datastore.Begin()
if err != nil {
return database.NotificationHook{}, false, err
}
defer tx.Rollback()
return tx.FindNewNotification(time.Now().Add(-renotifyInterval))
}
func markNotificationAsRead(datastore database.Datastore, name string) error {
tx, err := datastore.Begin()
if err != nil {
log.WithError(err).Error("an error happens when beginning database transaction")
}
defer tx.Rollback()
if err := tx.MarkNotificationAsRead(name); err != nil {
return err
}
return tx.Commit()
}

View File

@ -431,13 +431,7 @@ func addMetadata(ctx context.Context, datastore database.Datastore, vulnerabilit
// GetLastUpdateTime retrieves the latest successful time of update and whether
// or not it's the first update.
func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) {
tx, err := datastore.Begin()
if err != nil {
return time.Time{}, false, err
}
defer tx.Rollback()
lastUpdateTSS, ok, err := tx.FindKeyValue(updaterLastFlagName)
lastUpdateTSS, ok, err := database.FindKeyValueAndRollback(datastore, updaterLastFlagName)
if err != nil {
return time.Time{}, false, err
}
@ -449,7 +443,7 @@ func GetLastUpdateTime(datastore database.Datastore) (time.Time, bool, error) {
lastUpdateTS, err := strconv.ParseInt(lastUpdateTSS, 10, 64)
if err != nil {
return time.Time{}, false, err
panic(err)
}
return time.Unix(lastUpdateTS, 0).UTC(), false, nil
@ -539,40 +533,19 @@ func doVulnerabilitiesNamespacing(vulnerabilities []database.VulnerabilityWithAf
return response
}
// updateUpdaterFlags updates the flags specified by updaters, every transaction
// is independent of each other.
func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error {
for key, value := range flags {
tx, err := datastore.Begin()
if err != nil {
if err := database.UpdateKeyValueAndCommit(datastore, key, value); err != nil {
return err
}
defer tx.Rollback()
}
err = tx.UpdateKeyValue(key, value)
if err != nil {
return err
}
if err = tx.Commit(); err != nil {
return err
}
}
return nil
}
// setLastUpdateTime records the last successful date time in database.
func setLastUpdateTime(datastore database.Datastore) error {
tx, err := datastore.Begin()
if err != nil {
return err
}
defer tx.Rollback()
err = tx.UpdateKeyValue(updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10))
if err != nil {
return err
}
return tx.Commit()
return database.UpdateKeyValueAndCommit(datastore, updaterLastFlagName, strconv.FormatInt(time.Now().UTC().Unix(), 10))
}
// isVulnerabilityChange compares two vulnerabilities by their severity and
@ -648,12 +621,6 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu
return nil
}
tx, err := datastore.Begin()
if err != nil {
return err
}
defer tx.Rollback()
notifications := make([]database.VulnerabilityNotification, 0, len(changes))
for _, change := range changes {
var oldVuln, newVuln *database.Vulnerability
@ -675,11 +642,7 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu
})
}
if err := tx.InsertVulnerabilityNotifications(notifications); err != nil {
return err
}
return tx.Commit()
return database.InsertVulnerabilityNotificationsAndCommit(datastore, notifications)
}
// updateVulnerabilities upserts unique vulnerabilities into the database and
@ -698,13 +661,7 @@ func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vu
})
}
tx, err := datastore.Begin()
if err != nil {
return nil, err
}
defer tx.Rollback()
oldVulnNullable, err := tx.FindVulnerabilities(ids)
oldVulnNullable, err := database.FindVulnerabilitiesAndRollback(datastore, ids)
if err != nil {
return nil, err
}
@ -748,21 +705,8 @@ func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vu
}
}
log.WithField("count", len(toRemove)).Debug("marking vulnerabilities as outdated")
if err := tx.DeleteVulnerabilities(toRemove); err != nil {
return nil, err
}
log.WithField("count", len(toAdd)).Debug("inserting new vulnerabilities")
if err := tx.InsertVulnerabilities(toAdd); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return changes, nil
log.Debugf("there are %d vulnerability changes", len(changes))
return changes, database.UpdateVulnerabilitiesAndCommit(datastore, toRemove, toAdd)
}
func updaterEnabled(updaterName string) bool {