From dfa07f6d860c59ba2b2cc4909d38f650e9d3969b Mon Sep 17 00:00:00 2001 From: Sida Chen Date: Wed, 6 Mar 2019 16:35:35 -0500 Subject: [PATCH] pgsql: Move notification to its module --- .../{ => notification}/notification_test.go | 111 ++++++++------- .../vulnerability_notification.go} | 134 +++++------------- 2 files changed, 94 insertions(+), 151 deletions(-) rename database/pgsql/{ => notification}/notification_test.go (61%) rename database/pgsql/{notification.go => notification/vulnerability_notification.go} (60%) diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification/notification_test.go similarity index 61% rename from database/pgsql/notification_test.go rename to database/pgsql/notification/notification_test.go index da3b3248..bc9d2acc 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification/notification_test.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package notification import ( "testing" @@ -22,6 +22,8 @@ import ( "github.com/stretchr/testify/require" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/page" + "github.com/coreos/clair/database/pgsql/testutil" "github.com/coreos/clair/pkg/pagination" ) @@ -38,6 +40,8 @@ type findVulnerabilityNotificationOut struct { err string } +var testPaginationKey = pagination.Must(pagination.NewKey()) + var findVulnerabilityNotificationTests = []struct { title string in findVulnerabilityNotificationIn @@ -77,21 +81,21 @@ var findVulnerabilityNotificationTests = []struct { }, out: findVulnerabilityNotificationOut{ &database.VulnerabilityNotificationWithVulnerable{ - NotificationHook: realNotification[1].NotificationHook, + NotificationHook: testutil.RealNotification[1].NotificationHook, Old: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[2], + Vulnerability: testutil.RealVulnerability[2], Limit: 1, Affected: make(map[int]string), - Current: mustMarshalToken(testPaginationKey, Page{0}), - Next: mustMarshalToken(testPaginationKey, Page{0}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), End: true, }, New: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[1], + Vulnerability: testutil.RealVulnerability[1], Limit: 1, Affected: map[int]string{3: "ancestry-3"}, - Current: mustMarshalToken(testPaginationKey, Page{0}), - Next: mustMarshalToken(testPaginationKey, Page{4}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}), End: false, }, }, @@ -100,32 +104,31 @@ var findVulnerabilityNotificationTests = []struct { "", }, }, - { title: "find existing notification of second page of new affected ancestry", in: findVulnerabilityNotificationIn{ notificationName: "test", pageSize: 1, oldAffectedAncestryPage: pagination.FirstPageToken, - newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}), + newAffectedAncestryPage: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}), }, out: findVulnerabilityNotificationOut{ &database.VulnerabilityNotificationWithVulnerable{ - NotificationHook: realNotification[1].NotificationHook, + NotificationHook: testutil.RealNotification[1].NotificationHook, Old: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[2], + Vulnerability: testutil.RealVulnerability[2], Limit: 1, Affected: make(map[int]string), - Current: mustMarshalToken(testPaginationKey, Page{0}), - Next: mustMarshalToken(testPaginationKey, Page{0}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), End: true, }, New: &database.PagedVulnerableAncestries{ - Vulnerability: realVulnerability[1], + Vulnerability: testutil.RealVulnerability[1], Limit: 1, Affected: map[int]string{4: "ancestry-4"}, - Current: mustMarshalToken(testPaginationKey, Page{4}), - Next: mustMarshalToken(testPaginationKey, Page{0}), + Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}), + Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}), End: true, }, }, @@ -137,12 +140,12 @@ var findVulnerabilityNotificationTests = []struct { } func TestFindVulnerabilityNotification(t *testing.T) { - datastore, tx := openSessionForTest(t, "pagination", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "pagination") + defer cleanup() for _, test := range findVulnerabilityNotificationTests { t.Run(test.title, func(t *testing.T) { - notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage) + notification, ok, err := FindVulnerabilityNotification(tx, test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage, testutil.TestPaginationKey) if test.out.err != "" { require.EqualError(t, err, test.out.err) return @@ -155,13 +158,14 @@ func TestFindVulnerabilityNotification(t *testing.T) { } require.True(t, ok) - assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, ¬ification) + testutil.AssertVulnerabilityNotificationWithVulnerableEqual(t, testutil.TestPaginationKey, test.out.notification, ¬ification) }) } } func TestInsertVulnerabilityNotifications(t *testing.T) { - datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true) + datastore, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilityNotifications") + defer cleanup() n1 := database.VulnerabilityNotification{} n3 := database.VulnerabilityNotification{ @@ -187,34 +191,37 @@ func TestInsertVulnerabilityNotifications(t *testing.T) { }, } + tx, err := datastore.Begin() + require.Nil(t, err) + // invalid case - err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1}) - assert.NotNil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n1}) + require.NotNil(t, err) // invalid case: unknown vulnerability - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3}) - assert.NotNil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n3}) + require.NotNil(t, err) // invalid case: duplicated input notification - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4}) - assert.NotNil(t, err) - tx = restartSession(t, datastore, tx, false) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4, n4}) + require.NotNil(t, err) + tx = testutil.RestartTransaction(datastore, tx, false) // valid case - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) - assert.Nil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4}) + require.Nil(t, err) // invalid case: notification is already in database - err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4}) - assert.NotNil(t, err) + err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4}) + require.NotNil(t, err) - closeTest(t, datastore, tx) + require.Nil(t, tx.Rollback()) } func TestFindNewNotification(t *testing.T) { - tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification") + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNewNotification") defer cleanup() - noti, ok, err := tx.FindNewNotification(time.Now()) + noti, ok, err := FindNewNotification(tx, time.Now()) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) assert.Equal(t, time.Time{}, noti.Notified) @@ -223,13 +230,13 @@ func TestFindNewNotification(t *testing.T) { } // can't find the notified - assert.Nil(t, tx.MarkNotificationAsRead("test")) + assert.Nil(t, MarkNotificationAsRead(tx, "test")) // if the notified time is before - noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(10*time.Second))) assert.Nil(t, err) assert.False(t, ok) // can find the notified after a period of time - noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(10*time.Second))) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) assert.NotEqual(t, time.Time{}, noti.Notified) @@ -237,37 +244,37 @@ func TestFindNewNotification(t *testing.T) { assert.Equal(t, time.Time{}, noti.Deleted) } - assert.Nil(t, tx.DeleteNotification("test")) + assert.Nil(t, DeleteNotification(tx, "test")) // can't find in any time - noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(1000))) assert.Nil(t, err) assert.False(t, ok) - noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000))) + noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(1000))) assert.Nil(t, err) assert.False(t, ok) } func TestMarkNotificationAsRead(t *testing.T) { - datastore, tx := openSessionForTest(t, "MarkNotificationAsRead", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "MarkNotificationAsRead") + defer cleanup() // invalid case: notification doesn't exist - assert.NotNil(t, tx.MarkNotificationAsRead("non-existing")) + assert.NotNil(t, MarkNotificationAsRead(tx, "non-existing")) // valid case - assert.Nil(t, tx.MarkNotificationAsRead("test")) + assert.Nil(t, MarkNotificationAsRead(tx, "test")) // valid case - assert.Nil(t, tx.MarkNotificationAsRead("test")) + assert.Nil(t, MarkNotificationAsRead(tx, "test")) } func TestDeleteNotification(t *testing.T) { - datastore, tx := openSessionForTest(t, "DeleteNotification", true) - defer closeTest(t, datastore, tx) + tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteNotification") + defer cleanup() // invalid case: notification doesn't exist - assert.NotNil(t, tx.DeleteNotification("non-existing")) + assert.NotNil(t, DeleteNotification(tx, "non-existing")) // valid case - assert.Nil(t, tx.DeleteNotification("test")) + assert.Nil(t, DeleteNotification(tx, "test")) // invalid case: notification is already deleted - assert.NotNil(t, tx.DeleteNotification("test")) + assert.NotNil(t, DeleteNotification(tx, "test")) } diff --git a/database/pgsql/notification.go b/database/pgsql/notification/vulnerability_notification.go similarity index 60% rename from database/pgsql/notification.go rename to database/pgsql/notification/vulnerability_notification.go index 7d2b750d..e0ac3a1c 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification/vulnerability_notification.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package pgsql +package notification import ( "database/sql" @@ -22,6 +22,8 @@ import ( "github.com/guregu/null/zero" "github.com/coreos/clair/database" + "github.com/coreos/clair/database/pgsql/util" + "github.com/coreos/clair/database/pgsql/vulnerability" "github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/pagination" ) @@ -54,26 +56,24 @@ const ( SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id FROM Vulnerability_Notification WHERE name = $1` - - searchNotificationVulnerableAncestry = ` - SELECT DISTINCT ON (a.id) - a.id, a.name - FROM vulnerability_affected_namespaced_feature AS vanf, - ancestry_layer AS al, ancestry_feature AS af, ancestry AS a - WHERE vanf.vulnerability_id = $1 - AND a.id >= $2 - AND al.ancestry_id = a.id - AND al.id = af.ancestry_layer_id - AND af.namespaced_feature_id = vanf.namespaced_feature_id - ORDER BY a.id ASC - LIMIT $3;` ) +func queryInsertNotifications(count int) string { + return util.QueryInsert(count, + "vulnerability_notification", + "name", + "created_at", + "old_vulnerability_id", + "new_vulnerability_id", + ) +} + var ( - errNotificationNotFound = errors.New("requested notification is not found") + errNotificationNotFound = errors.New("requested notification is not found") + errVulnerabilityNotFound = errors.New("vulnerability is not in database") ) -func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error { +func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error { if len(notifications) == 0 { return nil } @@ -122,26 +122,26 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V oldVulnIDs = append(oldVulnIDs, vulnID) } - ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs) + ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs) if err != nil { return err } for i, id := range ids { if !id.Valid { - return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) + return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound) } newVulnIDMap[newVulnIDs[i]] = id } - ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs) + ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs) if err != nil { return err } for i, id := range ids { if !id.Valid { - return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) + return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound) } oldVulnIDMap[oldVulnIDs[i]] = id } @@ -178,13 +178,13 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V // multiple updaters, deadlock may happen. _, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...) if err != nil { - return handleError("queryInsertNotifications", err) + return util.HandleError("queryInsertNotifications", err) } return nil } -func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) { +func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) { var ( notification database.NotificationHook created zero.Time @@ -197,7 +197,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not if err == sql.ErrNoRows { return notification, false, nil } - return notification, false, handleError("searchNotificationAvailable", err) + return notification, false, util.HandleError("searchNotificationAvailable", err) } notification.Created = created.Time @@ -207,71 +207,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not return notification, true, nil } -func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) { - vulnPage := database.PagedVulnerableAncestries{Limit: limit} - currentPage := Page{0} - if currentToken != pagination.FirstPageToken { - if err := tx.key.UnmarshalToken(currentToken, ¤tPage); err != nil { - return vulnPage, err - } - } - - if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan( - &vulnPage.Name, - &vulnPage.Description, - &vulnPage.Link, - &vulnPage.Severity, - &vulnPage.Metadata, - &vulnPage.Namespace.Name, - &vulnPage.Namespace.VersionFormat, - ); err != nil { - return vulnPage, handleError("searchVulnerabilityByID", err) - } - - // the last result is used for the next page's startID - rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) - if err != nil { - return vulnPage, handleError("searchNotificationVulnerableAncestry", err) - } - defer rows.Close() - - ancestries := []affectedAncestry{} - for rows.Next() { - var ancestry affectedAncestry - err := rows.Scan(&ancestry.id, &ancestry.name) - if err != nil { - return vulnPage, handleError("searchNotificationVulnerableAncestry", err) - } - ancestries = append(ancestries, ancestry) - } - - lastIndex := 0 - if len(ancestries)-1 < limit { - lastIndex = len(ancestries) - vulnPage.End = true - } else { - // Use the last ancestry's ID as the next page. - lastIndex = len(ancestries) - 1 - vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id}) - if err != nil { - return vulnPage, err - } - } - - vulnPage.Affected = map[int]string{} - for _, ancestry := range ancestries[0:lastIndex] { - vulnPage.Affected[int(ancestry.id)] = ancestry.name - } - - vulnPage.Current, err = tx.key.MarshalToken(currentPage) - if err != nil { - return vulnPage, err - } - - return vulnPage, nil -} - -func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) ( +func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) ( database.VulnerabilityNotificationWithVulnerable, bool, error) { var ( noti database.VulnerabilityNotificationWithVulnerable @@ -294,7 +230,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa if err == sql.ErrNoRows { return noti, false, nil } - return noti, false, handleError("searchNotification", err) + return noti, false, util.HandleError("searchNotification", err) } if created.Valid { @@ -310,7 +246,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if oldVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken) + page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key) if err != nil { return noti, false, err } @@ -318,7 +254,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if newVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken) + page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key) if err != nil { return noti, false, err } @@ -328,44 +264,44 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa return noti, true, nil } -func (tx *pgSession) MarkNotificationAsRead(name string) error { +func MarkNotificationAsRead(tx *sql.Tx, name string) error { if name == "" { return commonerr.NewBadRequestError("Empty notification name is not allowed") } r, err := tx.Exec(updatedNotificationAsRead, name) if err != nil { - return handleError("updatedNotificationAsRead", err) + return util.HandleError("updatedNotificationAsRead", err) } affected, err := r.RowsAffected() if err != nil { - return handleError("updatedNotificationAsRead", err) + return util.HandleError("updatedNotificationAsRead", err) } if affected <= 0 { - return handleError("updatedNotificationAsRead", errNotificationNotFound) + return util.HandleError("updatedNotificationAsRead", errNotificationNotFound) } return nil } -func (tx *pgSession) DeleteNotification(name string) error { +func DeleteNotification(tx *sql.Tx, name string) error { if name == "" { return commonerr.NewBadRequestError("Empty notification name is not allowed") } result, err := tx.Exec(removeNotification, name) if err != nil { - return handleError("removeNotification", err) + return util.HandleError("removeNotification", err) } affected, err := result.RowsAffected() if err != nil { - return handleError("removeNotification", err) + return util.HandleError("removeNotification", err) } if affected <= 0 { - return handleError("removeNotification", commonerr.ErrNotFound) + return util.HandleError("removeNotification", commonerr.ErrNotFound) } return nil