Merge pull request #613 from jzelinskie/pkg-pagination

Introduce pkg/pagination
This commit is contained in:
Jimmy Zelinskie 2018-09-07 16:34:59 -04:00 committed by GitHub
commit e5c2e378a2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 187 additions and 144 deletions

View File

@ -25,6 +25,7 @@ import (
pb "github.com/coreos/clair/api/v3/clairpb" pb "github.com/coreos/clair/api/v3/clairpb"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
) )
// NotificationServer implements NotificationService interface for serving RPC. // NotificationServer implements NotificationService interface for serving RPC.
@ -214,8 +215,8 @@ func (s *NotificationServer) GetNotification(ctx context.Context, req *pb.GetNot
dbNotification, ok, err := tx.FindVulnerabilityNotification( dbNotification, ok, err := tx.FindVulnerabilityNotification(
req.GetName(), req.GetName(),
int(req.GetLimit()), int(req.GetLimit()),
database.PageNumber(req.GetOldVulnerabilityPage()), pagination.Token(req.GetOldVulnerabilityPage()),
database.PageNumber(req.GetNewVulnerabilityPage()), pagination.Token(req.GetNewVulnerabilityPage()),
) )
if err != nil { if err != nil {

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2018 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -20,7 +20,6 @@ import (
"os" "os"
"time" "time"
"github.com/fernet/fernet-go"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
@ -31,6 +30,7 @@ import (
"github.com/coreos/clair/ext/featurens" "github.com/coreos/clair/ext/featurens"
"github.com/coreos/clair/ext/notification" "github.com/coreos/clair/ext/notification"
"github.com/coreos/clair/ext/vulnsrc" "github.com/coreos/clair/ext/vulnsrc"
"github.com/coreos/clair/pkg/pagination"
) )
// ErrDatasourceNotLoaded is returned when the datasource variable in the // ErrDatasourceNotLoaded is returned when the datasource variable in the
@ -108,15 +108,10 @@ func LoadConfig(path string) (config *Config, err error) {
// Generate a pagination key if none is provided. // Generate a pagination key if none is provided.
if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" { if v, ok := config.Database.Options["paginationkey"]; !ok || v == nil || v.(string) == "" {
log.Warn("pagination key is empty, generating...") log.Warn("pagination key is empty, generating...")
var key fernet.Key config.Database.Options["paginationkey"] = pagination.Must(pagination.NewKey()).String()
if err = key.Generate(); err != nil {
return
}
config.Database.Options["paginationkey"] = key.Encode()
} else { } else {
_, err = fernet.DecodeKey(config.Database.Options["paginationkey"].(string)) _, err = pagination.KeyFromString(config.Database.Options["paginationkey"].(string))
if err != nil { if err != nil {
err = errors.New("Invalid Pagination key; must be 32-bit URL-safe base64")
return return
} }
} }

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2018 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -61,7 +61,17 @@ import (
_ "github.com/coreos/clair/ext/vulnsrc/ubuntu" _ "github.com/coreos/clair/ext/vulnsrc/ubuntu"
) )
const maxDBConnectionAttempts = 20 // MaxDBConnectionAttempts is the total number of tries that Clair will use to
// initially connect to a database at start-up.
const MaxDBConnectionAttempts = 20
// BinaryDependencies are the programs that Clair expects to be on the $PATH
// because it creates subprocesses of these programs.
var BinaryDependencies = []string{
"git",
"rpm",
"xz",
}
func waitForSignals(signals ...os.Signal) { func waitForSignals(signals ...os.Signal) {
interrupts := make(chan os.Signal, 1) interrupts := make(chan os.Signal, 1)
@ -136,7 +146,7 @@ func Boot(config *Config) {
// Open database // Open database
var db database.Datastore var db database.Datastore
var dbError error var dbError error
for attempts := 1; attempts <= maxDBConnectionAttempts; attempts++ { for attempts := 1; attempts <= MaxDBConnectionAttempts; attempts++ {
db, dbError = database.Open(config.Database) db, dbError = database.Open(config.Database)
if dbError == nil { if dbError == nil {
break break
@ -180,7 +190,7 @@ func main() {
flag.Parse() flag.Parse()
// Check for dependencies. // Check for dependencies.
for _, bin := range []string{"git", "rpm", "xz"} { for _, bin := range BinaryDependencies {
_, err := exec.LookPath(bin) _, err := exec.LookPath(bin)
if err != nil { if err != nil {
log.WithError(err).WithField("dependency", bin).Fatal("failed to find dependency") log.WithError(err).WithField("dependency", bin).Fatal("failed to find dependency")

View File

@ -20,6 +20,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"time" "time"
"github.com/coreos/clair/pkg/pagination"
) )
var ( var (
@ -168,16 +170,9 @@ type Session interface {
// affected ancestries affected by old or new vulnerability. // affected ancestries affected by old or new vulnerability.
// //
// Because the number of affected ancestries maybe large, they are paginated // Because the number of affected ancestries maybe large, they are paginated
// and their pages are specified by the given encrypted PageNumbers, which, // and their pages are specified by the paination token, which, if empty, are
// if empty, are always considered first page. // always considered first page.
// FindVulnerabilityNotification(name string, limit int, oldVulnerabilityPage pagination.Token, newVulnerabilityPage pagination.Token) (noti VulnerabilityNotificationWithVulnerable, found bool, err error)
// Session interface implementation should have encrypt and decrypt
// functions for PageNumber.
FindVulnerabilityNotification(name string, limit int,
oldVulnerabilityPage PageNumber,
newVulnerabilityPage PageNumber) (
noti VulnerabilityNotificationWithVulnerable,
found bool, err error)
// MarkNotificationNotified marks a Notification as notified now, assuming // MarkNotificationNotified marks a Notification as notified now, assuming
// the requested notification is in the database. // the requested notification is in the database.

View File

@ -14,7 +14,11 @@
package database package database
import "time" import (
"time"
"github.com/coreos/clair/pkg/pagination"
)
// MockSession implements Session and enables overriding each available method. // MockSession implements Session and enables overriding each available method.
// The default behavior of each method is to simply panic. // The default behavior of each method is to simply panic.
@ -38,7 +42,7 @@ type MockSession struct {
FctDeleteVulnerabilities func([]VulnerabilityID) error FctDeleteVulnerabilities func([]VulnerabilityID) error
FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error FctInsertVulnerabilityNotifications func([]VulnerabilityNotification) error
FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error) FctFindNewNotification func(lastNotified time.Time) (NotificationHook, bool, error)
FctFindVulnerabilityNotification func(name string, limit int, oldPage PageNumber, newPage PageNumber) ( FctFindVulnerabilityNotification func(name string, limit int, oldPage pagination.Token, newPage pagination.Token) (
vuln VulnerabilityNotificationWithVulnerable, ok bool, err error) vuln VulnerabilityNotificationWithVulnerable, ok bool, err error)
FctMarkNotificationNotified func(name string) error FctMarkNotificationNotified func(name string) error
FctDeleteNotification func(name string) error FctDeleteNotification func(name string) error
@ -182,7 +186,7 @@ func (ms *MockSession) FindNewNotification(lastNotified time.Time) (Notification
panic("required mock function not implemented") panic("required mock function not implemented")
} }
func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage PageNumber, newPage PageNumber) ( func (ms *MockSession) FindVulnerabilityNotification(name string, limit int, oldPage pagination.Token, newPage pagination.Token) (
VulnerabilityNotificationWithVulnerable, bool, error) { VulnerabilityNotificationWithVulnerable, bool, error) {
if ms.FctFindVulnerabilityNotification != nil { if ms.FctFindVulnerabilityNotification != nil {
return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage) return ms.FctFindVulnerabilityNotification(name, limit, oldPage, newPage)

View File

@ -18,6 +18,8 @@ import (
"database/sql/driver" "database/sql/driver"
"encoding/json" "encoding/json"
"time" "time"
"github.com/coreos/clair/pkg/pagination"
) )
// Processors are extentions to scan a layer's content. // Processors are extentions to scan a layer's content.
@ -173,8 +175,8 @@ type PagedVulnerableAncestries struct {
Affected map[int]string Affected map[int]string
Limit int Limit int
Current PageNumber Current pagination.Token
Next PageNumber Next pagination.Token
// End signals the end of the pages. // End signals the end of the pages.
End bool End bool
@ -209,9 +211,6 @@ type VulnerabilityNotificationWithVulnerable struct {
New *PagedVulnerableAncestries New *PagedVulnerableAncestries
} }
// PageNumber is used to do pagination.
type PageNumber string
// MetadataMap is for storing the metadata returned by vulnerability database. // MetadataMap is for storing the metadata returned by vulnerability database.
type MetadataMap map[string]interface{} type MetadataMap map[string]interface{}

View File

@ -23,6 +23,7 @@ import (
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
) )
var ( var (
@ -163,12 +164,12 @@ func (tx *pgSession) FindNewNotification(notifiedBefore time.Time) (database.Not
return notification, true, nil return notification, true, nil
} }
func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentPage database.PageNumber) (database.PagedVulnerableAncestries, error) { func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, currentToken pagination.Token) (database.PagedVulnerableAncestries, error) {
vulnPage := database.PagedVulnerableAncestries{Limit: limit} vulnPage := database.PagedVulnerableAncestries{Limit: limit}
current := idPageNumber{0} currentPage := Page{0}
if currentPage != "" { if currentToken != pagination.FirstPageToken {
var err error var err error
current, err = decryptPage(currentPage, tx.paginationKey) err = tx.key.UnmarshalToken(currentToken, &currentPage)
if err != nil { if err != nil {
return vulnPage, err return vulnPage, err
} }
@ -188,7 +189,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
} }
// the last result is used for the next page's startID // the last result is used for the next page's startID
rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, current.StartID, limit+1) rows, err := tx.Query(searchNotificationVulnerableAncestry, vulnID, currentPage.StartID, limit+1)
if err != nil { if err != nil {
return vulnPage, handleError("searchNotificationVulnerableAncestry", err) return vulnPage, handleError("searchNotificationVulnerableAncestry", err)
} }
@ -209,13 +210,9 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
lastIndex = len(ancestries) lastIndex = len(ancestries)
vulnPage.End = true vulnPage.End = true
} else { } else {
// Use the last ancestry's ID as the next PageNumber. // Use the last ancestry's ID as the next page.
lastIndex = len(ancestries) - 1 lastIndex = len(ancestries) - 1
vulnPage.Next, err = encryptPage( vulnPage.Next, err = tx.key.MarshalToken(Page{ancestries[len(ancestries)-1].id})
idPageNumber{
ancestries[len(ancestries)-1].id,
}, tx.paginationKey)
if err != nil { if err != nil {
return vulnPage, err return vulnPage, err
} }
@ -226,7 +223,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
vulnPage.Affected[int(ancestry.id)] = ancestry.name vulnPage.Affected[int(ancestry.id)] = ancestry.name
} }
vulnPage.Current, err = encryptPage(current, tx.paginationKey) vulnPage.Current, err = tx.key.MarshalToken(currentPage)
if err != nil { if err != nil {
return vulnPage, err return vulnPage, err
} }
@ -234,7 +231,7 @@ func (tx *pgSession) findPagedVulnerableAncestries(vulnID int64, limit int, curr
return vulnPage, nil return vulnPage, nil
} }
func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPage database.PageNumber, newPage database.PageNumber) ( func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPageToken pagination.Token, newPageToken pagination.Token) (
database.VulnerabilityNotificationWithVulnerable, bool, error) { database.VulnerabilityNotificationWithVulnerable, bool, error) {
var ( var (
noti database.VulnerabilityNotificationWithVulnerable noti database.VulnerabilityNotificationWithVulnerable
@ -274,7 +271,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
if oldVulnID.Valid { if oldVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPage) page, err := tx.findPagedVulnerableAncestries(oldVulnID.Int64, limit, oldPageToken)
if err != nil { if err != nil {
return noti, false, err return noti, false, err
} }
@ -282,7 +279,7 @@ func (tx *pgSession) FindVulnerabilityNotification(name string, limit int, oldPa
} }
if newVulnID.Valid { if newVulnID.Valid {
page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPage) page, err := tx.findPagedVulnerableAncestries(newVulnID.Int64, limit, newPageToken)
if err != nil { if err != nil {
return noti, false, err return noti, false, err
} }

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2018 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -73,22 +73,25 @@ func TestPagination(t *testing.T) {
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
oldPageNum, err := decryptPage(noti.Old.Current, tx.paginationKey) var oldPage Page
err := tx.key.UnmarshalToken(noti.Old.Current, &oldPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
assert.Equal(t, int64(0), oldPageNum.StartID) assert.Equal(t, int64(0), oldPage.StartID)
newPageNum, err := decryptPage(noti.New.Current, tx.paginationKey) var newPage Page
err = tx.key.UnmarshalToken(noti.New.Current, &newPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
newPageNextNum, err := decryptPage(noti.New.Next, tx.paginationKey) var newPageNext Page
err = tx.key.UnmarshalToken(noti.New.Next, &newPageNext)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
assert.Equal(t, int64(0), newPageNum.StartID) assert.Equal(t, int64(0), newPage.StartID)
assert.Equal(t, int64(4), newPageNextNum.StartID) assert.Equal(t, int64(4), newPageNext.StartID)
noti.Old.Current = "" noti.Old.Current = ""
noti.New.Current = "" noti.New.Current = ""
@ -98,26 +101,28 @@ func TestPagination(t *testing.T) {
} }
} }
page1, err := encryptPage(idPageNumber{0}, tx.paginationKey) pageNum1, err := tx.key.MarshalToken(Page{0})
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
page2, err := encryptPage(idPageNumber{4}, tx.paginationKey) pageNum2, err := tx.key.MarshalToken(Page{4})
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
noti, ok, err = tx.FindVulnerabilityNotification("test", 1, page1, page2) noti, ok, err = tx.FindVulnerabilityNotification("test", 1, pageNum1, pageNum2)
if assert.Nil(t, err) && assert.True(t, ok) { if assert.Nil(t, err) && assert.True(t, ok) {
assert.Equal(t, "test", noti.Name) assert.Equal(t, "test", noti.Name)
if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) { if assert.NotNil(t, noti.Old) && assert.NotNil(t, noti.New) {
oldCurrentPage, err := decryptPage(noti.Old.Current, tx.paginationKey) var oldCurrentPage Page
err = tx.key.UnmarshalToken(noti.Old.Current, &oldCurrentPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }
newCurrentPage, err := decryptPage(noti.New.Current, tx.paginationKey) var newCurrentPage Page
err = tx.key.UnmarshalToken(noti.New.Current, &newCurrentPage)
if !assert.Nil(t, err) { if !assert.Nil(t, err) {
assert.FailNow(t, "") assert.FailNow(t, "")
} }

View File

@ -33,8 +33,8 @@ import (
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/database/pgsql/migrations" "github.com/coreos/clair/database/pgsql/migrations"
"github.com/coreos/clair/database/pgsql/token"
"github.com/coreos/clair/pkg/commonerr" "github.com/coreos/clair/pkg/commonerr"
"github.com/coreos/clair/pkg/pagination"
) )
var ( var (
@ -91,33 +91,13 @@ type pgSQL struct {
type pgSession struct { type pgSession struct {
*sql.Tx *sql.Tx
paginationKey string key pagination.Key
} }
type idPageNumber struct { // Begin initiates a transaction to database.
// StartID is an implementation detail for paginating by an ID required to
// be unique to every ancestry and always increasing.
// //
// StartID is used to search for ancestry with ID >= StartID // The expected transaction isolation level in this implementation is "Read
StartID int64 // Committed".
}
func encryptPage(page idPageNumber, paginationKey string) (result database.PageNumber, err error) {
resultBytes, err := token.Marshal(page, paginationKey)
if err != nil {
return result, err
}
result = database.PageNumber(resultBytes)
return result, nil
}
func decryptPage(page database.PageNumber, paginationKey string) (result idPageNumber, err error) {
err = token.Unmarshal(string(page), paginationKey, &result)
return
}
// Begin initiates a transaction to database. The expected transaction isolation
// level in this implementation is "Read Committed".
func (pgSQL *pgSQL) Begin() (database.Session, error) { func (pgSQL *pgSQL) Begin() (database.Session, error) {
tx, err := pgSQL.DB.Begin() tx, err := pgSQL.DB.Begin()
if err != nil { if err != nil {
@ -125,7 +105,7 @@ func (pgSQL *pgSQL) Begin() (database.Session, error) {
} }
return &pgSession{ return &pgSession{
Tx: tx, Tx: tx,
paginationKey: pgSQL.config.PaginationKey, key: pagination.Must(pagination.KeyFromString(pgSQL.config.PaginationKey)),
}, nil }, nil
} }
@ -151,6 +131,15 @@ func (pgSQL *pgSQL) Ping() bool {
return pgSQL.DB.Ping() == nil return pgSQL.DB.Ping() == nil
} }
// Page is the representation of a page for the Postgres schema.
type Page struct {
// StartID is the ID being used as the basis for pagination across database
// results. It is used to search for an ancestry with ID >= StartID.
//
// StartID is required to be unique to every ancestry and always increasing.
StartID int64
}
// Config is the configuration that is used by openDatabase. // Config is the configuration that is used by openDatabase.
type Config struct { type Config struct {
Source string Source string

View File

@ -24,13 +24,13 @@ import (
"strings" "strings"
"testing" "testing"
fernet "github.com/fernet/fernet-go"
"github.com/pborman/uuid" "github.com/pborman/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
yaml "gopkg.in/yaml.v2" yaml "gopkg.in/yaml.v2"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/pkg/pagination"
) )
var ( var (
@ -215,18 +215,13 @@ func generateTestConfig(testName string, loadFixture bool, manageLife bool) data
source = fmt.Sprintf(sourceEnv, dbName) source = fmt.Sprintf(sourceEnv, dbName)
} }
var key fernet.Key
if err := key.Generate(); err != nil {
panic("failed to generate pagination key" + err.Error())
}
return database.RegistrableComponentConfig{ return database.RegistrableComponentConfig{
Options: map[string]interface{}{ Options: map[string]interface{}{
"source": source, "source": source,
"cachesize": 0, "cachesize": 0,
"managedatabaselifecycle": manageLife, "managedatabaselifecycle": manageLife,
"fixturepath": fixturePath, "fixturepath": fixturePath,
"paginationkey": key.Encode(), "paginationkey": pagination.Must(pagination.NewKey()).String(),
}, },
} }
} }

View File

@ -1,49 +0,0 @@
// Copyright 2017 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package token implements encryption/decryption for json encoded interfaces
package token
import (
"bytes"
"encoding/json"
"errors"
"time"
"github.com/fernet/fernet-go"
)
// Unmarshal decrypts a token using provided key
// and decode the result into interface.
func Unmarshal(token string, key string, v interface{}) error {
k, _ := fernet.DecodeKey(key)
msg := fernet.VerifyAndDecrypt([]byte(token), time.Hour, []*fernet.Key{k})
if msg == nil {
return errors.New("invalid or expired pagination token")
}
return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v)
}
// Marshal encodes an interface into json bytes and encrypts it.
func Marshal(v interface{}, key string) ([]byte, error) {
var buf bytes.Buffer
err := json.NewEncoder(&buf).Encode(v)
if err != nil {
return nil, err
}
k, _ := fernet.DecodeKey(key)
return fernet.EncryptAndSign(buf.Bytes(), k)
}

View File

@ -0,0 +1,102 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package pagination implements a series of utilities for dealing with
// paginating lists of objects for an API.
package pagination
import (
"bytes"
"encoding/json"
"errors"
"time"
"github.com/fernet/fernet-go"
)
// ErrInvalidToken is returned when a token fails to Unmarshal because it was
// invalid or expired.
var ErrInvalidToken = errors.New("invalid or expired pagination token")
// ErrInvalidKeyString is returned when the string representing a key is malformed.
var ErrInvalidKeyString = errors.New("invalid pagination key string: must be 32-byte URL-safe base64")
// Key represents the key used to cryptographically secure the token
// being used to keep track of pages.
type Key struct {
fkey *fernet.Key
}
// Token represents an opaque pagination token keeping track of a user's
// progress iterating through a list of results.
type Token string
// FirstPageToken is used to represent the first page of content.
var FirstPageToken = Token("")
// NewKey generates a new random pagination key.
func NewKey() (k Key, err error) {
k.fkey = new(fernet.Key)
err = k.fkey.Generate()
return k, err
}
// KeyFromString creates the key for a given string.
//
// Strings must be 32-byte URL-safe base64 representations of the key bytes.
func KeyFromString(keyString string) (k Key, err error) {
var fkey *fernet.Key
fkey, err = fernet.DecodeKey(keyString)
if err != nil {
return Key{}, ErrInvalidKeyString
}
return Key{fkey}, err
}
// Must is a helper that wraps calls returning a Key and and error and panics
// if the error is non-nil.
func Must(k Key, err error) Key {
if err != nil {
panic(err)
}
return k
}
// String implements the fmt.Stringer interface for Key.
func (k Key) String() string {
return k.fkey.Encode()
}
// MarshalToken encodes an interface into JSON bytes and produces a Token.
func (k Key) MarshalToken(v interface{}) (Token, error) {
var buf bytes.Buffer
err := json.NewEncoder(&buf).Encode(v)
if err != nil {
return Token(""), err
}
tokenBytes, err := fernet.EncryptAndSign(buf.Bytes(), k.fkey)
return Token(tokenBytes), err
}
// UnmarshalToken decrypts a Token using provided key and decodes the result
// into the provided interface.
func (k Key) UnmarshalToken(t Token, v interface{}) error {
msg := fernet.VerifyAndDecrypt([]byte(t), time.Hour, []*fernet.Key{k.fkey})
if msg == nil {
return ErrInvalidToken
}
return json.NewDecoder(bytes.NewBuffer(msg)).Decode(&v)
}