pkg/pagination: add token type

This change pulls as much pagination logic out of the database
implementation as possible. Database implementations should now be able
to marshal whatever state they need into opaque tokens with the
utilities in the pagination package.
This commit is contained in:
Jimmy Zelinskie 2018-09-07 16:12:19 -04:00
parent d193b46449
commit 0565938956
10 changed files with 68 additions and 91 deletions

View File

@ -25,6 +25,7 @@ import (
pb "github.com/coreos/clair/api/v3/clairpb" pb "github.com/coreos/clair/api/v3/clairpb"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
) )
// NotificationServer implements NotificationService interface for serving RPC. // NotificationServer implements NotificationService interface for serving RPC.
@ -214,8 +215,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
dbNotification, ok, err := tx.FindVulnerabilityNotification( dbNotification, ok, err := tx.FindVulnerabilityNotification(
req.GetName(), req.GetName(),
int(req.GetLimit()), int(req.GetLimit()),
database.PageNumber(req.GetOldVulnerabilityPage()), pagination.Token(req.GetOldVulnerabilityPage()),
database.PageNumber(req.GetNewVulnerabilityPage()), pagination.Token(req.GetNewVulnerabilityPage()),
) )
if err != nil { if err != nil {

View File

@ -20,6 +20,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"time" "time"
"github.com/coreos/clair/pkg/pagination"
) )
var ( var (
@ -168,16 +170,9 @@ type Session interface {
// affected ancestries affected by old or new vulnerability. // affected ancestries affected by old or new vulnerability.
// //
// Because the number of affected ancestries maybe large, they are paginated // Because the number of affected ancestries maybe large, they are paginated
// and their pages are specified by the given encrypted PageNumbers, which, // and their pages are specified by the paination token, which, if empty, are
// if empty, are always considered first page. // always considered first page.
// FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error)
// Session interface implementation should have encrypt and decrypt
// functions for PageNumber.
FindVulnerabilityNotification(name string, limit int,
oldVulnerabilityPage PageNumber,
newVulnerabilityPage PageNumber) (
noti VulnerabilityNotificationWithVulnerable,
found bool, err error)
// MarkNotificationNotified marks a Notification as notified now, assuming // MarkNotificationNotified marks a Notification as notified now, assuming
// the requested notification is in the database. // the requested notification is in the database.

View File

@ -14,7 +14,11 @@
package database package database
import "time" import (
"time"
"github.com/coreos/clair/pkg/pagination"
)
// MockSession implements Session and enables overriding each available method. // MockSession implements Session and enables overriding each available method.
// The default behavior of each method is to simply panic. // The default behavior of each method is to simply panic.
@ -38,7 +42,7 @@ type MockSession struct {
FctDeleteVulnerabilities func([]VulnerabilityID) error FctDeleteVulnerabilities func([]VulnerabilityID) error
FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error
FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error)
FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( FctFindVulnerabilityNotification func(name string, limit int, oldPage pagination.Token, newPage pagination.Token) (
vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) vuln VulnerabilityNotificationWithVulnerable, ok bool, err error)
FctMarkNotificationNotified func(name string) error FctMarkNotificationNotified func(name string) error
FctDeleteNotification func(name string) error FctDeleteNotification func(name string) error
@ -182,7 +186,7 @@ func (ms *MockSession) FindNewNotification(lastNotified time.Time) (Notification
panic("required mock function not implemented") panic("required mock function not implemented")
} }
func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage pagination.Token, newPage pagination.Token) (
VulnerabilityNotificationWithVulnerable, bool, error) { VulnerabilityNotificationWithVulnerable, bool, error) {
if ms.FctFindVulnerabilityNotification != nil { if ms.FctFindVulnerabilityNotification != nil {
return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage)

View File

@ -18,6 +18,8 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"time" "time"
"github.com/coreos/clair/pkg/pagination"
) )
// Processors are extentions to scan a layer's content. // Processors are extentions to scan a layer's content.
@ -173,8 +175,8 @@ type PagedVulnerableAncestries struct {
Affected map[int]string Affected map[int]string
Limit int Limit int
Current PageNumber Current pagination.Token
Next PageNumber Next pagination.Token
// End signals the end of the pages. // End signals the end of the pages.
End bool End bool
@ -209,9 +211,6 @@ type VulnerabilityNotificationWithVulnerable struct {
New *PagedVulnerableAncestries New *PagedVulnerableAncestries
} }
// PageNumber is used to do pagination.
type PageNumber string
// MetadataMap is for storing the metadata returned by vulnerability database. // MetadataMap is for storing the metadata returned by vulnerability database.
type MetadataMap map[string]interface{} type MetadataMap map[string]interface{}

View File

@ -23,6 +23,7 @@ import (
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
) )
var ( var (
@ -163,12 +164,12 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
return notification, true, nil return notification, true, nil
} }
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) {
vulnPage := database.PagedVulnerableAncestries{Limit: limit} vulnPage := database.PagedVulnerableAncestries{Limit: limit}
current := Page{0} currentPage := Page{0}
if currentPage != "" { if currentToken != pagination.FirstPageToken {
var err error var err error
current, err = PageFromPageNumber(tx.key, currentPage) err = tx.key.UnmarshalToken(currentToken, &currentPage)
if err != nil { if err != nil {
return vulnPage, err return vulnPage, err
} }
@ -188,7 +189,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
} }
// the last result is used for the next page's startID // the last result is used for the next page's startID
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
if err != nil { if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err) return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
} }
@ -209,9 +210,9 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
lastIndex = len(ancestries) lastIndex = len(ancestries)
vulnPage.End = true vulnPage.End = true
} else { } else {
// Use the last ancestry's ID as the next PageNumber. // Use the last ancestry's ID as the next page.
lastIndex = len(ancestries) - 1 lastIndex = len(ancestries) - 1
vulnPage.Next, err = Page{ancestries[len(ancestries)-1].id}.PageNumber(tx.key) vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id})
if err != nil { if err != nil {
return vulnPage, err return vulnPage, err
} }
@ -222,7 +223,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
vulnPage.Affected[int(ancestry.id)] = ancestry.name vulnPage.Affected[int(ancestry.id)] = ancestry.name
} }
vulnPage.Current, err = current.PageNumber(tx.key) vulnPage.Current, err = tx.key.MarshalToken(currentPage)
if err != nil { if err != nil {
return vulnPage, err return vulnPage, err
} }
@ -230,7 +231,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
return vulnPage, nil return vulnPage, nil
} }
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) (
database.VulnerabilityNotificationWithVulnerable, bool, error) { database.VulnerabilityNotificationWithVulnerable, bool, error) {
var ( var (
noti database.VulnerabilityNotificationWithVulnerable noti database.VulnerabilityNotificationWithVulnerable
@ -270,7 +271,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
if oldVulnID.Valid { if oldVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken)
if err != nil { if err != nil {
return noti, false, err return noti, false, err
} }
@ -278,7 +279,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
if newVulnID.Valid { if newVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken)
if err != nil { if err != nil {
return noti, false, err return noti, false, err
} }

View File

@ -73,17 +73,20 @@ func TestPagination(t *testing.T) {
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
oldPage, err := PageFromPageNumber(tx.key, noti.Old.Current) var oldPage Page
err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
assert.Equal(t, int64(0), oldPage.StartID) assert.Equal(t, int64(0), oldPage.StartID)
newPage, err := PageFromPageNumber(tx.key, noti.New.Current) var newPage Page
err = tx.key.UnmarshalToken(noti.New.Current, &newPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
newPageNext, err := PageFromPageNumber(tx.key, noti.New.Next) var newPageNext Page
err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
@ -98,12 +101,12 @@ func TestPagination(t *testing.T) {
} }
} }
pageNum1, err := Page{0}.PageNumber(tx.key) pageNum1, err := tx.key.MarshalToken(Page{0})
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
pageNum2, err := Page{4}.PageNumber(tx.key) pageNum2, err := tx.key.MarshalToken(Page{4})
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
@ -112,12 +115,14 @@ func TestPagination(t *testing.T) {
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
oldCurrentPage, err := PageFromPageNumber(tx.key, noti.Old.Current) var oldCurrentPage Page
err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
newCurrentPage, err := PageFromPageNumber(tx.key, noti.New.Current) var newCurrentPage Page
err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }

View File

@ -1,45 +0,0 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pgsql
import (
"github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/pagination"
)
// Page is the representation of a page for the Postgres schema.
type Page struct {
// StartID is the ID being used as the basis for pagination across database
// results. It is used to search for an ancestry with ID >= StartID.
//
// StartID is required to be unique to every ancestry and always increasing.
StartID int64
}
// PageNumber converts a Page to a database.PageNumber.
func (p Page) PageNumber(key pagination.Key) (pn database.PageNumber, err error) {
token, err := key.MarshalToken(p)
if err != nil {
return pn, err
}
pn = database.PageNumber(token)
return pn, nil
}
// PageFromPageNumber converts a database.PageNumber into a Page.
func PageFromPageNumber(key pagination.Key, pn database.PageNumber) (p Page, err error) {
err = key.UnmarshalToken(string(pn), &p)
return
}

View File

@ -131,6 +131,15 @@ func (pgSQL *pgSQL) Ping() bool {
return pgSQL.DB.Ping() == nil return pgSQL.DB.Ping() == nil
} }
// Page is the representation of a page for the Postgres schema.
type Page struct {
// StartID is the ID being used as the basis for pagination across database
// results. It is used to search for an ancestry with ID >= StartID.
//
// StartID is required to be unique to every ancestry and always increasing.
StartID int64
}
// Config is the configuration that is used by openDatabase. // Config is the configuration that is used by openDatabase.
type Config struct { type Config struct {
Source string Source string

View File

@ -221,7 +221,7 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data
"cachesize": 0, "cachesize": 0,
"managedatabaselifecycle": manageLife, "managedatabaselifecycle": manageLife,
"fixturepath": fixturePath, "fixturepath": fixturePath,
"paginationkey": pagination.MustGenerateNewKey().String(), "paginationkey": pagination.Must(pagination.NewKey()).String(),
}, },
} }
} }

