Merge pull request #613 from jzelinskie/pkg-pagination
Introduce pkg/pagination
This commit is contained in:
commit
e5c2e378a2
@ -25,6 +25,7 @@ import (
|
|||||||
pb "github.com/coreos/clair/api/v3/clairpb"
|
pb "github.com/coreos/clair/api/v3/clairpb"
|
||||||
"github.com/coreos/clair/database"
|
"github.com/coreos/clair/database"
|
||||||
"github.com/coreos/clair/pkg/commonerr"
|
"github.com/coreos/clair/pkg/commonerr"
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NotificationServer implements NotificationService interface for serving RPC.
|
// NotificationServer implements NotificationService interface for serving RPC.
|
||||||
@ -214,8 +215,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
|
|||||||
dbNotification, ok, err := tx.FindVulnerabilityNotification(
|
dbNotification, ok, err := tx.FindVulnerabilityNotification(
|
||||||
req.GetName(),
|
req.GetName(),
|
||||||
int(req.GetLimit()),
|
int(req.GetLimit()),
|
||||||
database.PageNumber(req.GetOldVulnerabilityPage()),
|
pagination.Token(req.GetOldVulnerabilityPage()),
|
||||||
database.PageNumber(req.GetNewVulnerabilityPage()),
|
pagination.Token(req.GetNewVulnerabilityPage()),
|
||||||
)
|
)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright 2017 clair authors
|
// Copyright 2018 clair authors
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -20,7 +20,6 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/fernet/fernet-go"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
|
||||||
@ -31,6 +30,7 @@ import (
|
|||||||
"github.com/coreos/clair/ext/featurens"
|
"github.com/coreos/clair/ext/featurens"
|
||||||
"github.com/coreos/clair/ext/notification"
|
"github.com/coreos/clair/ext/notification"
|
||||||
"github.com/coreos/clair/ext/vulnsrc"
|
"github.com/coreos/clair/ext/vulnsrc"
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrDatasourceNotLoaded is returned when the datasource variable in the
|
// ErrDatasourceNotLoaded is returned when the datasource variable in the
|
||||||
@ -108,15 +108,10 @@ func LoadConfig(path string) (config *Config, err error) {
|
|||||||
// Generate a pagination key if none is provided.
|
// Generate a pagination key if none is provided.
|
||||||
if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" {
|
if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" {
|
||||||
log.Warn("pagination key is empty, generating...")
|
log.Warn("pagination key is empty, generating...")
|
||||||
var key fernet.Key
|
config.Database.Options["paginationkey"] = pagination.Must(pagination.NewKey()).String()
|
||||||
if err = key.Generate(); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
config.Database.Options["paginationkey"] = key.Encode()
|
|
||||||
} else {
|
} else {
|
||||||
_, err = fernet.DecodeKey(config.Database.Options["paginationkey"].(string))
|
_, err = pagination.KeyFromString(config.Database.Options["paginationkey"].(string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = errors.New("Invalid Pagination key; must be 32-bit URL-safe base64")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright 2017 clair authors
|
// Copyright 2018 clair authors
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -61,7 +61,17 @@ import (
|
|||||||
_ "github.com/coreos/clair/ext/vulnsrc/ubuntu"
|
_ "github.com/coreos/clair/ext/vulnsrc/ubuntu"
|
||||||
)
|
)
|
||||||
|
|
||||||
const maxDBConnectionAttempts = 20
|
// MaxDBConnectionAttempts is the total number of tries that Clair will use to
|
||||||
|
// initially connect to a database at start-up.
|
||||||
|
const MaxDBConnectionAttempts = 20
|
||||||
|
|
||||||
|
// BinaryDependencies are the programs that Clair expects to be on the $PATH
|
||||||
|
// because it creates subprocesses of these programs.
|
||||||
|
var BinaryDependencies = []string{
|
||||||
|
"git",
|
||||||
|
"rpm",
|
||||||
|
"xz",
|
||||||
|
}
|
||||||
|
|
||||||
func waitForSignals(signals ...os.Signal) {
|
func waitForSignals(signals ...os.Signal) {
|
||||||
interrupts := make(chan os.Signal, 1)
|
interrupts := make(chan os.Signal, 1)
|
||||||
@ -136,7 +146,7 @@ func Boot(config *Config) {
|
|||||||
// Open database
|
// Open database
|
||||||
var db database.Datastore
|
var db database.Datastore
|
||||||
var dbError error
|
var dbError error
|
||||||
for attempts := 1; attempts <= maxDBConnectionAttempts; attempts++ {
|
for attempts := 1; attempts <= MaxDBConnectionAttempts; attempts++ {
|
||||||
db, dbError = database.Open(config.Database)
|
db, dbError = database.Open(config.Database)
|
||||||
if dbError == nil {
|
if dbError == nil {
|
||||||
break
|
break
|
||||||
@ -180,7 +190,7 @@ func main() {
|
|||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
// Check for dependencies.
|
// Check for dependencies.
|
||||||
for _, bin := range []string{"git", "rpm", "xz"} {
|
for _, bin := range BinaryDependencies {
|
||||||
_, err := exec.LookPath(bin)
|
_, err := exec.LookPath(bin)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithError(err).WithField("dependency", bin).Fatal("failed to find dependency")
|
log.WithError(err).WithField("dependency", bin).Fatal("failed to find dependency")
|
||||||
|
@ -20,6 +20,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -168,16 +170,9 @@ type Session interface {
|
|||||||
// affected ancestries affected by old or new vulnerability.
|
// affected ancestries affected by old or new vulnerability.
|
||||||
//
|
//
|
||||||
// Because the number of affected ancestries maybe large, they are paginated
|
// Because the number of affected ancestries maybe large, they are paginated
|
||||||
// and their pages are specified by the given encrypted PageNumbers, which,
|
// and their pages are specified by the paination token, which, if empty, are
|
||||||
// if empty, are always considered first page.
|
// always considered first page.
|
||||||
//
|
FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error)
|
||||||
// Session interface implementation should have encrypt and decrypt
|
|
||||||
// functions for PageNumber.
|
|
||||||
FindVulnerabilityNotification(name string, limit int,
|
|
||||||
oldVulnerabilityPage PageNumber,
|
|
||||||
newVulnerabilityPage PageNumber) (
|
|
||||||
noti VulnerabilityNotificationWithVulnerable,
|
|
||||||
found bool, err error)
|
|
||||||
|
|
||||||
// MarkNotificationNotified marks a Notification as notified now, assuming
|
// MarkNotificationNotified marks a Notification as notified now, assuming
|
||||||
// the requested notification is in the database.
|
// the requested notification is in the database.
|
||||||
|
@ -14,7 +14,11 @@
|
|||||||
|
|
||||||
package database
|
package database
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
|
)
|
||||||
|
|
||||||
// MockSession implements Session and enables overriding each available method.
|
// MockSession implements Session and enables overriding each available method.
|
||||||
// The default behavior of each method is to simply panic.
|
// The default behavior of each method is to simply panic.
|
||||||
@ -38,7 +42,7 @@ type MockSession struct {
|
|||||||
FctDeleteVulnerabilities func([]VulnerabilityID) error
|
FctDeleteVulnerabilities func([]VulnerabilityID) error
|
||||||
FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error
|
FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error
|
||||||
FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error)
|
FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error)
|
||||||
FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) (
|
FctFindVulnerabilityNotification func(name string, limit int, oldPage pagination.Token, newPage pagination.Token) (
|
||||||
vuln VulnerabilityNotificationWithVulnerable, ok bool, err error)
|
vuln VulnerabilityNotificationWithVulnerable, ok bool, err error)
|
||||||
FctMarkNotificationNotified func(name string) error
|
FctMarkNotificationNotified func(name string) error
|
||||||
FctDeleteNotification func(name string) error
|
FctDeleteNotification func(name string) error
|
||||||
@ -182,7 +186,7 @@ func (ms *MockSession) FindNewNotification(lastNotified time.Time) (Notification
|
|||||||
panic("required mock function not implemented")
|
panic("required mock function not implemented")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) (
|
func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage pagination.Token, newPage pagination.Token) (
|
||||||
VulnerabilityNotificationWithVulnerable, bool, error) {
|
VulnerabilityNotificationWithVulnerable, bool, error) {
|
||||||
if ms.FctFindVulnerabilityNotification != nil {
|
if ms.FctFindVulnerabilityNotification != nil {
|
||||||
return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage)
|
return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage)
|
||||||
|
@ -18,6 +18,8 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Processors are extentions to scan a layer's content.
|
// Processors are extentions to scan a layer's content.
|
||||||
@ -173,8 +175,8 @@ type PagedVulnerableAncestries struct {
|
|||||||
Affected map[int]string
|
Affected map[int]string
|
||||||
|
|
||||||
Limit int
|
Limit int
|
||||||
Current PageNumber
|
Current pagination.Token
|
||||||
Next PageNumber
|
Next pagination.Token
|
||||||
|
|
||||||
// End signals the end of the pages.
|
// End signals the end of the pages.
|
||||||
End bool
|
End bool
|
||||||
@ -209,9 +211,6 @@ type VulnerabilityNotificationWithVulnerable struct {
|
|||||||
New *PagedVulnerableAncestries
|
New *PagedVulnerableAncestries
|
||||||
}
|
}
|
||||||
|
|
||||||
// PageNumber is used to do pagination.
|
|
||||||
type PageNumber string
|
|
||||||
|
|
||||||
// MetadataMap is for storing the metadata returned by vulnerability database.
|
// MetadataMap is for storing the metadata returned by vulnerability database.
|
||||||
type MetadataMap map[string]interface{}
|
type MetadataMap map[string]interface{}
|
||||||
|
|
||||||
|
@ -23,6 +23,7 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/clair/database"
|
"github.com/coreos/clair/database"
|
||||||
"github.com/coreos/clair/pkg/commonerr"
|
"github.com/coreos/clair/pkg/commonerr"
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -163,12 +164,12 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
|
|||||||
return notification, true, nil
|
return notification, true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) {
|
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) {
|
||||||
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
|
vulnPage := database.PagedVulnerableAncestries{Limit: limit}
|
||||||
current := idPageNumber{0}
|
currentPage := Page{0}
|
||||||
if currentPage != "" {
|
if currentToken != pagination.FirstPageToken {
|
||||||
var err error
|
var err error
|
||||||
current, err = decryptPage(currentPage, tx.paginationKey)
|
err = tx.key.UnmarshalToken(currentToken, ¤tPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vulnPage, err
|
return vulnPage, err
|
||||||
}
|
}
|
||||||
@ -188,7 +189,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
|
|||||||
}
|
}
|
||||||
|
|
||||||
// the last result is used for the next page's startID
|
// the last result is used for the next page's startID
|
||||||
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1)
|
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
|
return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
|
||||||
}
|
}
|
||||||
@ -209,13 +210,9 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
|
|||||||
lastIndex = len(ancestries)
|
lastIndex = len(ancestries)
|
||||||
vulnPage.End = true
|
vulnPage.End = true
|
||||||
} else {
|
} else {
|
||||||
// Use the last ancestry's ID as the next PageNumber.
|
// Use the last ancestry's ID as the next page.
|
||||||
lastIndex = len(ancestries) - 1
|
lastIndex = len(ancestries) - 1
|
||||||
vulnPage.Next, err = encryptPage(
|
vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id})
|
||||||
idPageNumber{
|
|
||||||
ancestries[len(ancestries)-1].id,
|
|
||||||
}, tx.paginationKey)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vulnPage, err
|
return vulnPage, err
|
||||||
}
|
}
|
||||||
@ -226,7 +223,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
|
|||||||
vulnPage.Affected[int(ancestry.id)] = ancestry.name
|
vulnPage.Affected[int(ancestry.id)] = ancestry.name
|
||||||
}
|
}
|
||||||
|
|
||||||
vulnPage.Current, err = encryptPage(current, tx.paginationKey)
|
vulnPage.Current, err = tx.key.MarshalToken(currentPage)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return vulnPage, err
|
return vulnPage, err
|
||||||
}
|
}
|
||||||
@ -234,7 +231,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
|
|||||||
return vulnPage, nil
|
return vulnPage, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) (
|
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) (
|
||||||
database.VulnerabilityNotificationWithVulnerable, bool, error) {
|
database.VulnerabilityNotificationWithVulnerable, bool, error) {
|
||||||
var (
|
var (
|
||||||
noti database.VulnerabilityNotificationWithVulnerable
|
noti database.VulnerabilityNotificationWithVulnerable
|
||||||
@ -274,7 +271,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
|
|||||||
}
|
}
|
||||||
|
|
||||||
if oldVulnID.Valid {
|
if oldVulnID.Valid {
|
||||||
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage)
|
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noti, false, err
|
return noti, false, err
|
||||||
}
|
}
|
||||||
@ -282,7 +279,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
|
|||||||
}
|
}
|
||||||
|
|
||||||
if newVulnID.Valid {
|
if newVulnID.Valid {
|
||||||
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage)
|
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return noti, false, err
|
return noti, false, err
|
||||||
}
|
}
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
// Copyright 2017 clair authors
|
// Copyright 2018 clair authors
|
||||||
//
|
//
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
// you may not use this file except in compliance with the License.
|
// you may not use this file except in compliance with the License.
|
||||||
@ -73,22 +73,25 @@ func TestPagination(t *testing.T) {
|
|||||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||||
assert.Equal(t, "test", noti.Name)
|
assert.Equal(t, "test", noti.Name)
|
||||||
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
|
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
|
||||||
oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey)
|
var oldPage Page
|
||||||
|
err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage)
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Equal(t, int64(0), oldPageNum.StartID)
|
assert.Equal(t, int64(0), oldPage.StartID)
|
||||||
newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey)
|
var newPage Page
|
||||||
|
err = tx.key.UnmarshalToken(noti.New.Current, &newPage)
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey)
|
var newPageNext Page
|
||||||
|
err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext)
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
assert.Equal(t, int64(0), newPageNum.StartID)
|
assert.Equal(t, int64(0), newPage.StartID)
|
||||||
assert.Equal(t, int64(4), newPageNextNum.StartID)
|
assert.Equal(t, int64(4), newPageNext.StartID)
|
||||||
|
|
||||||
noti.Old.Current = ""
|
noti.Old.Current = ""
|
||||||
noti.New.Current = ""
|
noti.New.Current = ""
|
||||||
@ -98,26 +101,28 @@ func TestPagination(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
page1, err := encryptPage(idPageNumber{0}, tx.paginationKey)
|
pageNum1, err := tx.key.MarshalToken(Page{0})
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
page2, err := encryptPage(idPageNumber{4}, tx.paginationKey)
|
pageNum2, err := tx.key.MarshalToken(Page{4})
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2)
|
noti, ok, err = tx.FindVulnerabilityNotification("test", 1, pageNum1, pageNum2)
|
||||||
if assert.Nil(t, err) && assert.True(t, ok) {
|
if assert.Nil(t, err) && assert.True(t, ok) {
|
||||||
assert.Equal(t, "test", noti.Name)
|
assert.Equal(t, "test", noti.Name)
|
||||||
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
|
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
|
||||||
oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey)
|
var oldCurrentPage Page
|
||||||
|
err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage)
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey)
|
var newCurrentPage Page
|
||||||
|
err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage)
|
||||||
if !assert.Nil(t, err) {
|
if !assert.Nil(t, err) {
|
||||||
assert.FailNow(t, "")
|
assert.FailNow(t, "")
|
||||||
}
|
}
|
||||||
|
@ -33,8 +33,8 @@ import (
|
|||||||
|
|
||||||
"github.com/coreos/clair/database"
|
"github.com/coreos/clair/database"
|
||||||
"github.com/coreos/clair/database/pgsql/migrations"
|
"github.com/coreos/clair/database/pgsql/migrations"
|
||||||
"github.com/coreos/clair/database/pgsql/token"
|
|
||||||
"github.com/coreos/clair/pkg/commonerr"
|
"github.com/coreos/clair/pkg/commonerr"
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -91,41 +91,21 @@ type pgSQL struct {
|
|||||||
type pgSession struct {
|
type pgSession struct {
|
||||||
*sql.Tx
|
*sql.Tx
|
||||||
|
|
||||||
paginationKey string
|
key pagination.Key
|
||||||
}
|
}
|
||||||
|
|
||||||
type idPageNumber struct {
|
// Begin initiates a transaction to database.
|
||||||
// StartID is an implementation detail for paginating by an ID required to
|
//
|
||||||
// be unique to every ancestry and always increasing.
|
// The expected transaction isolation level in this implementation is "Read
|
||||||
//
|
// Committed".
|
||||||
// StartID is used to search for ancestry with ID >= StartID
|
|
||||||
StartID int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) {
|
|
||||||
resultBytes, err := token.Marshal(page, paginationKey)
|
|
||||||
if err != nil {
|
|
||||||
return result, err
|
|
||||||
}
|
|
||||||
result = database.PageNumber(resultBytes)
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) {
|
|
||||||
err = token.Unmarshal(string(page), paginationKey, &result)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Begin initiates a transaction to database. The expected transaction isolation
|
|
||||||
// level in this implementation is "Read Committed".
|
|
||||||
func (pgSQL *pgSQL) Begin() (database.Session, error) {
|
func (pgSQL *pgSQL) Begin() (database.Session, error) {
|
||||||
tx, err := pgSQL.DB.Begin()
|
tx, err := pgSQL.DB.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &pgSession{
|
return &pgSession{
|
||||||
Tx: tx,
|
Tx: tx,
|
||||||
paginationKey: pgSQL.config.PaginationKey,
|
key: pagination.Must(pagination.KeyFromString(pgSQL.config.PaginationKey)),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -151,6 +131,15 @@ func (pgSQL *pgSQL) Ping() bool {
|
|||||||
return pgSQL.DB.Ping() == nil
|
return pgSQL.DB.Ping() == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Page is the representation of a page for the Postgres schema.
|
||||||
|
type Page struct {
|
||||||
|
// StartID is the ID being used as the basis for pagination across database
|
||||||
|
// results. It is used to search for an ancestry with ID >= StartID.
|
||||||
|
//
|
||||||
|
// StartID is required to be unique to every ancestry and always increasing.
|
||||||
|
StartID int64
|
||||||
|
}
|
||||||
|
|
||||||
// Config is the configuration that is used by openDatabase.
|
// Config is the configuration that is used by openDatabase.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Source string
|
Source string
|
||||||
|
@ -24,13 +24,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
fernet "github.com/fernet/fernet-go"
|
|
||||||
"github.com/pborman/uuid"
|
"github.com/pborman/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
yaml "gopkg.in/yaml.v2"
|
yaml "gopkg.in/yaml.v2"
|
||||||
|
|
||||||
"github.com/coreos/clair/database"
|
"github.com/coreos/clair/database"
|
||||||
|
"github.com/coreos/clair/pkg/pagination"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@ -215,18 +215,13 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data
|
|||||||
source = fmt.Sprintf(sourceEnv, dbName)
|
source = fmt.Sprintf(sourceEnv, dbName)
|
||||||
}
|
}
|
||||||
|
|
||||||
var key fernet.Key
|
|
||||||
if err := key.Generate(); err != nil {
|
|
||||||
panic("failed to generate pagination key" + err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
return database.RegistrableComponentConfig{
|
return database.RegistrableComponentConfig{
|
||||||
Options: map[string]interface{}{
|
Options: map[string]interface{}{
|
||||||
"source": source,
|
"source": source,
|
||||||
"cachesize": 0,
|
"cachesize": 0,
|
||||||
"managedatabaselifecycle": manageLife,
|
"managedatabaselifecycle": manageLife,
|
||||||
"fixturepath": fixturePath,
|
"fixturepath": fixturePath,
|
||||||
"paginationkey": key.Encode(),
|
"paginationkey": pagination.Must(pagination.NewKey()).String(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,49 +0,0 @@
|
|||||||
// Copyright 2017 clair authors
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// Package token implements encryption/decryption for json encoded interfaces
|
|
||||||
package token
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fernet/fernet-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Unmarshal decrypts a token using provided key
|
|
||||||
// and decode the result into interface.
|
|
||||||
func Unmarshal(token string, key string, v interface{}) error {
|
|
||||||
k, _ := fernet.DecodeKey(key)
|
|
||||||
msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k})
|
|
||||||
if msg == nil {
|
|
||||||
return errors.New("invalid or expired pagination token")
|
|
||||||
}
|
|
||||||
|
|
||||||
return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Marshal encodes an interface into json bytes and encrypts it.
|
|
||||||
func Marshal(v interface{}, key string) ([]byte, error) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
err := json.NewEncoder(&buf).Encode(v)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
k, _ := fernet.DecodeKey(key)
|
|
||||||
return fernet.EncryptAndSign(buf.Bytes(), k)
|
|
||||||
}
|
|
102
pkg/pagination/pagination.go
Normal file
102
pkg/pagination/pagination.go
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
// Copyright 2018 clair authors
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Package pagination implements a series of utilities for dealing with
|
||||||
|
// paginating lists of objects for an API.
|
||||||
|
package pagination
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/fernet/fernet-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrInvalidToken is returned when a token fails to Unmarshal because it was
|
||||||
|
// invalid or expired.
|
||||||
|
var ErrInvalidToken = errors.New("invalid or expired pagination token")
|
||||||
|
|
||||||
|
// ErrInvalidKeyString is returned when the string representing a key is malformed.
|
||||||
|
var ErrInvalidKeyString = errors.New("invalid pagination key string: must be 32-byte URL-safe base64")
|
||||||
|
|
||||||
|
// Key represents the key used to cryptographically secure the token
|
||||||
|
// being used to keep track of pages.
|
||||||
|
type Key struct {
|
||||||
|
fkey *fernet.Key
|
||||||
|
}
|
||||||
|
|
||||||
|
// Token represents an opaque pagination token keeping track of a user's
|
||||||
|
// progress iterating through a list of results.
|
||||||
|
type Token string
|
||||||
|
|
||||||
|
// FirstPageToken is used to represent the first page of content.
|
||||||
|
var FirstPageToken = Token("")
|
||||||
|
|
||||||
|
// NewKey generates a new random pagination key.
|
||||||
|
func NewKey() (k Key, err error) {
|
||||||
|
k.fkey = new(fernet.Key)
|
||||||
|
err = k.fkey.Generate()
|
||||||
|
return k, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeyFromString creates the key for a given string.
|
||||||
|
//
|
||||||
|
// Strings must be 32-byte URL-safe base64 representations of the key bytes.
|
||||||
|
func KeyFromString(keyString string) (k Key, err error) {
|
||||||
|
var fkey *fernet.Key
|
||||||
|
fkey, err = fernet.DecodeKey(keyString)
|
||||||
|
if err != nil {
|
||||||
|
return Key{}, ErrInvalidKeyString
|
||||||
|
}
|
||||||
|
return Key{fkey}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Must is a helper that wraps calls returning a Key and and error and panics
|
||||||
|
// if the error is non-nil.
|
||||||
|
func Must(k Key, err error) Key {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return k
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the fmt.Stringer interface for Key.
|
||||||
|
func (k Key) String() string {
|
||||||
|
return k.fkey.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// MarshalToken encodes an interface into JSON bytes and produces a Token.
|
||||||
|
func (k Key) MarshalToken(v interface{}) (Token, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := json.NewEncoder(&buf).Encode(v)
|
||||||
|
if err != nil {
|
||||||
|
return Token(""), err
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenBytes, err := fernet.EncryptAndSign(buf.Bytes(), k.fkey)
|
||||||
|
return Token(tokenBytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalToken decrypts a Token using provided key and decodes the result
|
||||||
|
// into the provided interface.
|
||||||
|
func (k Key) UnmarshalToken(t Token, v interface{}) error {
|
||||||
|
msg := fernet.VerifyAndDecrypt([]byte(t), time.Hour, []*fernet.Key{k.fkey})
|
||||||
|
if msg == nil {
|
||||||
|
return ErrInvalidToken
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v)
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user