pgsql: Move notification to its module

master
Sida Chen 5 years ago
parent 921acb26fe
commit dfa07f6d86

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
package notification
import (
"testing"
@ -22,6 +22,8 @@ import (
"github.com/stretchr/testify/require"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/page"
"github.com/coreos/clair/database/pgsql/testutil"
"github.com/coreos/clair/pkg/pagination"
)
@ -38,6 +40,8 @@ type findVulnerabilityNotificationOut struct {
err string
}
var testPaginationKey = pagination.Must(pagination.NewKey())
var findVulnerabilityNotificationTests = []struct {
title string
in findVulnerabilityNotificationIn
@ -77,21 +81,21 @@ var findVulnerabilityNotificationTests = []struct {
},
out: findVulnerabilityNotificationOut{
&database.VulnerabilityNotificationWithVulnerable{
NotificationHook: realNotification[1].NotificationHook,
NotificationHook: testutil.RealNotification[1].NotificationHook,
Old: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[2],
Vulnerability: testutil.RealVulnerability[2],
Limit: 1,
Affected: make(map[int]string),
Current: mustMarshalToken(testPaginationKey, Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{0}),
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
End: true,
},
New: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[1],
Vulnerability: testutil.RealVulnerability[1],
Limit: 1,
Affected: map[int]string{3: "ancestry-3"},
Current: mustMarshalToken(testPaginationKey, Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{4}),
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
End: false,
},
},
@ -100,32 +104,31 @@ var findVulnerabilityNotificationTests = []struct {
"",
},
},
{
title: "find existing notification of second page of new affected ancestry",
in: findVulnerabilityNotificationIn{
notificationName: "test",
pageSize: 1,
oldAffectedAncestryPage: pagination.FirstPageToken,
newAffectedAncestryPage: mustMarshalToken(testPaginationKey, Page{4}),
newAffectedAncestryPage: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
},
out: findVulnerabilityNotificationOut{
&database.VulnerabilityNotificationWithVulnerable{
NotificationHook: realNotification[1].NotificationHook,
NotificationHook: testutil.RealNotification[1].NotificationHook,
Old: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[2],
Vulnerability: testutil.RealVulnerability[2],
Limit: 1,
Affected: make(map[int]string),
Current: mustMarshalToken(testPaginationKey, Page{0}),
Next: mustMarshalToken(testPaginationKey, Page{0}),
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
End: true,
},
New: &database.PagedVulnerableAncestries{
Vulnerability: realVulnerability[1],
Vulnerability: testutil.RealVulnerability[1],
Limit: 1,
Affected: map[int]string{4: "ancestry-4"},
Current: mustMarshalToken(testPaginationKey, Page{4}),
Next: mustMarshalToken(testPaginationKey, Page{0}),
Current: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{4}),
Next: testutil.MustMarshalToken(testutil.TestPaginationKey, page.Page{0}),
End: true,
},
},
@ -137,12 +140,12 @@ var findVulnerabilityNotificationTests = []struct {
}
func TestFindVulnerabilityNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "pagination", true)
defer closeTest(t, datastore, tx)
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "pagination")
defer cleanup()
for _, test := range findVulnerabilityNotificationTests {
t.Run(test.title, func(t *testing.T) {
notification, ok, err := tx.FindVulnerabilityNotification(test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage)
notification, ok, err := FindVulnerabilityNotification(tx, test.in.notificationName, test.in.pageSize, test.in.oldAffectedAncestryPage, test.in.newAffectedAncestryPage, testutil.TestPaginationKey)
if test.out.err != "" {
require.EqualError(t, err, test.out.err)
return
@ -155,13 +158,14 @@ func TestFindVulnerabilityNotification(t *testing.T) {
}
require.True(t, ok)
assertVulnerabilityNotificationWithVulnerableEqual(t, testPaginationKey, test.out.notification, &notification)
testutil.AssertVulnerabilityNotificationWithVulnerableEqual(t, testutil.TestPaginationKey, test.out.notification, &notification)
})
}
}
func TestInsertVulnerabilityNotifications(t *testing.T) {
datastore, tx := openSessionForTest(t, "InsertVulnerabilityNotifications", true)
datastore, cleanup := testutil.CreateTestDBWithFixture(t, "InsertVulnerabilityNotifications")
defer cleanup()
n1 := database.VulnerabilityNotification{}
n3 := database.VulnerabilityNotification{
@ -187,34 +191,37 @@ func TestInsertVulnerabilityNotifications(t *testing.T) {
},
}
tx, err := datastore.Begin()
require.Nil(t, err)
// invalid case
err := tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n1})
assert.NotNil(t, err)
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n1})
require.NotNil(t, err)
// invalid case: unknown vulnerability
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n3})
assert.NotNil(t, err)
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n3})
require.NotNil(t, err)
// invalid case: duplicated input notification
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4, n4})
assert.NotNil(t, err)
tx = restartSession(t, datastore, tx, false)
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4, n4})
require.NotNil(t, err)
tx = testutil.RestartTransaction(datastore, tx, false)
// valid case
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
assert.Nil(t, err)
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4})
require.Nil(t, err)
// invalid case: notification is already in database
err = tx.InsertVulnerabilityNotifications([]database.VulnerabilityNotification{n4})
assert.NotNil(t, err)
err = InsertVulnerabilityNotifications(tx, []database.VulnerabilityNotification{n4})
require.NotNil(t, err)
closeTest(t, datastore, tx)
require.Nil(t, tx.Rollback())
}
func TestFindNewNotification(t *testing.T) {
tx, cleanup := createTestPgSessionWithFixtures(t, "TestFindNewNotification")
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "TestFindNewNotification")
defer cleanup()
noti, ok, err := tx.FindNewNotification(time.Now())
noti, ok, err := FindNewNotification(tx, time.Now())
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
assert.Equal(t, time.Time{}, noti.Notified)
@ -223,13 +230,13 @@ func TestFindNewNotification(t *testing.T) {
}
// can't find the notified
assert.Nil(t, tx.MarkNotificationAsRead("test"))
assert.Nil(t, MarkNotificationAsRead(tx, "test"))
// if the notified time is before
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(10 * time.Second)))
noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(10*time.Second)))
assert.Nil(t, err)
assert.False(t, ok)
// can find the notified after a period of time
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(10 * time.Second)))
noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(10*time.Second)))
if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name)
assert.NotEqual(t, time.Time{}, noti.Notified)
@ -237,37 +244,37 @@ func TestFindNewNotification(t *testing.T) {
assert.Equal(t, time.Time{}, noti.Deleted)
}
assert.Nil(t, tx.DeleteNotification("test"))
assert.Nil(t, DeleteNotification(tx, "test"))
// can't find in any time
noti, ok, err = tx.FindNewNotification(time.Now().Add(-time.Duration(1000)))
noti, ok, err = FindNewNotification(tx, time.Now().Add(-time.Duration(1000)))
assert.Nil(t, err)
assert.False(t, ok)
noti, ok, err = tx.FindNewNotification(time.Now().Add(time.Duration(1000)))
noti, ok, err = FindNewNotification(tx, time.Now().Add(time.Duration(1000)))
assert.Nil(t, err)
assert.False(t, ok)
}
func TestMarkNotificationAsRead(t *testing.T) {
datastore, tx := openSessionForTest(t, "MarkNotificationAsRead", true)
defer closeTest(t, datastore, tx)
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "MarkNotificationAsRead")
defer cleanup()
// invalid case: notification doesn't exist
assert.NotNil(t, tx.MarkNotificationAsRead("non-existing"))
assert.NotNil(t, MarkNotificationAsRead(tx, "non-existing"))
// valid case
assert.Nil(t, tx.MarkNotificationAsRead("test"))
assert.Nil(t, MarkNotificationAsRead(tx, "test"))
// valid case
assert.Nil(t, tx.MarkNotificationAsRead("test"))
assert.Nil(t, MarkNotificationAsRead(tx, "test"))
}
func TestDeleteNotification(t *testing.T) {
datastore, tx := openSessionForTest(t, "DeleteNotification", true)
defer closeTest(t, datastore, tx)
tx, cleanup := testutil.CreateTestTxWithFixtures(t, "DeleteNotification")
defer cleanup()
// invalid case: notification doesn't exist
assert.NotNil(t, tx.DeleteNotification("non-existing"))
assert.NotNil(t, DeleteNotification(tx, "non-existing"))
// valid case
assert.Nil(t, tx.DeleteNotification("test"))
assert.Nil(t, DeleteNotification(tx, "test"))
// invalid case: notification is already deleted
assert.NotNil(t, tx.DeleteNotification("test"))
assert.NotNil(t, DeleteNotification(tx, "test"))
}