View File

@ -38,6 +38,13 @@ type Key struct {
fkey *fernet.Key fkey *fernet.Key
} }
// Token represents an opaque pagination token keeping track of a user's
// progress iterating through a list of results.
type Token string
// FirstPageToken is used to represent the first page of content.
var FirstPageToken = Token("")
// NewKey generates a new random pagination key. // NewKey generates a new random pagination key.
func NewKey() (k Key, err error) { func NewKey() (k Key, err error) {
k.fkey = new(fernet.Key) k.fkey = new(fernet.Key)
@ -71,21 +78,22 @@ func (k Key) String() string {
return k.fkey.Encode() return k.fkey.Encode()
} }
// MarshalToken encodes an interface into JSON bytes and encrypts it. // MarshalToken encodes an interface into JSON bytes and produces a Token.
func (k Key) MarshalToken(v interface{}) ([]byte, error) { func (k Key) MarshalToken(v interface{}) (Token, error) {
var buf bytes.Buffer var buf bytes.Buffer
err := json.NewEncoder(&buf).Encode(v) err := json.NewEncoder(&buf).Encode(v)
if err != nil { if err != nil {
return nil, err return Token(""), err
} }
return fernet.EncryptAndSign(buf.Bytes(), k.fkey) tokenBytes, err := fernet.EncryptAndSign(buf.Bytes(), k.fkey)
return Token(tokenBytes), err
} }
// UnmarshalToken decrypts a token using provided key and decodes the result // UnmarshalToken decrypts a Token using provided key and decodes the result
// into the provided interface. // into the provided interface.
func (k Key) UnmarshalToken(token string, v interface{}) error { func (k Key) UnmarshalToken(t Token, v interface{}) error {
msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k.fkey}) msg := fernet.VerifyAndDecrypt([]byte(t), time.Hour, []*fernet.Key{k.fkey})
if msg == nil { if msg == nil {
return ErrInvalidToken return ErrInvalidToken
} }