diff --git a/api/v3/rpc.go b/api/v3/rpc.go index 9c893bf2..1111c09d 100644 --- a/api/v3/rpc.go +++ b/api/v3/rpc.go @@ -25,6 +25,7 @@ import ( pb "github.com/coreos/clair/api/v3/clairpb" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) // NotificationServer implements NotificationService interface for serving RPC. @@ -214,8 +215,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot dbNotification, ok, err := tx.FindVulnerabilityNotification( req.GetName(), int(req.GetLimit()), - database.PageNumber(req.GetOldVulnerabilityPage()), - database.PageNumber(req.GetNewVulnerabilityPage()), + pagination.Token(req.GetOldVulnerabilityPage()), + pagination.Token(req.GetNewVulnerabilityPage()), ) if err != nil { diff --git a/database/database.go b/database/database.go index 5c427494..61edb140 100644 --- a/database/database.go +++ b/database/database.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" "time" + + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -168,16 +170,9 @@ type Session interface { // affected ancestries affected by old or new vulnerability. // // Because the number of affected ancestries maybe large, they are paginated - // and their pages are specified by the given encrypted PageNumbers, which, - // if empty, are always considered first page. - // - // Session interface implementation should have encrypt and decrypt - // functions for PageNumber. - FindVulnerabilityNotification(name string, limit int, - oldVulnerabilityPage PageNumber, - newVulnerabilityPage PageNumber) ( - noti VulnerabilityNotificationWithVulnerable, - found bool, err error) + // and their pages are specified by the paination token, which, if empty, are + // always considered first page. + FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error) // MarkNotificationNotified marks a Notification as notified now, assuming // the requested notification is in the database. diff --git a/database/mock.go b/database/mock.go index ed6e8f16..0cd09919 100644 --- a/database/mock.go +++ b/database/mock.go @@ -14,7 +14,11 @@ package database -import "time" +import ( + "time" + + "github.com/coreos/clair/pkg/pagination" +) // MockSession implements Session and enables overriding each available method. // The default behavior of each method is to simply panic. @@ -38,7 +42,7 @@ type MockSession struct { FctDeleteVulnerabilities func([]VulnerabilityID) error FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) - FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + FctFindVulnerabilityNotification func(name string, limit int, oldPage pagination.Token, newPage pagination.Token) ( vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) FctMarkNotificationNotified func(name string) error FctDeleteNotification func(name string) error @@ -182,7 +186,7 @@ func (ms *MockSession) FindNewNotification(lastNotified time.Time) (Notification panic("required mock function not implemented") } -func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( +func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage pagination.Token, newPage pagination.Token) ( VulnerabilityNotificationWithVulnerable, bool, error) { if ms.FctFindVulnerabilityNotification != nil { return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) diff --git a/database/models.go b/database/models.go index e5a82358..702a079f 100644 --- a/database/models.go +++ b/database/models.go @@ -18,6 +18,8 @@ import ( "database/sql/driver" "encoding/json" "time" + + "github.com/coreos/clair/pkg/pagination" ) // Processors are extentions to scan a layer's content. @@ -173,8 +175,8 @@ type PagedVulnerableAncestries struct { Affected map[int]string Limit int - Current PageNumber - Next PageNumber + Current pagination.Token + Next pagination.Token // End signals the end of the pages. End bool @@ -209,9 +211,6 @@ type VulnerabilityNotificationWithVulnerable struct { New *PagedVulnerableAncestries } -// PageNumber is used to do pagination. -type PageNumber string - // MetadataMap is for storing the metadata returned by vulnerability database. type MetadataMap map[string]interface{} diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index 5c9af262..3a27fb3c 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -23,6 +23,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -163,12 +164,12 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not return notification, true, nil } -func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { +func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) { vulnPage := database.PagedVulnerableAncestries{Limit: limit} - current := Page{0} - if currentPage != "" { + currentPage := Page{0} + if currentToken != pagination.FirstPageToken { var err error - current, err = PageFromPageNumber(tx.key, currentPage) + err = tx.key.UnmarshalToken(currentToken, ¤tPage) if err != nil { return vulnPage, err } @@ -188,7 +189,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr } // the last result is used for the next page's startID - rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) if err != nil { return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } @@ -209,9 +210,9 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr lastIndex = len(ancestries) vulnPage.End = true } else { - // Use the last ancestry's ID as the next PageNumber. + // Use the last ancestry's ID as the next page. lastIndex = len(ancestries) - 1 - vulnPage.Next, err = Page{ancestries[len(ancestries)-1].id}.PageNumber(tx.key) + vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id}) if err != nil { return vulnPage, err } @@ -222,7 +223,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr vulnPage.Affected[int(ancestry.id)] = ancestry.name } - vulnPage.Current, err = current.PageNumber(tx.key) + vulnPage.Current, err = tx.key.MarshalToken(currentPage) if err != nil { return vulnPage, err } @@ -230,7 +231,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr return vulnPage, nil } -func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( +func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) ( database.VulnerabilityNotificationWithVulnerable, bool, error) { var ( noti database.VulnerabilityNotificationWithVulnerable @@ -270,7 +271,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if oldVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) + page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken) if err != nil { return noti, false, err } @@ -278,7 +279,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if newVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) + page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken) if err != nil { return noti, false, err } diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 18e27a5c..ec119e99 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -73,17 +73,20 @@ func TestPagination(t *testing.T) { if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldPage, err := PageFromPageNumber(tx.key, noti.Old.Current) + var oldPage Page + err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } assert.Equal(t, int64(0), oldPage.StartID) - newPage, err := PageFromPageNumber(tx.key, noti.New.Current) + var newPage Page + err = tx.key.UnmarshalToken(noti.New.Current, &newPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newPageNext, err := PageFromPageNumber(tx.key, noti.New.Next) + var newPageNext Page + err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext) if !assert.Nil(t, err) { assert.FailNow(t, "") } @@ -98,12 +101,12 @@ func TestPagination(t *testing.T) { } } - pageNum1, err := Page{0}.PageNumber(tx.key) + pageNum1, err := tx.key.MarshalToken(Page{0}) if !assert.Nil(t, err) { assert.FailNow(t, "") } - pageNum2, err := Page{4}.PageNumber(tx.key) + pageNum2, err := tx.key.MarshalToken(Page{4}) if !assert.Nil(t, err) { assert.FailNow(t, "") } @@ -112,12 +115,14 @@ func TestPagination(t *testing.T) { if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldCurrentPage, err := PageFromPageNumber(tx.key, noti.Old.Current) + var oldCurrentPage Page + err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newCurrentPage, err := PageFromPageNumber(tx.key, noti.New.Current) + var newCurrentPage Page + err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } diff --git a/database/pgsql/page.go b/database/pgsql/page.go deleted file mode 100644 index 101d1dff..00000000 --- a/database/pgsql/page.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/pagination" -) - -// Page is the representation of a page for the Postgres schema. -type Page struct { - // StartID is the ID being used as the basis for pagination across database - // results. It is used to search for an ancestry with ID >= StartID. - // - // StartID is required to be unique to every ancestry and always increasing. - StartID int64 -} - -// PageNumber converts a Page to a database.PageNumber. -func (p Page) PageNumber(key pagination.Key) (pn database.PageNumber, err error) { - token, err := key.MarshalToken(p) - if err != nil { - return pn, err - } - pn = database.PageNumber(token) - return pn, nil -} - -// PageFromPageNumber converts a database.PageNumber into a Page. -func PageFromPageNumber(key pagination.Key, pn database.PageNumber) (p Page, err error) { - err = key.UnmarshalToken(string(pn), &p) - return -} diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 66e8930e..9af010b6 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -131,6 +131,15 @@ func (pgSQL *pgSQL) Ping() bool { return pgSQL.DB.Ping() == nil } +// Page is the representation of a page for the Postgres schema. +type Page struct { + // StartID is the ID being used as the basis for pagination across database + // results. It is used to search for an ancestry with ID >= StartID. + // + // StartID is required to be unique to every ancestry and always increasing. + StartID int64 +} + // Config is the configuration that is used by openDatabase. type Config struct { Source string diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 0e5b7234..e4a8c8b4 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -221,7 +221,7 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data "cachesize": 0, "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, - "paginationkey": pagination.MustGenerateNewKey().String(), + "paginationkey": pagination.Must(pagination.NewKey()).String(), }, } } diff --git a/pkg/pagination/pagination.go b/pkg/pagination/pagination.go index 5f4acb66..1e765cec 100644 --- a/pkg/pagination/pagination.go +++ b/pkg/pagination/pagination.go @@ -38,6 +38,13 @@ type Key struct { fkey *fernet.Key } +// Token represents an opaque pagination token keeping track of a user's +// progress iterating through a list of results. +type Token string + +// FirstPageToken is used to represent the first page of content. +var FirstPageToken = Token("") + // NewKey generates a new random pagination key. func NewKey() (k Key, err error) { k.fkey = new(fernet.Key) @@ -71,21 +78,22 @@ func (k Key) String() string { return k.fkey.Encode() } -// MarshalToken encodes an interface into JSON bytes and encrypts it. -func (k Key) MarshalToken(v interface{}) ([]byte, error) { +// MarshalToken encodes an interface into JSON bytes and produces a Token. +func (k Key) MarshalToken(v interface{}) (Token, error) { var buf bytes.Buffer err := json.NewEncoder(&buf).Encode(v) if err != nil { - return nil, err + return Token(""), err } - return fernet.EncryptAndSign(buf.Bytes(), k.fkey) + tokenBytes, err := fernet.EncryptAndSign(buf.Bytes(), k.fkey) + return Token(tokenBytes), err } -// UnmarshalToken decrypts a token using provided key and decodes the result +// UnmarshalToken decrypts a Token using provided key and decodes the result // into the provided interface. -func (k Key) UnmarshalToken(token string, v interface{}) error { - msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k.fkey}) +func (k Key) UnmarshalToken(t Token, v interface{}) error { + msg := fernet.VerifyAndDecrypt([]byte(t), time.Hour, []*fernet.Key{k.fkey}) if msg == nil { return ErrInvalidToken }