@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
package notification
import (
"database/sql"
@ -22,6 +22,8 @@ import (
"github.com/guregu/null/zero"
"github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/util"
"github.com/coreos/clair/database/pgsql/vulnerability"
"github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
)
@ -54,26 +56,24 @@ const (
SELECT created_at, notified_at, deleted_at, old_vulnerability_id, new_vulnerability_id
FROM Vulnerability_Notification
WHERE name = $1`
searchNotificationVulnerableAncestry = `
SELECT DISTINCT ON (a.id)
a.id, a.name
FROM vulnerability_affected_namespaced_feature AS vanf,
ancestry_layer AS al, ancestry_feature AS af, ancestry AS a
WHERE vanf.vulnerability_id = $1
AND a.id >= $2
AND al.ancestry_id = a.id
AND al.id = af.ancestry_layer_id
AND af.namespaced_feature_id = vanf.namespaced_feature_id
ORDER BY a.id ASC
LIMIT $3;`
)
func queryInsertNotifications(count int) string {
return util.QueryInsert(count,
"vulnerability_notification",
"name",
"created_at",
"old_vulnerability_id",
"new_vulnerability_id",
)
}
var (
errNotificationNotFound = errors.New("requested notification is not found")
errNotificationNotFound = errors.New("requested notification is not found")
errVulnerabilityNotFound = errors.New("vulnerability is not in database")
)
func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.VulnerabilityNotification) error {
func InsertVulnerabilityNotifications(tx *sql.Tx, notifications []database.VulnerabilityNotification) error {
if len(notifications) == 0 {
return nil
}
@ -122,26 +122,26 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V
oldVulnIDs = append(oldVulnIDs, vulnID)
}
ids, err := tx.findNotDeletedVulnerabilityIDs(newVulnIDs)
ids, err := vulnerability.FindNotDeletedVulnerabilityIDs(tx, newVulnIDs)
if err != nil {
return err
}
for i, id := range ids {
if !id.Valid {
return handleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
return util.HandleError("findNotDeletedVulnerabilityIDs", errVulnerabilityNotFound)
}
newVulnIDMap[newVulnIDs[i]] = id
}
ids, err = tx.findLatestDeletedVulnerabilityIDs(oldVulnIDs)
ids, err = vulnerability.FindLatestDeletedVulnerabilityIDs(tx, oldVulnIDs)
if err != nil {
return err
}
for i, id := range ids {
if !id.Valid {
return handleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
return util.HandleError("findLatestDeletedVulnerabilityIDs", errVulnerabilityNotFound)
}
oldVulnIDMap[oldVulnIDs[i]] = id
}
@ -178,13 +178,13 @@ func (tx *pgSession) InsertVulnerabilityNotifications(notifications []database.V
// multiple updaters, deadlock may happen.
_, err = tx.Exec(queryInsertNotifications(len(notifications)), keys...)
if err != nil {
return handleError("queryInsertNotifications", err)
return util.HandleError("queryInsertNotifications", err)
}
return nil
}
func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.NotificationHook, bool, error) {
func FindNewNotification(tx *sql.Tx, notifiedBefore time.Time) (database.NotificationHook, bool, error) {
var (
notification database.NotificationHook
created zero.Time
@ -197,7 +197,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
if err == sql.ErrNoRows {
return notification, false, nil
}
return notification, false, handleError("searchNotificationAvailable", err)
return notification, false, util.HandleError("searchNotificationAvailable", err)
}
notification.Created = created.Time
@ -207,71 +207,7 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
return notification, true, nil
}
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) {
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
currentPage := Page{0}
if currentToken != pagination.FirstPageToken {
if err := tx.key.UnmarshalToken(currentToken, &currentPage); err != nil {
return vulnPage, err
}
}
if err := tx.QueryRow(searchVulnerabilityByID, vulnID).Scan(
&vulnPage.Name,
&vulnPage.Description,
&vulnPage.Link,
&vulnPage.Severity,
&vulnPage.Metadata,
&vulnPage.Namespace.Name,
&vulnPage.Namespace.VersionFormat,
); err != nil {
return vulnPage, handleError("searchVulnerabilityByID", err)
}
// the last result is used for the next page's startID
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
}
defer rows.Close()
ancestries := []affectedAncestry{}
for rows.Next() {
var ancestry affectedAncestry
err := rows.Scan(&ancestry.id, &ancestry.name)
if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
}
ancestries = append(ancestries, ancestry)
}
lastIndex := 0
if len(ancestries)-1 < limit {
lastIndex = len(ancestries)
vulnPage.End = true
} else {
// Use the last ancestry's ID as the next page.
lastIndex = len(ancestries) - 1
vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id})
if err != nil {
return vulnPage, err
}
}
vulnPage.Affected = map[int]string{}
for _, ancestry := range ancestries[0:lastIndex] {
vulnPage.Affected[int(ancestry.id)] = ancestry.name
}
vulnPage.Current, err = tx.key.MarshalToken(currentPage)
if err != nil {
return vulnPage, err
}
return vulnPage, nil
}
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) (
func FindVulnerabilityNotification(tx *sql.Tx, name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token, key pagination.Key) (
database.VulnerabilityNotificationWithVulnerable, bool, error) {
var (
noti database.VulnerabilityNotificationWithVulnerable
@ -294,7 +230,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
if err == sql.ErrNoRows {
return noti, false, nil
}
return noti, false, handleError("searchNotification", err)
return noti, false, util.HandleError("searchNotification", err)
}
if created.Valid {
@ -310,7 +246,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
}
if oldVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken)
page, err := vulnerability.FindPagedVulnerableAncestries(tx, oldVulnID.Int64, limit, oldPageToken, key)
if err != nil {
return noti, false, err
}
@ -318,7 +254,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
}
if newVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken)
page, err := vulnerability.FindPagedVulnerableAncestries(tx, newVulnID.Int64, limit, newPageToken, key)
if err != nil {
return noti, false, err
}
@ -328,44 +264,44 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
return noti, true, nil
}
func (tx *pgSession) MarkNotificationAsRead(name string) error {
func MarkNotificationAsRead(tx *sql.Tx, name string) error {
if name == "" {
return commonerr.NewBadRequestError("Empty notification name is not allowed")
}
r, err := tx.Exec(updatedNotificationAsRead, name)
if err != nil {
return handleError("updatedNotificationAsRead", err)
return util.HandleError("updatedNotificationAsRead", err)
}
affected, err := r.RowsAffected()
if err != nil {
return handleError("updatedNotificationAsRead", err)
return util.HandleError("updatedNotificationAsRead", err)
}
if affected <= 0 {
return handleError("updatedNotificationAsRead", errNotificationNotFound)
return util.HandleError("updatedNotificationAsRead", errNotificationNotFound)
}
return nil
}
func (tx *pgSession) DeleteNotification(name string) error {
func DeleteNotification(tx *sql.Tx, name string) error {
if name == "" {
return commonerr.NewBadRequestError("Empty notification name is not allowed")
}
result, err := tx.Exec(removeNotification, name)
if err != nil {
return handleError("removeNotification", err)
return util.HandleError("removeNotification", err)
}
affected, err := result.RowsAffected()
if err != nil {
return handleError("removeNotification", err)
return util.HandleError("removeNotification", err)
}
if affected <= 0 {
return handleError("removeNotification", commonerr.ErrNotFound)
return util.HandleError("removeNotification", commonerr.ErrNotFound)
}
return nil
Loading…
Cancel
Save