diff --git a/api/v3/rpc.go b/api/v3/rpc.go index 9c893bf2..1111c09d 100644 --- a/api/v3/rpc.go +++ b/api/v3/rpc.go @@ -25,6 +25,7 @@ import ( pb "github.com/coreos/clair/api/v3/clairpb" "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) // NotificationServer implements NotificationService interface for serving RPC. @@ -214,8 +215,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot dbNotification, ok, err := tx.FindVulnerabilityNotification( req.GetName(), int(req.GetLimit()), - database.PageNumber(req.GetOldVulnerabilityPage()), - database.PageNumber(req.GetNewVulnerabilityPage()), + pagination.Token(req.GetOldVulnerabilityPage()), + pagination.Token(req.GetNewVulnerabilityPage()), ) if err != nil { diff --git a/cmd/clair/config.go b/cmd/clair/config.go index 08f26066..09a01364 100644 --- a/cmd/clair/config.go +++ b/cmd/clair/config.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2018 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -20,7 +20,6 @@ import ( "os" "time" - "github.com/fernet/fernet-go" log "github.com/sirupsen/logrus" "gopkg.in/yaml.v2" @@ -31,6 +30,7 @@ import ( "github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/notification" "github.com/coreos/clair/ext/vulnsrc" + "github.com/coreos/clair/pkg/pagination" ) // ErrDatasourceNotLoaded is returned when the datasource variable in the @@ -108,15 +108,10 @@ func LoadConfig(path string) (config *Config, err error) { // Generate a pagination key if none is provided. if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" { log.Warn("pagination key is empty, generating...") - var key fernet.Key - if err = key.Generate(); err != nil { - return - } - config.Database.Options["paginationkey"] = key.Encode() + config.Database.Options["paginationkey"] = pagination.Must(pagination.NewKey()).String() } else { - _, err = fernet.DecodeKey(config.Database.Options["paginationkey"].(string)) + _, err = pagination.KeyFromString(config.Database.Options["paginationkey"].(string)) if err != nil { - err = errors.New("Invalid Pagination key; must be 32-bit URL-safe base64") return } } diff --git a/cmd/clair/main.go b/cmd/clair/main.go index dbb360a8..da046351 100644 --- a/cmd/clair/main.go +++ b/cmd/clair/main.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2018 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -61,7 +61,17 @@ import ( _ "github.com/coreos/clair/ext/vulnsrc/ubuntu" ) -const maxDBConnectionAttempts = 20 +// MaxDBConnectionAttempts is the total number of tries that Clair will use to +// initially connect to a database at start-up. +const MaxDBConnectionAttempts = 20 + +// BinaryDependencies are the programs that Clair expects to be on the $PATH +// because it creates subprocesses of these programs. +var BinaryDependencies = []string{ + "git", + "rpm", + "xz", +} func waitForSignals(signals ...os.Signal) { interrupts := make(chan os.Signal, 1) @@ -136,7 +146,7 @@ func Boot(config *Config) { // Open database var db database.Datastore var dbError error - for attempts := 1; attempts <= maxDBConnectionAttempts; attempts++ { + for attempts := 1; attempts <= MaxDBConnectionAttempts; attempts++ { db, dbError = database.Open(config.Database) if dbError == nil { break @@ -180,7 +190,7 @@ func main() { flag.Parse() // Check for dependencies. - for _, bin := range []string{"git", "rpm", "xz"} { + for _, bin := range BinaryDependencies { _, err := exec.LookPath(bin) if err != nil { log.WithError(err).WithField("dependency", bin).Fatal("failed to find dependency") diff --git a/database/database.go b/database/database.go index 5c427494..61edb140 100644 --- a/database/database.go +++ b/database/database.go @@ -20,6 +20,8 @@ import ( "errors" "fmt" "time" + + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -168,16 +170,9 @@ type Session interface { // affected ancestries affected by old or new vulnerability. // // Because the number of affected ancestries maybe large, they are paginated - // and their pages are specified by the given encrypted PageNumbers, which, - // if empty, are always considered first page. - // - // Session interface implementation should have encrypt and decrypt - // functions for PageNumber. - FindVulnerabilityNotification(name string, limit int, - oldVulnerabilityPage PageNumber, - newVulnerabilityPage PageNumber) ( - noti VulnerabilityNotificationWithVulnerable, - found bool, err error) + // and their pages are specified by the paination token, which, if empty, are + // always considered first page. + FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error) // MarkNotificationNotified marks a Notification as notified now, assuming // the requested notification is in the database. diff --git a/database/mock.go b/database/mock.go index ed6e8f16..0cd09919 100644 --- a/database/mock.go +++ b/database/mock.go @@ -14,7 +14,11 @@ package database -import "time" +import ( + "time" + + "github.com/coreos/clair/pkg/pagination" +) // MockSession implements Session and enables overriding each available method. // The default behavior of each method is to simply panic. @@ -38,7 +42,7 @@ type MockSession struct { FctDeleteVulnerabilities func([]VulnerabilityID) error FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) - FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( + FctFindVulnerabilityNotification func(name string, limit int, oldPage pagination.Token, newPage pagination.Token) ( vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) FctMarkNotificationNotified func(name string) error FctDeleteNotification func(name string) error @@ -182,7 +186,7 @@ func (ms *MockSession) FindNewNotification(lastNotified time.Time) (Notification panic("required mock function not implemented") } -func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( +func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage pagination.Token, newPage pagination.Token) ( VulnerabilityNotificationWithVulnerable, bool, error) { if ms.FctFindVulnerabilityNotification != nil { return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) diff --git a/database/models.go b/database/models.go index e5a82358..702a079f 100644 --- a/database/models.go +++ b/database/models.go @@ -18,6 +18,8 @@ import ( "database/sql/driver" "encoding/json" "time" + + "github.com/coreos/clair/pkg/pagination" ) // Processors are extentions to scan a layer's content. @@ -173,8 +175,8 @@ type PagedVulnerableAncestries struct { Affected map[int]string Limit int - Current PageNumber - Next PageNumber + Current pagination.Token + Next pagination.Token // End signals the end of the pages. End bool @@ -209,9 +211,6 @@ type VulnerabilityNotificationWithVulnerable struct { New *PagedVulnerableAncestries } -// PageNumber is used to do pagination. -type PageNumber string - // MetadataMap is for storing the metadata returned by vulnerability database. type MetadataMap map[string]interface{} diff --git a/database/pgsql/notification.go b/database/pgsql/notification.go index ebc346d3..3a27fb3c 100644 --- a/database/pgsql/notification.go +++ b/database/pgsql/notification.go @@ -23,6 +23,7 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -163,12 +164,12 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not return notification, true, nil } -func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { +func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) { vulnPage := database.PagedVulnerableAncestries{Limit: limit} - current := idPageNumber{0} - if currentPage != "" { + currentPage := Page{0} + if currentToken != pagination.FirstPageToken { var err error - current, err = decryptPage(currentPage, tx.paginationKey) + err = tx.key.UnmarshalToken(currentToken, ¤tPage) if err != nil { return vulnPage, err } @@ -188,7 +189,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr } // the last result is used for the next page's startID - rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) + rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1) if err != nil { return vulnPage, handleError("searchNotificationVulnerableAncestry", err) } @@ -209,13 +210,9 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr lastIndex = len(ancestries) vulnPage.End = true } else { - // Use the last ancestry's ID as the next PageNumber. + // Use the last ancestry's ID as the next page. lastIndex = len(ancestries) - 1 - vulnPage.Next, err = encryptPage( - idPageNumber{ - ancestries[len(ancestries)-1].id, - }, tx.paginationKey) - + vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id}) if err != nil { return vulnPage, err } @@ -226,7 +223,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr vulnPage.Affected[int(ancestry.id)] = ancestry.name } - vulnPage.Current, err = encryptPage(current, tx.paginationKey) + vulnPage.Current, err = tx.key.MarshalToken(currentPage) if err != nil { return vulnPage, err } @@ -234,7 +231,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr return vulnPage, nil } -func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( +func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) ( database.VulnerabilityNotificationWithVulnerable, bool, error) { var ( noti database.VulnerabilityNotificationWithVulnerable @@ -274,7 +271,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if oldVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) + page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken) if err != nil { return noti, false, err } @@ -282,7 +279,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa } if newVulnID.Valid { - page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) + page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken) if err != nil { return noti, false, err } diff --git a/database/pgsql/notification_test.go b/database/pgsql/notification_test.go index 0d930d08..ec119e99 100644 --- a/database/pgsql/notification_test.go +++ b/database/pgsql/notification_test.go @@ -1,4 +1,4 @@ -// Copyright 2017 clair authors +// Copyright 2018 clair authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -73,22 +73,25 @@ func TestPagination(t *testing.T) { if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey) + var oldPage Page + err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - assert.Equal(t, int64(0), oldPageNum.StartID) - newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey) + assert.Equal(t, int64(0), oldPage.StartID) + var newPage Page + err = tx.key.UnmarshalToken(noti.New.Current, &newPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey) + var newPageNext Page + err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext) if !assert.Nil(t, err) { assert.FailNow(t, "") } - assert.Equal(t, int64(0), newPageNum.StartID) - assert.Equal(t, int64(4), newPageNextNum.StartID) + assert.Equal(t, int64(0), newPage.StartID) + assert.Equal(t, int64(4), newPageNext.StartID) noti.Old.Current = "" noti.New.Current = "" @@ -98,26 +101,28 @@ func TestPagination(t *testing.T) { } } - page1, err := encryptPage(idPageNumber{0}, tx.paginationKey) + pageNum1, err := tx.key.MarshalToken(Page{0}) if !assert.Nil(t, err) { assert.FailNow(t, "") } - page2, err := encryptPage(idPageNumber{4}, tx.paginationKey) + pageNum2, err := tx.key.MarshalToken(Page{4}) if !assert.Nil(t, err) { assert.FailNow(t, "") } - noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2) + noti, ok, err = tx.FindVulnerabilityNotification("test", 1, pageNum1, pageNum2) if assert.Nil(t, err) && assert.True(t, ok) { assert.Equal(t, "test", noti.Name) if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { - oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey) + var oldCurrentPage Page + err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } - newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey) + var newCurrentPage Page + err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage) if !assert.Nil(t, err) { assert.FailNow(t, "") } diff --git a/database/pgsql/pgsql.go b/database/pgsql/pgsql.go index 335b77c7..9af010b6 100644 --- a/database/pgsql/pgsql.go +++ b/database/pgsql/pgsql.go @@ -33,8 +33,8 @@ import ( "github.com/coreos/clair/database" "github.com/coreos/clair/database/pgsql/migrations" - "github.com/coreos/clair/database/pgsql/token" "github.com/coreos/clair/pkg/commonerr" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -91,41 +91,21 @@ type pgSQL struct { type pgSession struct { *sql.Tx - paginationKey string + key pagination.Key } -type idPageNumber struct { - // StartID is an implementation detail for paginating by an ID required to - // be unique to every ancestry and always increasing. - // - // StartID is used to search for ancestry with ID >= StartID - StartID int64 -} - -func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) { - resultBytes, err := token.Marshal(page, paginationKey) - if err != nil { - return result, err - } - result = database.PageNumber(resultBytes) - return result, nil -} - -func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) { - err = token.Unmarshal(string(page), paginationKey, &result) - return -} - -// Begin initiates a transaction to database. The expected transaction isolation -// level in this implementation is "Read Committed". +// Begin initiates a transaction to database. +// +// The expected transaction isolation level in this implementation is "Read +// Committed". func (pgSQL *pgSQL) Begin() (database.Session, error) { tx, err := pgSQL.DB.Begin() if err != nil { return nil, err } return &pgSession{ - Tx: tx, - paginationKey: pgSQL.config.PaginationKey, + Tx: tx, + key: pagination.Must(pagination.KeyFromString(pgSQL.config.PaginationKey)), }, nil } @@ -151,6 +131,15 @@ func (pgSQL *pgSQL) Ping() bool { return pgSQL.DB.Ping() == nil } +// Page is the representation of a page for the Postgres schema. +type Page struct { + // StartID is the ID being used as the basis for pagination across database + // results. It is used to search for an ancestry with ID >= StartID. + // + // StartID is required to be unique to every ancestry and always increasing. + StartID int64 +} + // Config is the configuration that is used by openDatabase. type Config struct { Source string diff --git a/database/pgsql/pgsql_test.go b/database/pgsql/pgsql_test.go index 96241666..e4a8c8b4 100644 --- a/database/pgsql/pgsql_test.go +++ b/database/pgsql/pgsql_test.go @@ -24,13 +24,13 @@ import ( "strings" "testing" - fernet "github.com/fernet/fernet-go" "github.com/pborman/uuid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" yaml "gopkg.in/yaml.v2" "github.com/coreos/clair/database" + "github.com/coreos/clair/pkg/pagination" ) var ( @@ -215,18 +215,13 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data source = fmt.Sprintf(sourceEnv, dbName) } - var key fernet.Key - if err := key.Generate(); err != nil { - panic("failed to generate pagination key" + err.Error()) - } - return database.RegistrableComponentConfig{ Options: map[string]interface{}{ "source": source, "cachesize": 0, "managedatabaselifecycle": manageLife, "fixturepath": fixturePath, - "paginationkey": key.Encode(), + "paginationkey": pagination.Must(pagination.NewKey()).String(), }, } } diff --git a/database/pgsql/token/token.go b/database/pgsql/token/token.go deleted file mode 100644 index 3492e712..00000000 --- a/database/pgsql/token/token.go +++ /dev/null @@ -1,49 +0,0 @@ -// Copyright 2017 clair authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package token implements encryption/decryption for json encoded interfaces -package token - -import ( - "bytes" - "encoding/json" - "errors" - "time" - - "github.com/fernet/fernet-go" -) - -// Unmarshal decrypts a token using provided key -// and decode the result into interface. -func Unmarshal(token string, key string, v interface{}) error { - k, _ := fernet.DecodeKey(key) - msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k}) - if msg == nil { - return errors.New("invalid or expired pagination token") - } - - return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v) -} - -// Marshal encodes an interface into json bytes and encrypts it. -func Marshal(v interface{}, key string) ([]byte, error) { - var buf bytes.Buffer - err := json.NewEncoder(&buf).Encode(v) - if err != nil { - return nil, err - } - - k, _ := fernet.DecodeKey(key) - return fernet.EncryptAndSign(buf.Bytes(), k) -} diff --git a/pkg/pagination/pagination.go b/pkg/pagination/pagination.go new file mode 100644 index 00000000..1e765cec --- /dev/null +++ b/pkg/pagination/pagination.go @@ -0,0 +1,102 @@ +// Copyright 2018 clair authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package pagination implements a series of utilities for dealing with +// paginating lists of objects for an API. +package pagination + +import ( + "bytes" + "encoding/json" + "errors" + "time" + + "github.com/fernet/fernet-go" +) + +// ErrInvalidToken is returned when a token fails to Unmarshal because it was +// invalid or expired. +var ErrInvalidToken = errors.New("invalid or expired pagination token") + +// ErrInvalidKeyString is returned when the string representing a key is malformed. +var ErrInvalidKeyString = errors.New("invalid pagination key string: must be 32-byte URL-safe base64") + +// Key represents the key used to cryptographically secure the token +// being used to keep track of pages. +type Key struct { + fkey *fernet.Key +} + +// Token represents an opaque pagination token keeping track of a user's +// progress iterating through a list of results. +type Token string + +// FirstPageToken is used to represent the first page of content. +var FirstPageToken = Token("") + +// NewKey generates a new random pagination key. +func NewKey() (k Key, err error) { + k.fkey = new(fernet.Key) + err = k.fkey.Generate() + return k, err +} + +// KeyFromString creates the key for a given string. +// +// Strings must be 32-byte URL-safe base64 representations of the key bytes. +func KeyFromString(keyString string) (k Key, err error) { + var fkey *fernet.Key + fkey, err = fernet.DecodeKey(keyString) + if err != nil { + return Key{}, ErrInvalidKeyString + } + return Key{fkey}, err +} + +// Must is a helper that wraps calls returning a Key and and error and panics +// if the error is non-nil. +func Must(k Key, err error) Key { + if err != nil { + panic(err) + } + return k +} + +// String implements the fmt.Stringer interface for Key. +func (k Key) String() string { + return k.fkey.Encode() +} + +// MarshalToken encodes an interface into JSON bytes and produces a Token. +func (k Key) MarshalToken(v interface{}) (Token, error) { + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(v) + if err != nil { + return Token(""), err + } + + tokenBytes, err := fernet.EncryptAndSign(buf.Bytes(), k.fkey) + return Token(tokenBytes), err +} + +// UnmarshalToken decrypts a Token using provided key and decodes the result +// into the provided interface. +func (k Key) UnmarshalToken(t Token, v interface{}) error { + msg := fernet.VerifyAndDecrypt([]byte(t), time.Hour, []*fernet.Key{k.fkey}) + if msg == nil { + return ErrInvalidToken + } + + return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v) +}