From b20482e0aebcf2cc67f61e8ff821ddcdffc53ac7 Mon Sep 17 00:00:00 2001 From: Jimmy Zelinskie Date: Thu, 6 Sep 2018 17:40:01 -0400 Subject: [PATCH 1/3] cmd/clair: document constants --- cmd/clair/main.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/cmd/clair/main.go b/cmd/clair/main.go index dbb360a8..da046351 100644 --- a/cmd/clair/main.go +++ b/cmd/clair/main.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2018 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,7 +61,17 @@ import ( _ "github.com/coreos/clair/ext/vulnsrc/ubuntu" ) -const maxDBConnectionAttempts = 20 +// MaxDBConnectionAttempts is the total number of tries that Clair will use to +// initially connect to a database at start-up. +const MaxDBConnectionAttempts = 20 + +// BinaryDependencies are the programs that Clair expects to be on the $PATH +// because it creates subprocesses of these programs. +var BinaryDependencies = []string{ + "git", + "rpm", + "xz", +} func waitForSignals(signals ...os.Signal) { interrupts := make(chan os.Signal, 1) @@ -136,7 +146,7 @@ func Boot(config *Config) { // Open database var db database.Datastore var dbError error - for attempts := 1; attempts <= maxDBConnectionAttempts; attempts++ { + for attempts := 1; attempts <= MaxDBConnectionAttempts; attempts++ { db, dbError = database.Open(config.Database) if dbError == nil { break @@ -180,7 +190,7 @@ func main() { flag.Parse() // Check for dependencies. - for _, bin := range []string{"git", "rpm", "xz"} { + for _, bin := range BinaryDependencies { _, err := exec.LookPath(bin) if err != nil { log.WithError(err).WithField("dependency", bin).Fatal("failed to find dependency") From d193b46449a64a554c3b54dd637a371769bfe195 Mon Sep 17 00:00:00 2001 From: Jimmy Zelinskie Date: Thu, 6 Sep 2018 19:15:06 -0400 Subject: [PATCH 2/3] pkg/pagination: init This change refactors a lot of the code dealing with pagination so that fernet implementation details do not leak. - Deletes database/pgsql/token - Introduces a pagination package - Renames idPageNumber to Page and add a constructor and method. --- cmd/clair/config.go | 13 ++-- database/pgsql/notification.go | 12 ++-- database/pgsql/notification_test.go | 24 ++++---- database/pgsql/page.go | 45 ++++++++++++++ database/pgsql/pgsql.go | 36 +++-------- database/pgsql/pgsql_test.go | 9 +-- database/pgsql/token/token.go | 49 --------------- pkg/pagination/pagination.go | 94 +++++++++++++++++++++++++++++ 8 files changed, 169 insertions(+), 113 deletions(-) create mode 100644 database/pgsql/page.go delete mode 100644 database/pgsql/token/token.go create mode 100644 pkg/pagination/pagination.go diff --git a/cmd/clair/config.go b/cmd/clair/config.go index 08f26066..09a01364 100644 --- a/cmd/clair/config.go +++ b/cmd/clair/config.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2018 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import ( "os" "time" - "github.com/fernet/fernet-go" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -31,6 +30,7 @@ import ( "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/notification" "github.com/coreos/clair/ext/vulnsrc" + "github.com/coreos/clair/pkg/pagination" ) // ErrDatasourceNotLoaded is returned when the datasource variable in the @@ -108,15 +108,10 @@ func LoadConfig(path string) (config *Config, err error) { // Generate a pagination key if none is provided. if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" { log.Warn("pagination key is empty, generating...") - var key fernet.Key - if err = key.Generate(); err != nil { - return - } - config.Database.Options["paginationkey"] = key.Encode() + config.Database.Options["paginationkey"] = pagination.Must(pagination.NewKey()).String() } else { - _, err = fernet.DecodeKey(config.Database.Options["paginationkey"].(string)) + _, err = pagination.KeyFromString(config.Database.Options["paginationkey"].(string)) if err != nil { - err = errors.New("Invalid Pagination key; must be 32-bit URL-safe base64") return } } diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index ebc346d3..5c9af262 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -165,10 +165,10 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { vulnPage := database.PagedVulnerableAncestries{Limit: limit} - current := idPageNumber{0} + current := Page{0} if currentPage != "" { var err error - current, err = decryptPage(currentPage, tx.paginationKey) + current, err = PageFromPageNumber(tx.key, currentPage) if err != nil { return vulnPage, err } @@ -211,11 +211,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr } else { // Use the last ancestry's ID as the next PageNumber. lastIndex = len(ancestries) - 1 - vulnPage.Next, err = encryptPage( - idPageNumber{ - ancestries[len(ancestries)-1].id, - }, tx.paginationKey) - + vulnPage.Next, err = Page{ancestries[len(ancestries)-1].id}.PageNumber(tx.key) if err != nil { return vulnPage, err } @@ -226,7 +222,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr vulnPage.Affected[int(ancestry.id)] = ancestry.name } - vulnPage.Current, err = encryptPage(current, tx.paginationKey) + vulnPage.Current, err = current.PageNumber(tx.key) if err != nil { return vulnPage, err } diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 0d930d08..18e27a5c 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2018 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -73,22 +73,22 @@ func TestPagination(t *testing.T) { if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey) + oldPage, err := PageFromPageNumber(tx.key, noti.Old.Current) if !assert.Nil(t, err) { assert.FailNow(t, "") } - assert.Equal(t, int64(0), oldPageNum.StartID) - newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey) + assert.Equal(t, int64(0), oldPage.StartID) + newPage, err := PageFromPageNumber(tx.key, noti.New.Current) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey) + newPageNext, err := PageFromPageNumber(tx.key, noti.New.Next) if !assert.Nil(t, err) { assert.FailNow(t, "") } - assert.Equal(t, int64(0), newPageNum.StartID) - assert.Equal(t, int64(4), newPageNextNum.StartID) + assert.Equal(t, int64(0), newPage.StartID) + assert.Equal(t, int64(4), newPageNext.StartID) noti.Old.Current = "" noti.New.Current = "" @@ -98,26 +98,26 @@ func TestPagination(t *testing.T) { } } - page1, err := encryptPage(idPageNumber{0}, tx.paginationKey) + pageNum1, err := Page{0}.PageNumber(tx.key) if !assert.Nil(t, err) { assert.FailNow(t, "") } - page2, err := encryptPage(idPageNumber{4}, tx.paginationKey) + pageNum2, err := Page{4}.PageNumber(tx.key) if !assert.Nil(t, err) { assert.FailNow(t, "") } - noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2) + noti, ok, err = tx.FindVulnerabilityNotification("test", 1, pageNum1, pageNum2) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey) + oldCurrentPage, err := PageFromPageNumber(tx.key, noti.Old.Current) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey) + newCurrentPage, err := PageFromPageNumber(tx.key, noti.New.Current) if !assert.Nil(t, err) { assert.FailNow(t, "") } diff --git a/database/pgsql/page.go b/database/pgsql/page.go new file mode 100644 index 00000000..101d1dff --- /dev/null +++ b/database/pgsql/page.go @@ -0,0 +1,45 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pgsql + +import ( + "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" +) + +// Page is the representation of a page for the Postgres schema. +type Page struct { + // StartID is the ID being used as the basis for pagination across database + // results. It is used to search for an ancestry with ID >= StartID. + // + // StartID is required to be unique to every ancestry and always increasing. + StartID int64 +} + +// PageNumber converts a Page to a database.PageNumber. +func (p Page) PageNumber(key pagination.Key) (pn database.PageNumber, err error) { + token, err := key.MarshalToken(p) + if err != nil { + return pn, err + } + pn = database.PageNumber(token) + return pn, nil +} + +// PageFromPageNumber converts a database.PageNumber into a Page. +func PageFromPageNumber(key pagination.Key, pn database.PageNumber) (p Page, err error) { + err = key.UnmarshalToken(string(pn), &p) + return +} diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 335b77c7..66e8930e 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -33,8 +33,8 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/migrations" - "github.com/coreos/clair/database/pgsql/token" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -91,41 +91,21 @@ type pgSQL struct { type pgSession struct { *sql.Tx - paginationKey string + key pagination.Key } -type idPageNumber struct { - // StartID is an implementation detail for paginating by an ID required to - // be unique to every ancestry and always increasing. - // - // StartID is used to search for ancestry with ID >= StartID - StartID int64 -} - -func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) { - resultBytes, err := token.Marshal(page, paginationKey) - if err != nil { - return result, err - } - result = database.PageNumber(resultBytes) - return result, nil -} - -func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) { - err = token.Unmarshal(string(page), paginationKey, &result) - return -} - -// Begin initiates a transaction to database. The expected transaction isolation -// level in this implementation is "Read Committed". +// Begin initiates a transaction to database. +// +// The expected transaction isolation level in this implementation is "Read +// Committed". func (pgSQL *pgSQL) Begin() (database.Session, error) { tx, err := pgSQL.DB.Begin() if err != nil { return nil, err } return &pgSession{ - Tx: tx, - paginationKey: pgSQL.config.PaginationKey, + Tx: tx, + key: pagination.Must(pagination.KeyFromString(pgSQL.config.PaginationKey)), }, nil } diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 96241666..0e5b7234 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -24,13 +24,13 @@ import ( "strings" "testing" - fernet "github.com/fernet/fernet-go" "github.com/pborman/uuid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" yaml "gopkg.in/yaml.v2" "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -215,18 +215,13 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data source = fmt.Sprintf(sourceEnv, dbName) } - var key fernet.Key - if err := key.Generate(); err != nil { - panic("failed to generate pagination key" + err.Error()) - } - return database.RegistrableComponentConfig{ Options: map[string]interface{}{ "source": source, "cachesize": 0, "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, - "paginationkey": key.Encode(), + "paginationkey": pagination.MustGenerateNewKey().String(), }, } } diff --git a/database/pgsql/token/token.go b/database/pgsql/token/token.go deleted file mode 100644 index 3492e712..00000000 --- a/database/pgsql/token/token.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package token implements encryption/decryption for json encoded interfaces -package token - -import ( - "bytes" - "encoding/json" - "errors" - "time" - - "github.com/fernet/fernet-go" -) - -// Unmarshal decrypts a token using provided key -// and decode the result into interface. -func Unmarshal(token string, key string, v interface{}) error { - k, _ := fernet.DecodeKey(key) - msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k}) - if msg == nil { - return errors.New("invalid or expired pagination token") - } - - return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v) -} - -// Marshal encodes an interface into json bytes and encrypts it. -func Marshal(v interface{}, key string) ([]byte, error) { - var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(v) - if err != nil { - return nil, err - } - - k, _ := fernet.DecodeKey(key) - return fernet.EncryptAndSign(buf.Bytes(), k) -} diff --git a/pkg/pagination/pagination.go b/pkg/pagination/pagination.go new file mode 100644 index 00000000..5f4acb66 --- /dev/null +++ b/pkg/pagination/pagination.go @@ -0,0 +1,94 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pagination implements a series of utilities for dealing with +// paginating lists of objects for an API. +package pagination + +import ( + "bytes" + "encoding/json" + "errors" + "time" + + "github.com/fernet/fernet-go" +) + +// ErrInvalidToken is returned when a token fails to Unmarshal because it was +// invalid or expired. +var ErrInvalidToken = errors.New("invalid or expired pagination token") + +// ErrInvalidKeyString is returned when the string representing a key is malformed. +var ErrInvalidKeyString = errors.New("invalid pagination key string: must be 32-byte URL-safe base64") + +// Key represents the key used to cryptographically secure the token +// being used to keep track of pages. +type Key struct { + fkey *fernet.Key +} + +// NewKey generates a new random pagination key. +func NewKey() (k Key, err error) { + k.fkey = new(fernet.Key) + err = k.fkey.Generate() + return k, err +} + +// KeyFromString creates the key for a given string. +// +// Strings must be 32-byte URL-safe base64 representations of the key bytes. +func KeyFromString(keyString string) (k Key, err error) { + var fkey *fernet.Key + fkey, err = fernet.DecodeKey(keyString) + if err != nil { + return Key{}, ErrInvalidKeyString + } + return Key{fkey}, err +} + +// Must is a helper that wraps calls returning a Key and and error and panics +// if the error is non-nil. +func Must(k Key, err error) Key { + if err != nil { + panic(err) + } + return k +} + +// String implements the fmt.Stringer interface for Key. +func (k Key) String() string { + return k.fkey.Encode() +} + +// MarshalToken encodes an interface into JSON bytes and encrypts it. +func (k Key) MarshalToken(v interface{}) ([]byte, error) { + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(v) + if err != nil { + return nil, err + } + + return fernet.EncryptAndSign(buf.Bytes(), k.fkey) +} + +// UnmarshalToken decrypts a token using provided key and decodes the result +// into the provided interface. +func (k Key) UnmarshalToken(token string, v interface{}) error { + msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k.fkey}) + if msg == nil { + return ErrInvalidToken + } + + return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v) +} From 05659389569549f445eefac650df260ab4f4f05b Mon Sep 17 00:00:00 2001 From: Jimmy Zelinskie Date: Fri, 7 Sep 2018 16:12:19 -0400 Subject: [PATCH 3/3] pkg/pagination: add token type This change pulls as much pagination logic out of the database implementation as possible. Database implementations should now be able to marshal whatever state they need into opaque tokens with the utilities in the pagination package. --- api/v3/rpc.go | 5 ++-- database/database.go | 15 ++++------ database/mock.go | 10 +++++-- database/models.go | 9 +++--- database/pgsql/notification.go | 23 ++++++++------- database/pgsql/notification_test.go | 19 +++++++----- database/pgsql/page.go | 45 ----------------------------- database/pgsql/pgsql.go | 9 ++++++ database/pgsql/pgsql_test.go | 2 +- pkg/pagination/pagination.go | 22 +++++++++----- 10 files changed, 68 insertions(+), 91 deletions(-) delete mode 100644 database/pgsql/page.go diff --git a/api/v3/rpc.go b/api/v3/rpc.go index 9c893bf2..1111c09d 100644 --- a/api/v3/rpc.go +++ b/api/v3/rpc.go @@ -25,6 +25,7 @@ import ( pb "github.com/coreos/clair/api/v3/clairpb" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) // NotificationServer implements NotificationService interface for serving RPC. @@ -214,8 +215,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot dbNotification, ok, err := tx.FindVulnerabilityNotification( req.GetName(), int(req.GetLimit()), - database.PageNumber(req.GetOldVulnerabilityPage()), - database.PageNumber(req.GetNewVulnerabilityPage()), + pagination.Token(req.GetOldVulnerabilityPage()), + pagination.Token(req.GetNewVulnerabilityPage()), ) if err != nil { diff --git a/database/database.go b/database/database.go index 5c427494..61edb140 100644 --- a/database/database.go +++ b/database/database.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" "time" + + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -168,16 +170,9 @@ type Session interface { // affected ancestries affected by old or new vulnerability. // // Because the number of affected ancestries maybe large, they are paginated - // and their pages are specified by the given encrypted PageNumbers, which, - // if empty, are always considered first page. - // - // Session interface implementation should have encrypt and decrypt - // functions for PageNumber. - FindVulnerabilityNotification(name string, limit int, - oldVulnerabilityPage PageNumber, - newVulnerabilityPage PageNumber) ( - noti VulnerabilityNotificationWithVulnerable, - found bool, err error) + // and their pages are specified by the paination token, which, if empty, are + // always considered first page. + FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error) // MarkNotificationNotified marks a Notification as notified now, assuming // the requested notification is in the database. diff --git a/database/mock.go b/database/mock.go index ed6e8f16..0cd09919 100644 --- a/database/mock.go +++ b/database/mock.go @@ -14,7 +14,11 @@ package database -import "time" +import ( + "time" + + "github.com/coreos/clair/pkg/pagination" +) // MockSession implements Session and enables overriding each available method. // The default behavior of each method is to simply panic. @@ -38,7 +42,7 @@ type MockSession struct { FctDeleteVulnerabilities func([]VulnerabilityID) error FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) - FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + FctFindVulnerabilityNotification func(name string, limit int, oldPage pagination.Token, newPage pagination.Token) ( vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) FctMarkNotificationNotified func(name string) error FctDeleteNotification func(name string) error @@ -182,7 +186,7 @@ func (ms *MockSession) FindNewNotification(lastNotified time.Time) (Notification panic("required mock function not implemented") } -func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( +func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage pagination.Token, newPage pagination.Token) ( VulnerabilityNotificationWithVulnerable, bool, error) { if ms.FctFindVulnerabilityNotification != nil { return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) diff --git a/database/models.go b/database/models.go index e5a82358..702a079f 100644 --- a/database/models.go +++ b/database/models.go @@ -18,6 +18,8 @@ import ( "database/sql/driver" "encoding/json" "time" + + "github.com/coreos/clair/pkg/pagination" ) // Processors are extentions to scan a layer's content. @@ -173,8 +175,8 @@ type PagedVulnerableAncestries struct { Affected map[int]string Limit int - Current PageNumber - Next PageNumber + Current pagination.Token + Next pagination.Token // End signals the end of the pages. End bool @@ -209,9 +211,6 @@ type VulnerabilityNotificationWithVulnerable struct { New *PagedVulnerableAncestries } -// PageNumber is used to do pagination. -type PageNumber string - // MetadataMap is for storing the metadata returned by vulnerability database. type MetadataMap map[string]interface{} diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index 5c9af262..3a27fb3c 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -23,6 +23,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -163,12 +164,12 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not return notification, true, nil } -func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { +func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) { vulnPage := database.PagedVulnerableAncestries{Limit: limit} - current := Page{0} - if currentPage != "" { + currentPage := Page{0} + if currentToken != pagination.FirstPageToken { var err error - current, err = PageFromPageNumber(tx.key, currentPage) + err = tx.key.UnmarshalToken(currentToken, ¤tPage) if err != nil { return vulnPage, err } @@ -188,7 +189,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr } // the last result is used for the next page's startID - rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) if err != nil { return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } @@ -209,9 +210,9 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr lastIndex = len(ancestries) vulnPage.End = true } else { - // Use the last ancestry's ID as the next PageNumber. + // Use the last ancestry's ID as the next page. lastIndex = len(ancestries) - 1 - vulnPage.Next, err = Page{ancestries[len(ancestries)-1].id}.PageNumber(tx.key) + vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id}) if err != nil { return vulnPage, err } @@ -222,7 +223,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr vulnPage.Affected[int(ancestry.id)] = ancestry.name } - vulnPage.Current, err = current.PageNumber(tx.key) + vulnPage.Current, err = tx.key.MarshalToken(currentPage) if err != nil { return vulnPage, err } @@ -230,7 +231,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr return vulnPage, nil } -func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( +func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) ( database.VulnerabilityNotificationWithVulnerable, bool, error) { var ( noti database.VulnerabilityNotificationWithVulnerable @@ -270,7 +271,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if oldVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) + page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken) if err != nil { return noti, false, err } @@ -278,7 +279,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if newVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) + page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken) if err != nil { return noti, false, err } diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 18e27a5c..ec119e99 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -73,17 +73,20 @@ func TestPagination(t *testing.T) { if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldPage, err := PageFromPageNumber(tx.key, noti.Old.Current) + var oldPage Page + err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } assert.Equal(t, int64(0), oldPage.StartID) - newPage, err := PageFromPageNumber(tx.key, noti.New.Current) + var newPage Page + err = tx.key.UnmarshalToken(noti.New.Current, &newPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newPageNext, err := PageFromPageNumber(tx.key, noti.New.Next) + var newPageNext Page + err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext) if !assert.Nil(t, err) { assert.FailNow(t, "") } @@ -98,12 +101,12 @@ func TestPagination(t *testing.T) { } } - pageNum1, err := Page{0}.PageNumber(tx.key) + pageNum1, err := tx.key.MarshalToken(Page{0}) if !assert.Nil(t, err) { assert.FailNow(t, "") } - pageNum2, err := Page{4}.PageNumber(tx.key) + pageNum2, err := tx.key.MarshalToken(Page{4}) if !assert.Nil(t, err) { assert.FailNow(t, "") } @@ -112,12 +115,14 @@ func TestPagination(t *testing.T) { if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldCurrentPage, err := PageFromPageNumber(tx.key, noti.Old.Current) + var oldCurrentPage Page + err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newCurrentPage, err := PageFromPageNumber(tx.key, noti.New.Current) + var newCurrentPage Page + err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } diff --git a/database/pgsql/page.go b/database/pgsql/page.go deleted file mode 100644 index 101d1dff..00000000 --- a/database/pgsql/page.go +++ /dev/null @@ -1,45 +0,0 @@ -// Copyright 2018 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package pgsql - -import ( - "github.com/coreos/clair/database" - "github.com/coreos/clair/pkg/pagination" -) - -// Page is the representation of a page for the Postgres schema. -type Page struct { - // StartID is the ID being used as the basis for pagination across database - // results. It is used to search for an ancestry with ID >= StartID. - // - // StartID is required to be unique to every ancestry and always increasing. - StartID int64 -} - -// PageNumber converts a Page to a database.PageNumber. -func (p Page) PageNumber(key pagination.Key) (pn database.PageNumber, err error) { - token, err := key.MarshalToken(p) - if err != nil { - return pn, err - } - pn = database.PageNumber(token) - return pn, nil -} - -// PageFromPageNumber converts a database.PageNumber into a Page. -func PageFromPageNumber(key pagination.Key, pn database.PageNumber) (p Page, err error) { - err = key.UnmarshalToken(string(pn), &p) - return -} diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 66e8930e..9af010b6 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -131,6 +131,15 @@ func (pgSQL *pgSQL) Ping() bool { return pgSQL.DB.Ping() == nil } +// Page is the representation of a page for the Postgres schema. +type Page struct { + // StartID is the ID being used as the basis for pagination across database + // results. It is used to search for an ancestry with ID >= StartID. + // + // StartID is required to be unique to every ancestry and always increasing. + StartID int64 +} + // Config is the configuration that is used by openDatabase. type Config struct { Source string diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 0e5b7234..e4a8c8b4 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -221,7 +221,7 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data "cachesize": 0, "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, - "paginationkey": pagination.MustGenerateNewKey().String(), + "paginationkey": pagination.Must(pagination.NewKey()).String(), }, } } diff --git a/pkg/pagination/pagination.go b/pkg/pagination/pagination.go index 5f4acb66..1e765cec 100644 --- a/pkg/pagination/pagination.go +++ b/pkg/pagination/pagination.go @@ -38,6 +38,13 @@ type Key struct { fkey *fernet.Key } +// Token represents an opaque pagination token keeping track of a user's +// progress iterating through a list of results. +type Token string + +// FirstPageToken is used to represent the first page of content. +var FirstPageToken = Token("") + // NewKey generates a new random pagination key. func NewKey() (k Key, err error) { k.fkey = new(fernet.Key) @@ -71,21 +78,22 @@ func (k Key) String() string { return k.fkey.Encode() } -// MarshalToken encodes an interface into JSON bytes and encrypts it. -func (k Key) MarshalToken(v interface{}) ([]byte, error) { +// MarshalToken encodes an interface into JSON bytes and produces a Token. +func (k Key) MarshalToken(v interface{}) (Token, error) { var buf bytes.Buffer err := json.NewEncoder(&buf).Encode(v) if err != nil { - return nil, err + return Token(""), err } - return fernet.EncryptAndSign(buf.Bytes(), k.fkey) + tokenBytes, err := fernet.EncryptAndSign(buf.Bytes(), k.fkey) + return Token(tokenBytes), err } -// UnmarshalToken decrypts a token using provided key and decodes the result +// UnmarshalToken decrypts a Token using provided key and decodes the result // into the provided interface. -func (k Key) UnmarshalToken(token string, v interface{}) error { - msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k.fkey}) +func (k Key) UnmarshalToken(t Token, v interface{}) error { + msg := fernet.VerifyAndDecrypt([]byte(t), time.Hour, []*fernet.Key{k.fkey}) if msg == nil { return ErrInvalidToken }