Merge pull request #685 from jzelinskie/updater-cleanup

updater: remove FindLock(), use errgroup to avoid races
This commit is contained in:
Jimmy Zelinskie 2019-02-14 14:57:59 -05:00 committed by GitHub
commit cafe0976a4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
76 changed files with 3586 additions and 541 deletions

View File

@ -183,23 +183,16 @@ type Session interface {
// FindKeyValue retrieves a value from the given key. // FindKeyValue retrieves a value from the given key.
FindKeyValue(key string) (value string, found bool, err error) FindKeyValue(key string) (value string, found bool, err error)
// Lock creates or renew a Lock in the database with the given name, owner // Lock acquires or renews a lock in the database with the given name, owner
// and duration. // and duration without blocking. After the specified duration, the lock
// expires if it hasn't already been unlocked in order to prevent a deadlock.
// //
// After the specified duration, the Lock expires by itself if it hasn't been // If the acquisition of a lock is not successful, expiration should be
// unlocked, and thus, let other users create a Lock with the same name. // the time that existing lock expires.
// However, the owner can renew its Lock by setting renew to true.
// Lock should not block, it should instead returns whether the Lock has been
// successfully acquired/renewed. If it's the case, the expiration time of
// that Lock is returned as well.
Lock(name string, owner string, duration time.Duration, renew bool) (success bool, expiration time.Time, err error) Lock(name string, owner string, duration time.Duration, renew bool) (success bool, expiration time.Time, err error)
// Unlock releases an existing Lock. // Unlock releases an existing Lock.
Unlock(name, owner string) error Unlock(name, owner string) error
// FindLock returns the owner of a Lock specified by the name, and its
// expiration time if it exists.
FindLock(name string) (owner string, expiration time.Time, found bool, err error)
} }
// Datastore represents a persistent data store // Datastore represents a persistent data store

View File

@ -15,6 +15,8 @@
package database package database
import ( import (
"time"
"github.com/deckarep/golang-set" "github.com/deckarep/golang-set"
) )
@ -304,3 +306,47 @@ func MergeLayers(l *Layer, new *Layer) *Layer {
return l return l
} }
// AcquireLock acquires a named global lock for a duration.
//
// If renewal is true, the lock is extended as long as the same owner is
// attempting to renew the lock.
func AcquireLock(datastore Datastore, name, owner string, duration time.Duration, renewal bool) (success bool, expiration time.Time) {
// any error will cause the function to catch the error and return false.
tx, err := datastore.Begin()
if err != nil {
return false, time.Time{}
}
defer tx.Rollback()
locked, t, err := tx.Lock(name, owner, duration, renewal)
if err != nil {
return false, time.Time{}
}
if locked {
if err := tx.Commit(); err != nil {
return false, time.Time{}
}
}
return locked, t
}
// ReleaseLock releases a named global lock.
func ReleaseLock(datastore Datastore, name, owner string) {
tx, err := datastore.Begin()
if err != nil {
return
}
defer tx.Rollback()
if err := tx.Unlock(name, owner); err != nil {
return
}
if err := tx.Commit(); err != nil {
return
}
}

View File

@ -48,7 +48,6 @@ type MockSession struct {
FctFindKeyValue func(key string) (string, bool, error) FctFindKeyValue func(key string) (string, bool, error)
FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error) FctLock func(name string, owner string, duration time.Duration, renew bool) (bool, time.Time, error)
FctUnlock func(name, owner string) error FctUnlock func(name, owner string) error
FctFindLock func(name string) (string, time.Time, bool, error)
} }
func (ms *MockSession) Commit() error { func (ms *MockSession) Commit() error {
@ -220,13 +219,6 @@ func (ms *MockSession) Unlock(name, owner string) error {
panic("required mock function not implemented") panic("required mock function not implemented")
} }
func (ms *MockSession) FindLock(name string) (string, time.Time, bool, error) {
if ms.FctFindLock != nil {
return ms.FctFindLock(name)
}
panic("required mock function not implemented")
}
// MockDatastore implements Datastore and enables overriding each available method. // MockDatastore implements Datastore and enables overriding each available method.
// The default behavior of each method is to simply panic. // The default behavior of each method is to simply panic.
type MockDatastore struct { type MockDatastore struct {

View File

@ -25,7 +25,7 @@ import (
const ( const (
soiLock = `INSERT INTO lock(name, owner, until) VALUES ($1, $2, $3)` soiLock = `INSERT INTO lock(name, owner, until) VALUES ($1, $2, $3)`
searchLock = `SELECT owner, until FROM Lock WHERE name = $1` searchLock = `SELECT until FROM Lock WHERE name = $1`
updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2` updateLock = `UPDATE Lock SET until = $3 WHERE name = $1 AND owner = $2`
removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2` removeLock = `DELETE FROM Lock WHERE name = $1 AND owner = $2`
removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP` removeLockExpired = `DELETE FROM LOCK WHERE until < CURRENT_TIMESTAMP`
@ -67,7 +67,9 @@ func (tx *pgSession) Lock(name string, owner string, duration time.Duration, ren
_, err := tx.Exec(soiLock, name, owner, until) _, err := tx.Exec(soiLock, name, owner, until)
if err != nil { if err != nil {
if isErrUniqueViolation(err) { if isErrUniqueViolation(err) {
return false, until, nil // Return the existing locks expiration.
err := tx.QueryRow(searchLock, name).Scan(&until)
return false, until, handleError("searchLock", err)
} }
return false, until, handleError("insertLock", err) return false, until, handleError("insertLock", err)
} }
@ -86,25 +88,6 @@ func (tx *pgSession) Unlock(name, owner string) error {
return err return err
} }
// FindLock returns the owner of a lock specified by its name and its
// expiration time.
func (tx *pgSession) FindLock(name string) (string, time.Time, bool, error) {
if name == "" {
return "", time.Time{}, false, commonerr.NewBadRequestError("could not find an invalid lock")
}
defer observeQueryTime("FindLock", "all", time.Now())
var owner string
var until time.Time
err := tx.QueryRow(searchLock, name).Scan(&owner, &until)
if err != nil {
return owner, until, false, handleError("searchLock", err)
}
return owner, until, true, nil
}
// pruneLocks removes every expired locks from the database // pruneLocks removes every expired locks from the database
func (tx *pgSession) pruneLocks() error { func (tx *pgSession) pruneLocks() error {
defer observeQueryTime("pruneLocks", "all", time.Now()) defer observeQueryTime("pruneLocks", "all", time.Now())

View File

@ -26,7 +26,6 @@ func TestLock(t *testing.T) {
defer datastore.Close() defer datastore.Close()
var l bool var l bool
var et time.Time
// Create a first lock. // Create a first lock.
l, _, err := tx.Lock("test1", "owner1", time.Minute, false) l, _, err := tx.Lock("test1", "owner1", time.Minute, false)
@ -62,19 +61,11 @@ func TestLock(t *testing.T) {
assert.Nil(t, err) assert.Nil(t, err)
tx = restartSession(t, datastore, tx, true) tx = restartSession(t, datastore, tx, true)
l, et, err = tx.Lock("test1", "owner2", time.Minute, false) l, _, err = tx.Lock("test1", "owner2", time.Minute, false)
assert.Nil(t, err) assert.Nil(t, err)
assert.True(t, l) assert.True(t, l)
tx = restartSession(t, datastore, tx, true) tx = restartSession(t, datastore, tx, true)
// LockInfo
o, et2, ok, err := tx.FindLock("test1")
assert.True(t, ok)
assert.Nil(t, err)
assert.Equal(t, "owner2", o)
assert.Equal(t, et.Second(), et2.Second())
tx = restartSession(t, datastore, tx, true)
// Create a second lock which is actually already expired ... // Create a second lock which is actually already expired ...
l, _, err = tx.Lock("test2", "owner1", -time.Minute, false) l, _, err = tx.Lock("test2", "owner1", -time.Minute, false)
assert.Nil(t, err) assert.Nil(t, err)

View File

@ -86,3 +86,11 @@ func Appenders() map[string]Appender {
return ret return ret
} }
// CleanAll is a utility function that calls Clean() on every registered
// Appender.
func CleanAll() {
for _, appender := range Appenders() {
appender.Clean()
}
}

View File

@ -93,3 +93,11 @@ func ListUpdaters() []string {
} }
return r return r
} }
// CleanAll is a utility function that calls Clean() on every registered
// Updater.
func CleanAll() {
for _, updater := range Updaters() {
updater.Clean()
}
}

16
glide.lock generated
View File

@ -1,5 +1,5 @@
hash: 208de0ba40f951c17ac45683952efdd6b14f5efbfb70dcdb493c52954d96cb75 hash: 9db0d8cd37c64c634552e66bab73f6644f7443150209d88ce9972c7259a22867
updated: 2018-10-22T22:58:38.105092-04:00 updated: 2019-01-10T13:32:31.077199-05:00
imports: imports:
- name: github.com/beorn7/perks - name: github.com/beorn7/perks
version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9 version: 4c0e84591b9aa9e6dcfdf3e020114cd81f89d5f9
@ -31,7 +31,7 @@ imports:
- ptypes/struct - ptypes/struct
- ptypes/timestamp - ptypes/timestamp
- name: github.com/google/uuid - name: github.com/google/uuid
version: e704694aed0ea004bb7eb1fc2e911d048a54606a version: 9b3b1e0f5f99ae461456d768e7d301a7acdaa2d8
- name: github.com/grpc-ecosystem/go-grpc-prometheus - name: github.com/grpc-ecosystem/go-grpc-prometheus
version: 2500245aa6110c562d17020fb31a2c133d737799 version: 2500245aa6110c562d17020fb31a2c133d737799
- name: github.com/grpc-ecosystem/grpc-gateway - name: github.com/grpc-ecosystem/grpc-gateway
@ -65,7 +65,7 @@ imports:
subpackages: subpackages:
- difflib - difflib
- name: github.com/prometheus/client_golang - name: github.com/prometheus/client_golang
version: 1cafe34db7fdec6022e17e00e1c1ea501022f3e4 version: 505eaef017263e299324067d40ca2c48f6a2cf50
subpackages: subpackages:
- prometheus - prometheus
- prometheus/internal - prometheus/internal
@ -88,7 +88,7 @@ imports:
- name: github.com/sirupsen/logrus - name: github.com/sirupsen/logrus
version: ba1b36c82c5e05c4f912a88eab0dcd91a171688f version: ba1b36c82c5e05c4f912a88eab0dcd91a171688f
- name: github.com/stretchr/testify - name: github.com/stretchr/testify
version: f35b8ab0b5a2cef36673838d662e249dd9c94686 version: ffdc059bfe9ce6a4e144ba849dbedead332c6053
subpackages: subpackages:
- assert - assert
- require - require
@ -102,6 +102,10 @@ imports:
- internal/timeseries - internal/timeseries
- lex/httplex - lex/httplex
- trace - trace
- name: golang.org/x/sync
version: 37e7f081c4d4c64e13b10787722085407fe5d15f
subpackages:
- errgroup
- name: golang.org/x/sys - name: golang.org/x/sys
version: b90f89a1e7a9c1f6b918820b3daa7f08488c8594 version: b90f89a1e7a9c1f6b918820b3daa7f08488c8594
subpackages: subpackages:
@ -135,5 +139,5 @@ imports:
- tap - tap
- transport - transport
- name: gopkg.in/yaml.v2 - name: gopkg.in/yaml.v2
version: 5420a8b6744d3b0345ab293f6fcba19c978f1183 version: 51d6538a90f86fe93ac480b35f37b2be17fef232
testImports: [] testImports: []

View File

@ -31,3 +31,4 @@ import:
version: ^1.7.1 version: ^1.7.1
- package: gopkg.in/yaml.v2 - package: gopkg.in/yaml.v2
version: ^2.2.1 version: ^2.2.1
- package: golang.org/x/sync/errgroup

View File

@ -102,7 +102,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
if interrupted { if interrupted {
running = false running = false
} }
unlock(datastore, notification.Name, whoAmI) database.ReleaseLock(datastore, notification.Name, whoAmI)
done <- true done <- true
}() }()
@ -113,7 +113,7 @@ func RunNotifier(config *notification.Config, datastore database.Datastore, stop
case <-done: case <-done:
break outer break outer
case <-time.After(notifierLockRefreshDuration): case <-time.After(notifierLockRefreshDuration):
lock(datastore, notification.Name, whoAmI, notifierLockDuration, true) database.AcquireLock(datastore, notification.Name, whoAmI, notifierLockDuration, true)
case <-stopper.Chan(): case <-stopper.Chan():
running = false running = false
break break
@ -141,7 +141,7 @@ func findTask(datastore database.Datastore, renotifyInterval time.Duration, whoA
} }
// Lock the notification. // Lock the notification.
if hasLock, _ := lock(datastore, notification.Name, whoAmI, notifierLockDuration, false); hasLock { if hasLock, _ := database.AcquireLock(datastore, notification.Name, whoAmI, notifierLockDuration, false); hasLock {
log.WithField(logNotiName, notification.Name).Info("found and locked a notification") log.WithField(logNotiName, notification.Name).Info("found and locked a notification")
return &notification return &notification
} }
@ -208,44 +208,3 @@ func markNotificationAsRead(datastore database.Datastore, name string) error {
} }
return tx.Commit() return tx.Commit()
} }
// unlock removes a lock with provided name, owner. Internally, it handles
// database transaction and catches error.
func unlock(datastore database.Datastore, name, owner string) {
tx, err := datastore.Begin()
if err != nil {
return
}
defer tx.Rollback()
if err := tx.Unlock(name, owner); err != nil {
return
}
if err := tx.Commit(); err != nil {
return
}
}
func lock(datastore database.Datastore, name string, owner string, duration time.Duration, renew bool) (bool, time.Time) {
// any error will cause the function to catch the error and return false.
tx, err := datastore.Begin()
if err != nil {
return false, time.Time{}
}
defer tx.Rollback()
locked, t, err := tx.Lock(name, owner, duration, renew)
if err != nil {
return false, time.Time{}
}
if locked {
if err := tx.Commit(); err != nil {
return false, time.Time{}
}
}
return locked, t
}

60
pkg/timeutil/timeutil.go Normal file
View File

@ -0,0 +1,60 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package timeutil implements extra utilities dealing with time not found
// in the standard library.
package timeutil
import (
"math"
"math/rand"
"time"
log "github.com/sirupsen/logrus"
"github.com/coreos/clair/pkg/stopper"
)
// ApproxSleep is a stoppable time.Sleep that adds a slight random variation to
// the wakeup time in order to prevent thundering herds.
func ApproxSleep(approxWakeup time.Time, st *stopper.Stopper) (stopped bool) {
waitUntil := approxWakeup.Add(time.Duration(rand.ExpFloat64()/0.5) * time.Second)
log.WithField("wakeup", waitUntil).Debug("updater sleeping")
now := time.Now().UTC()
if !waitUntil.Before(now) {
if !st.Sleep(waitUntil.Sub(now)) {
return true
}
}
return false
}
// ExpBackoff doubles the backoff time, if the result is longer than the
// parameter max, max will be returned.
func ExpBackoff(prev, max time.Duration) time.Duration {
t := 2 * prev
if t > max {
t = max
}
if t == 0 {
return time.Second
}
return t
}
// FractionalDuration calculates the fraction of a Duration rounding half way
// from zero.
func FractionalDuration(fraction float64, d time.Duration) time.Duration {
return time.Duration(math.Round(float64(d) * fraction))
}

View File

@ -0,0 +1,40 @@
// Copyright 2018 clair authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package timeutil
import (
"testing"
"time"
)
func TestExpBackoff(t *testing.T) {
tests := []struct {
prev, max, want time.Duration
}{
{time.Duration(0), 1 * time.Minute, 1 * time.Second},
{1 * time.Second, 1 * time.Minute, 2 * time.Second},
{16 * time.Second, 1 * time.Minute, 32 * time.Second},
{32 * time.Second, 1 * time.Minute, 1 * time.Minute},
{1 * time.Minute, 1 * time.Minute, 1 * time.Minute},
{2 * time.Minute, 1 * time.Minute, 1 * time.Minute},
}
for i, tt := range tests {
got := ExpBackoff(tt.prev, tt.max)
if tt.want != got {
t.Errorf("case %d: want=%v got=%v", i, tt.want, got)
}
}
}

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2019 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,8 +15,9 @@
package clair package clair
import ( import (
"context"
"errors"
"fmt" "fmt"
"math/rand"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -24,11 +25,13 @@ import (
"github.com/pborman/uuid" "github.com/pborman/uuid"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/errgroup"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/coreos/clair/ext/vulnmdsrc" "github.com/coreos/clair/ext/vulnmdsrc"
"github.com/coreos/clair/ext/vulnsrc" "github.com/coreos/clair/ext/vulnsrc"
"github.com/coreos/clair/pkg/stopper" "github.com/coreos/clair/pkg/stopper"
"github.com/coreos/clair/pkg/timeutil"
) )
const ( const (
@ -87,133 +90,126 @@ func RunUpdater(config *UpdaterConfig, datastore database.Datastore, st *stopper
return return
} }
// Clean up any resources the updater left behind.
defer func() {
vulnmdsrc.CleanAll()
vulnsrc.CleanAll()
log.Info("updater service stopped")
}()
// Create a new unique identity for tracking who owns global locks.
whoAmI := uuid.New() whoAmI := uuid.New()
log.WithField("lock identifier", whoAmI).Info("updater service started") log.WithField("owner", whoAmI).Info("updater service started")
sleepDuration := updaterSleepBetweenLoopsDuration
for { for {
var stop bool
// Determine if this is the first update and define the next update time. // Determine if this is the first update and define the next update time.
// The next update time is (last update time + interval) or now if this is the first update. // The next update time is (last update time + interval) or now if this is the first update.
nextUpdate := time.Now().UTC() nextUpdate := time.Now().UTC()
lastUpdate, firstUpdate, err := GetLastUpdateTime(datastore) lastUpdate, isFirstUpdate, err := GetLastUpdateTime(datastore)
if err != nil { if err != nil {
log.WithError(err).Error("an error occurred while getting the last update time") log.WithError(err).Error("an error occurred while getting the last update time")
nextUpdate = nextUpdate.Add(config.Interval) nextUpdate = nextUpdate.Add(config.Interval)
} else if !firstUpdate { }
log.WithFields(log.Fields{
"firstUpdate": isFirstUpdate,
"nextUpdate": nextUpdate,
}).Debug("fetched last update time")
if !isFirstUpdate {
nextUpdate = lastUpdate.Add(config.Interval) nextUpdate = lastUpdate.Add(config.Interval)
} }
// If the next update timer is in the past, then try to update. // If the next update timer is in the past, then try to update.
if nextUpdate.Before(time.Now().UTC()) { if nextUpdate.Before(time.Now().UTC()) {
// Attempt to get a lock on the the update. // Attempt to get a lock on the update.
log.Debug("attempting to obtain update lock") log.Debug("attempting to obtain update lock")
hasLock, hasLockUntil := lock(datastore, updaterLockName, whoAmI, updaterLockDuration, false) acquiredLock, lockExpiration := database.AcquireLock(datastore, updaterLockName, whoAmI, updaterLockDuration, false)
if hasLock { if lockExpiration.IsZero() {
// Launch update in a new go routine. // Any failures to acquire the lock should instantly expire.
doneC := make(chan bool, 1) var instantExpiration time.Duration
go func() { sleepDuration = instantExpiration
update(datastore, firstUpdate)
doneC <- true
}()
for done := false; !done && !stop; {
select {
case <-doneC:
done = true
case <-time.After(updaterLockRefreshDuration):
// Refresh the lock until the update is done.
lock(datastore, updaterLockName, whoAmI, updaterLockDuration, true)
case <-st.Chan():
stop = true
}
} }
// Unlock the updater. if acquiredLock {
unlock(datastore, updaterLockName, whoAmI) sleepDuration, err = updateWhileRenewingLock(datastore, whoAmI, isFirstUpdate, st)
if err != nil {
if stop { if err == errReceivedStopSignal {
break log.Debug("updater received stop signal")
return
} }
log.WithError(err).Debug("failed to acquired lock")
// Sleep for a short duration to prevent pinning the CPU on a sleepDuration = timeutil.ExpBackoff(sleepDuration, config.Interval)
// consistent failure.
if stopped := sleepUpdater(time.Now().Add(updaterSleepBetweenLoopsDuration), st); stopped {
break
} }
continue
} else { } else {
lockOwner, lockExpiration, ok, err := findLock(datastore, updaterLockName) sleepDuration = updaterSleepBetweenLoopsDuration
if !ok || err != nil { }
log.Debug("update lock is already taken")
nextUpdate = hasLockUntil
} else { } else {
log.WithFields(log.Fields{"lock owner": lockOwner, "lock expiration": lockExpiration}).Debug("update lock is already taken") sleepDuration = time.Until(nextUpdate)
nextUpdate = lockExpiration
}
}
} }
if stopped := sleepUpdater(nextUpdate, st); stopped { if stopped := timeutil.ApproxSleep(time.Now().Add(sleepDuration), st); stopped {
break return
} }
} }
// Clean resources.
for _, appenders := range vulnmdsrc.Appenders() {
appenders.Clean()
}
for _, updaters := range vulnsrc.Updaters() {
updaters.Clean()
}
log.Info("updater service stopped")
} }
// sleepUpdater sleeps the updater for an approximate duration, but remains var errReceivedStopSignal = errors.New("stopped")
// able to be cancelled by a stopper.
func sleepUpdater(approxWakeup time.Time, st *stopper.Stopper) (stopped bool) { func updateWhileRenewingLock(datastore database.Datastore, whoAmI string, isFirstUpdate bool, st *stopper.Stopper) (sleepDuration time.Duration, err error) {
waitUntil := approxWakeup.Add(time.Duration(rand.ExpFloat64()/0.5) * time.Second) g, ctx := errgroup.WithContext(context.Background())
log.WithField("scheduled time", waitUntil).Debug("updater sleeping") g.Go(func() error {
if !waitUntil.Before(time.Now().UTC()) { return update(ctx, datastore, isFirstUpdate)
if !st.Sleep(waitUntil.Sub(time.Now())) { })
return true
g.Go(func() error {
var refreshDuration = updaterLockRefreshDuration
for {
select {
case <-time.After(timeutil.FractionalDuration(0.9, refreshDuration)):
success, lockExpiration := database.AcquireLock(datastore, updaterLockName, whoAmI, updaterLockRefreshDuration, true)
if !success {
return errors.New("failed to extend lock")
}
refreshDuration = time.Until(lockExpiration)
case <-ctx.Done():
database.ReleaseLock(datastore, updaterLockName, whoAmI)
return ctx.Err()
} }
} }
return false })
g.Go(func() error {
select {
case <-st.Chan():
return errReceivedStopSignal
case <-ctx.Done():
return ctx.Err()
}
})
err = g.Wait()
return
} }
// update fetches all the vulnerabilities from the registered fetchers, updates // update fetches all the vulnerabilities from the registered fetchers, updates
// vulnerabilities, and updater flags, and logs notes from updaters. // vulnerabilities, and updater flags, and logs notes from updaters.
func update(datastore database.Datastore, firstUpdate bool) { func update(ctx context.Context, datastore database.Datastore, firstUpdate bool) error {
defer setUpdaterDuration(time.Now()) defer setUpdaterDuration(time.Now())
log.Info("updating vulnerabilities") log.Info("updating vulnerabilities")
// Fetch updates. // Fetch updates.
success, vulnerabilities, flags, notes := fetch(datastore) success, vulnerabilities, flags, notes := fetchUpdates(ctx, datastore)
// do vulnerability namespacing again to merge potentially duplicated namespaces, vulnerabilities := deduplicate(vulnerabilities)
// vulnerabilities from each updater.
vulnerabilities = doVulnerabilitiesNamespacing(vulnerabilities)
// deduplicate fetched namespaces and store them into database.
nsMap := map[database.Namespace]struct{}{}
for _, vuln := range vulnerabilities {
nsMap[vuln.Namespace] = struct{}{}
}
namespaces := make([]database.Namespace, 0, len(nsMap))
for ns := range nsMap {
namespaces = append(namespaces, ns)
}
if err := database.PersistNamespacesAndCommit(datastore, namespaces); err != nil { if err := database.PersistNamespacesAndCommit(datastore, namespaces); err != nil {
log.WithError(err).Error("Unable to insert namespaces") log.WithError(err).Error("Unable to insert namespaces")
return return err
} }
changes, err := updateVulnerabilities(datastore, vulnerabilities) changes, err := updateVulnerabilities(ctx, datastore, vulnerabilities)
defer func() { defer func() {
if err != nil { if err != nil {
@ -223,21 +219,21 @@ func update(datastore database.Datastore, firstUpdate bool) {
if err != nil { if err != nil {
log.WithError(err).Error("Unable to update vulnerabilities") log.WithError(err).Error("Unable to update vulnerabilities")
return return err
} }
if !firstUpdate { if !firstUpdate {
err = createVulnerabilityNotifications(datastore, changes) err = createVulnerabilityNotifications(datastore, changes)
if err != nil { if err != nil {
log.WithError(err).Error("Unable to create notifications") log.WithError(err).Error("Unable to create notifications")
return return err
} }
} }
err = updateUpdaterFlags(datastore, flags) err = updateUpdaterFlags(datastore, flags)
if err != nil { if err != nil {
log.WithError(err).Error("Unable to update updater flags") log.WithError(err).Error("Unable to update updater flags")
return return err
} }
for _, note := range notes { for _, note := range notes {
@ -249,17 +245,87 @@ func update(datastore database.Datastore, firstUpdate bool) {
err = setLastUpdateTime(datastore) err = setLastUpdateTime(datastore)
if err != nil { if err != nil {
log.WithError(err).Error("Unable to set last update time") log.WithError(err).Error("Unable to set last update time")
return return err
} }
} }
log.Info("update finished") log.Info("update finished")
return nil
}
func deduplicate(vulns []database.VulnerabilityWithAffected) ([]database.Namespace, []database.VulnerabilityWithAffected) {
// do vulnerability namespacing again to merge potentially duplicated
// vulnerabilities from each updater.
vulnerabilities := doVulnerabilitiesNamespacing(vulns)
nsMap := map[database.Namespace]struct{}{}
for _, vuln := range vulnerabilities {
nsMap[vuln.Namespace] = struct{}{}
}
namespaces := make([]database.Namespace, 0, len(nsMap))
for ns := range nsMap {
namespaces = append(namespaces, ns)
}
return namespaces, vulnerabilities
} }
func setUpdaterDuration(start time.Time) { func setUpdaterDuration(start time.Time) {
promUpdaterDurationSeconds.Set(time.Since(start).Seconds()) promUpdaterDurationSeconds.Set(time.Since(start).Seconds())
} }
// fetchUpdates asynchronously runs all of the enabled Updaters, aggregates
// their results, and appends metadata to the vulnerabilities found.
func fetchUpdates(ctx context.Context, datastore database.Datastore) (success bool, vulns []database.VulnerabilityWithAffected, flags map[string]string, notes []string) {
flags = make(map[string]string)
log.Info("fetching vulnerability updates")
var mu sync.RWMutex
g, ctx := errgroup.WithContext(ctx)
for updaterName, updater := range vulnsrc.Updaters() {
// Shadow the loop variables to avoid closing over the wrong thing.
// See: https://golang.org/doc/faq#closures_and_goroutines
updaterName := updaterName
updater := updater
g.Go(func() error {
if !updaterEnabled(updaterName) {
return nil
}
// TODO(jzelinskie): add context to Update()
response, err := updater.Update(datastore)
if err != nil {
promUpdaterErrorsTotal.Inc()
log.WithError(err).WithField("updater", updaterName).Error("an error occurred when fetching an update")
return err
}
namespacedVulns := doVulnerabilitiesNamespacing(response.Vulnerabilities)
mu.Lock()
vulns = append(vulns, namespacedVulns...)
notes = append(notes, response.Notes...)
if response.FlagName != "" && response.FlagValue != "" {
flags[response.FlagName] = response.FlagValue
}
mu.Unlock()
return nil
})
}
if err := g.Wait(); err == nil {
success = true
}
vulns = addMetadata(ctx, datastore, vulns)
return
}
// fetch get data from the registered fetchers, in parallel. // fetch get data from the registered fetchers, in parallel.
func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffected, map[string]string, []string) { func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffected, map[string]string, []string) {
var vulnerabilities []database.VulnerabilityWithAffected var vulnerabilities []database.VulnerabilityWithAffected
@ -304,12 +370,12 @@ func fetch(datastore database.Datastore) (bool, []database.VulnerabilityWithAffe
} }
close(responseC) close(responseC)
return status, addMetadata(datastore, vulnerabilities), flags, notes return status, addMetadata(context.TODO(), datastore, vulnerabilities), flags, notes
} }
// Add metadata to the specified vulnerabilities using the registered // addMetadata asynchronously updates a list of vulnerabilities with metadata
// MetadataFetchers, in parallel. // from the vulnerability metadata sources.
func addMetadata(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected { func addMetadata(ctx context.Context, datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) []database.VulnerabilityWithAffected {
if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 { if len(vulnmdsrc.Appenders()) == 0 || len(vulnerabilities) == 0 {
return vulnerabilities return vulnerabilities
} }
@ -325,31 +391,39 @@ func addMetadata(datastore database.Datastore, vulnerabilities []database.Vulner
}) })
} }
var wg sync.WaitGroup g, ctx := errgroup.WithContext(ctx)
wg.Add(len(vulnmdsrc.Appenders())) for name, metadataAppender := range vulnmdsrc.Appenders() {
// Shadow the loop variables to avoid closing over the wrong thing.
// See: https://golang.org/doc/faq#closures_and_goroutines
name := name
metadataAppender := metadataAppender
for n, a := range vulnmdsrc.Appenders() { g.Go(func() error {
go func(name string, appender vulnmdsrc.Appender) { // TODO(jzelinskie): add ctx to BuildCache()
defer wg.Done() if err := metadataAppender.BuildCache(datastore); err != nil {
// Build up a metadata cache.
if err := appender.BuildCache(datastore); err != nil {
promUpdaterErrorsTotal.Inc() promUpdaterErrorsTotal.Inc()
log.WithError(err).WithField("appender name", name).Error("an error occurred when loading metadata fetcher") log.WithError(err).WithField("appender", name).Error("an error occurred when fetching vulnerability metadata")
return return err
}
defer metadataAppender.PurgeCache()
for i, vulnerability := range lockableVulnerabilities {
metadataAppender.Append(vulnerability.Name, vulnerability.appendFunc)
if i%10 == 0 {
select {
case <-ctx.Done():
return nil
default:
}
}
} }
// Append vulnerability metadata to each vulnerability. return nil
for _, vulnerability := range lockableVulnerabilities { })
appender.Append(vulnerability.Name, vulnerability.appendFunc)
} }
// Purge the metadata cache. g.Wait()
appender.PurgeCache()
}(n, a)
}
wg.Wait()
return vulnerabilities return vulnerabilities
} }
@ -465,15 +539,6 @@ func doVulnerabilitiesNamespacing(vulnerabilities []database.VulnerabilityWithAf
return response return response
} }
func findLock(datastore database.Datastore, updaterLockName string) (string, time.Time, bool, error) {
tx, err := datastore.Begin()
if err != nil {
log.WithError(err).Error()
}
defer tx.Rollback()
return tx.FindLock(updaterLockName)
}
// updateUpdaterFlags updates the flags specified by updaters, every transaction // updateUpdaterFlags updates the flags specified by updaters, every transaction
// is independent of each other. // is independent of each other.
func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error { func updateUpdaterFlags(datastore database.Datastore, flags map[string]string) error {
@ -619,7 +684,7 @@ func createVulnerabilityNotifications(datastore database.Datastore, changes []vu
// updateVulnerabilities upserts unique vulnerabilities into the database and // updateVulnerabilities upserts unique vulnerabilities into the database and
// computes vulnerability changes. // computes vulnerability changes.
func updateVulnerabilities(datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) { func updateVulnerabilities(ctx context.Context, datastore database.Datastore, vulnerabilities []database.VulnerabilityWithAffected) ([]vulnerabilityChange, error) {
log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities") log.WithField("count", len(vulnerabilities)).Debug("updating vulnerabilities")
if len(vulnerabilities) == 0 { if len(vulnerabilities) == 0 {
return nil, nil return nil, nil
@ -637,13 +702,19 @@ func updateVulnerabilities(datastore database.Datastore, vulnerabilities []datab
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer tx.Rollback() defer tx.Rollback()
oldVulnNullable, err := tx.FindVulnerabilities(ids) oldVulnNullable, err := tx.FindVulnerabilities(ids)
if err != nil { if err != nil {
return nil, err return nil, err
} }
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
oldVuln := []database.VulnerabilityWithAffected{} oldVuln := []database.VulnerabilityWithAffected{}
for _, vuln := range oldVulnNullable { for _, vuln := range oldVulnNullable {
if vuln.Valid { if vuln.Valid {
@ -656,6 +727,12 @@ func updateVulnerabilities(datastore database.Datastore, vulnerabilities []datab
return nil, err return nil, err
} }
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
toRemove := []database.VulnerabilityID{} toRemove := []database.VulnerabilityID{}
toAdd := []database.VulnerabilityWithAffected{} toAdd := []database.VulnerabilityWithAffected{}
for _, change := range changes { for _, change := range changes {

View File

@ -1,4 +1,4 @@
// Copyright 2017 clair authors // Copyright 2019 clair authors
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
@ -15,13 +15,13 @@
package clair package clair
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"github.com/stretchr/testify/assert"
"github.com/coreos/clair/database" "github.com/coreos/clair/database"
"github.com/stretchr/testify/assert"
) )
type mockUpdaterDatastore struct { type mockUpdaterDatastore struct {
@ -270,27 +270,27 @@ func TestCreatVulnerabilityNotification(t *testing.T) {
} }
datastore := newmockUpdaterDatastore() datastore := newmockUpdaterDatastore()
change, err := updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{}) change, err := updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 0) assert.Len(t, change, 0)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 1) assert.Len(t, change, 1)
assert.Nil(t, change[0].old) assert.Nil(t, change[0].old)
assertVulnerability(t, *change[0].new, v1) assertVulnerability(t, *change[0].new, v1)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v1}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v1})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 0) assert.Len(t, change, 0)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v2}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v2})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 1) assert.Len(t, change, 1)
assertVulnerability(t, *change[0].new, v2) assertVulnerability(t, *change[0].new, v2)
assertVulnerability(t, *change[0].old, v1) assertVulnerability(t, *change[0].old, v1)
change, err = updateVulnerabilities(datastore, []database.VulnerabilityWithAffected{v3}) change, err = updateVulnerabilities(context.TODO(), datastore, []database.VulnerabilityWithAffected{v3})
assert.Nil(t, err) assert.Nil(t, err)
assert.Len(t, change, 1) assert.Len(t, change, 1)
assertVulnerability(t, *change[0].new, v3) assertVulnerability(t, *change[0].new, v3)

1
vendor/github.com/google/uuid/go.mod generated vendored Normal file
View File

@ -0,0 +1 @@
module github.com/google/uuid

View File

@ -1,4 +1,4 @@
// Copyright 2016 Google Inc. All rights reserved. // Copyright 2018 Google Inc. All rights reserved.
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
@ -35,20 +35,43 @@ const (
var rander = rand.Reader // random function var rander = rand.Reader // random function
// Parse decodes s into a UUID or returns an error. Both the UUID form of // Parse decodes s into a UUID or returns an error. Both the standard UUID
// xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and // forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded. // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx are decoded as well as the
// Microsoft encoding {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx} and the raw hex
// encoding: xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx.
func Parse(s string) (UUID, error) { func Parse(s string) (UUID, error) {
var uuid UUID var uuid UUID
if len(s) != 36 { switch len(s) {
if len(s) != 36+9 { // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
return uuid, fmt.Errorf("invalid UUID length: %d", len(s)) case 36:
}
// urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
case 36 + 9:
if strings.ToLower(s[:9]) != "urn:uuid:" { if strings.ToLower(s[:9]) != "urn:uuid:" {
return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9]) return uuid, fmt.Errorf("invalid urn prefix: %q", s[:9])
} }
s = s[9:] s = s[9:]
// {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
case 36 + 2:
s = s[1:]
// xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
case 32:
var ok bool
for i := range uuid {
uuid[i], ok = xtob(s[i*2], s[i*2+1])
if !ok {
return uuid, errors.New("invalid UUID format")
} }
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(s))
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' { if s[8] != '-' || s[13] != '-' || s[18] != '-' || s[23] != '-' {
return uuid, errors.New("invalid UUID format") return uuid, errors.New("invalid UUID format")
} }
@ -70,15 +93,29 @@ func Parse(s string) (UUID, error) {
// ParseBytes is like Parse, except it parses a byte slice instead of a string. // ParseBytes is like Parse, except it parses a byte slice instead of a string.
func ParseBytes(b []byte) (UUID, error) { func ParseBytes(b []byte) (UUID, error) {
var uuid UUID var uuid UUID
if len(b) != 36 { switch len(b) {
if len(b) != 36+9 { case 36: // xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
return uuid, fmt.Errorf("invalid UUID length: %d", len(b)) case 36 + 9: // urn:uuid:xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
}
if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) { if !bytes.Equal(bytes.ToLower(b[:9]), []byte("urn:uuid:")) {
return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9]) return uuid, fmt.Errorf("invalid urn prefix: %q", b[:9])
} }
b = b[9:] b = b[9:]
case 36 + 2: // {xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx}
b = b[1:]
case 32: // xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
var ok bool
for i := 0; i < 32; i += 2 {
uuid[i/2], ok = xtob(b[i], b[i+1])
if !ok {
return uuid, errors.New("invalid UUID format")
} }
}
return uuid, nil
default:
return uuid, fmt.Errorf("invalid UUID length: %d", len(b))
}
// s is now at least 36 bytes long
// it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx
if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' { if b[8] != '-' || b[13] != '-' || b[18] != '-' || b[23] != '-' {
return uuid, errors.New("invalid UUID format") return uuid, errors.New("invalid UUID format")
} }

View File

@ -59,12 +59,22 @@ var tests = []test{
{"f47ac10b-58cc-4372-e567-0e02b2c3d479", 4, Future, true}, {"f47ac10b-58cc-4372-e567-0e02b2c3d479", 4, Future, true},
{"f47ac10b-58cc-4372-f567-0e02b2c3d479", 4, Future, true}, {"f47ac10b-58cc-4372-f567-0e02b2c3d479", 4, Future, true},
{"f47ac10b158cc-5372-a567-0e02b2c3d479", 0, Invalid, false}, {"f47ac10b158cc-5372-a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc25372-a567-0e02b2c3d479", 0, Invalid, false}, {"f47ac10b-58cc25372-a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-53723a567-0e02b2c3d479", 0, Invalid, false}, {"f47ac10b-58cc-53723a567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-5372-a56740e02b2c3d479", 0, Invalid, false}, {"f47ac10b-58cc-5372-a56740e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-5372-a567-0e02-2c3d479", 0, Invalid, false}, {"f47ac10b-58cc-5372-a567-0e02-2c3d479", 0, Invalid, false},
{"g47ac10b-58cc-4372-a567-0e02b2c3d479", 0, Invalid, false}, {"g47ac10b-58cc-4372-a567-0e02b2c3d479", 0, Invalid, false},
{"{f47ac10b-58cc-0372-8567-0e02b2c3d479}", 0, RFC4122, true},
{"{f47ac10b-58cc-0372-8567-0e02b2c3d479", 0, Invalid, false},
{"f47ac10b-58cc-0372-8567-0e02b2c3d479}", 0, Invalid, false},
{"f47ac10b58cc037285670e02b2c3d479", 0, RFC4122, true},
{"f47ac10b58cc037285670e02b2c3d4790", 0, Invalid, false},
{"f47ac10b58cc037285670e02b2c3d47", 0, Invalid, false},
} }
var constants = []struct { var constants = []struct {

View File

@ -1,3 +1,23 @@
## 0.9.2 / 2018-12-06
* [FEATURE] Support for Go modules. #501
* [FEATURE] `Timer.ObserveDuration` returns observed duration. #509
* [ENHANCEMENT] Improved doc comments and error messages. #504
* [BUGFIX] Fix race condition during metrics gathering. #512
* [BUGFIX] Fix testutil metric comparison for Histograms and empty labels. #494
#498
## 0.9.1 / 2018-11-03
* [FEATURE] Add `WriteToTextfile` function to facilitate the creation of
*.prom files for the textfile collector of the node exporter. #489
* [ENHANCEMENT] More descriptive error messages for inconsistent label
cardinality. #487
* [ENHANCEMENT] Exposition: Use a GZIP encoder pool to avoid allocations in
high-frequency scrape scenarios. #366
* [ENHANCEMENT] Exposition: Streaming serving of metrics data while encoding.
#482
* [ENHANCEMENT] API client: Add a way to return the body of a 5xx response.
#479
## 0.9.0 / 2018-10-15 ## 0.9.0 / 2018-10-15
* [CHANGE] Go1.6 is no longer supported. * [CHANGE] Go1.6 is no longer supported.
* [CHANGE] More refinements of the `Registry` consistency checks: Duplicated * [CHANGE] More refinements of the `Registry` consistency checks: Duplicated

View File

@ -28,18 +28,53 @@ unexport GOBIN
GO ?= go GO ?= go
GOFMT ?= $(GO)fmt GOFMT ?= $(GO)fmt
FIRST_GOPATH := $(firstword $(subst :, ,$(shell $(GO) env GOPATH))) FIRST_GOPATH := $(firstword $(subst :, ,$(shell $(GO) env GOPATH)))
GOOPTS ?=
GO_VERSION ?= $(shell $(GO) version)
GO_VERSION_NUMBER ?= $(word 3, $(GO_VERSION))
PRE_GO_111 ?= $(shell echo $(GO_VERSION_NUMBER) | grep -E 'go1\.(10|[0-9])\.')
unexport GOVENDOR
ifeq (, $(PRE_GO_111))
ifneq (,$(wildcard go.mod))
# Enforce Go modules support just in case the directory is inside GOPATH (and for Travis CI).
GO111MODULE := on
ifneq (,$(wildcard vendor))
# Always use the local vendor/ directory to satisfy the dependencies.
GOOPTS := $(GOOPTS) -mod=vendor
endif
endif
else
ifneq (,$(wildcard go.mod))
ifneq (,$(wildcard vendor))
$(warning This repository requires Go >= 1.11 because of Go modules)
$(warning Some recipes may not work as expected as the current Go runtime is '$(GO_VERSION_NUMBER)')
endif
else
# This repository isn't using Go modules (yet).
GOVENDOR := $(FIRST_GOPATH)/bin/govendor
endif
unexport GO111MODULE
endif
PROMU := $(FIRST_GOPATH)/bin/promu PROMU := $(FIRST_GOPATH)/bin/promu
STATICCHECK := $(FIRST_GOPATH)/bin/staticcheck STATICCHECK := $(FIRST_GOPATH)/bin/staticcheck
GOVENDOR := $(FIRST_GOPATH)/bin/govendor
pkgs = ./... pkgs = ./...
GO_VERSION ?= $(shell $(GO) version)
GO_BUILD_PLATFORM ?= $(subst /,-,$(lastword $(GO_VERSION)))
PROMU_VERSION ?= 0.2.0
PROMU_URL := https://github.com/prometheus/promu/releases/download/v$(PROMU_VERSION)/promu-$(PROMU_VERSION).$(GO_BUILD_PLATFORM).tar.gz
PREFIX ?= $(shell pwd) PREFIX ?= $(shell pwd)
BIN_DIR ?= $(shell pwd) BIN_DIR ?= $(shell pwd)
DOCKER_IMAGE_TAG ?= $(subst /,-,$(shell git rev-parse --abbrev-ref HEAD)) DOCKER_IMAGE_TAG ?= $(subst /,-,$(shell git rev-parse --abbrev-ref HEAD))
DOCKER_REPO ?= prom DOCKER_REPO ?= prom
.PHONY: all .PHONY: all
all: style staticcheck unused build test all: precheck style staticcheck unused build test
# This rule is used to forward a target like "build" to "common-build". This # This rule is used to forward a target like "build" to "common-build". This
# allows a new "build" target to be defined in a Makefile which includes this # allows a new "build" target to be defined in a Makefile which includes this
@ -70,37 +105,54 @@ common-check_license:
.PHONY: common-test-short .PHONY: common-test-short
common-test-short: common-test-short:
@echo ">> running short tests" @echo ">> running short tests"
$(GO) test -short $(pkgs) GO111MODULE=$(GO111MODULE) $(GO) test -short $(GOOPTS) $(pkgs)
.PHONY: common-test .PHONY: common-test
common-test: common-test:
@echo ">> running all tests" @echo ">> running all tests"
$(GO) test -race $(pkgs) GO111MODULE=$(GO111MODULE) $(GO) test -race $(GOOPTS) $(pkgs)
.PHONY: common-format .PHONY: common-format
common-format: common-format:
@echo ">> formatting code" @echo ">> formatting code"
$(GO) fmt $(pkgs) GO111MODULE=$(GO111MODULE) $(GO) fmt $(GOOPTS) $(pkgs)
.PHONY: common-vet .PHONY: common-vet
common-vet: common-vet:
@echo ">> vetting code" @echo ">> vetting code"
$(GO) vet $(pkgs) GO111MODULE=$(GO111MODULE) $(GO) vet $(GOOPTS) $(pkgs)
.PHONY: common-staticcheck .PHONY: common-staticcheck
common-staticcheck: $(STATICCHECK) common-staticcheck: $(STATICCHECK)
@echo ">> running staticcheck" @echo ">> running staticcheck"
ifdef GO111MODULE
GO111MODULE=$(GO111MODULE) $(STATICCHECK) -ignore "$(STATICCHECK_IGNORE)" -checks "SA*" $(pkgs)
else
$(STATICCHECK) -ignore "$(STATICCHECK_IGNORE)" $(pkgs) $(STATICCHECK) -ignore "$(STATICCHECK_IGNORE)" $(pkgs)
endif
.PHONY: common-unused .PHONY: common-unused
common-unused: $(GOVENDOR) common-unused: $(GOVENDOR)
ifdef GOVENDOR
@echo ">> running check for unused packages" @echo ">> running check for unused packages"
@$(GOVENDOR) list +unused | grep . && exit 1 || echo 'No unused packages' @$(GOVENDOR) list +unused | grep . && exit 1 || echo 'No unused packages'
else
ifdef GO111MODULE
@echo ">> running check for unused/missing packages in go.mod"
GO111MODULE=$(GO111MODULE) $(GO) mod tidy
@git diff --exit-code -- go.sum go.mod
ifneq (,$(wildcard vendor))
@echo ">> running check for unused packages in vendor/"
GO111MODULE=$(GO111MODULE) $(GO) mod vendor
@git diff --exit-code -- go.sum go.mod vendor/
endif
endif
endif
.PHONY: common-build .PHONY: common-build
common-build: promu common-build: promu
@echo ">> building binaries" @echo ">> building binaries"
$(PROMU) build --prefix $(PREFIX) GO111MODULE=$(GO111MODULE) $(PROMU) build --prefix $(PREFIX)
.PHONY: common-tarball .PHONY: common-tarball
common-tarball: promu common-tarball: promu
@ -120,13 +172,52 @@ common-docker-tag-latest:
docker tag "$(DOCKER_REPO)/$(DOCKER_IMAGE_NAME):$(DOCKER_IMAGE_TAG)" "$(DOCKER_REPO)/$(DOCKER_IMAGE_NAME):latest" docker tag "$(DOCKER_REPO)/$(DOCKER_IMAGE_NAME):$(DOCKER_IMAGE_TAG)" "$(DOCKER_REPO)/$(DOCKER_IMAGE_NAME):latest"
.PHONY: promu .PHONY: promu
promu: promu: $(PROMU)
GOOS= GOARCH= $(GO) get -u github.com/prometheus/promu
$(PROMU):
curl -s -L $(PROMU_URL) | tar -xvz -C /tmp
mkdir -v -p $(FIRST_GOPATH)/bin
cp -v /tmp/promu-$(PROMU_VERSION).$(GO_BUILD_PLATFORM)/promu $(PROMU)
.PHONY: proto
proto:
@echo ">> generating code from proto files"
@./scripts/genproto.sh
.PHONY: $(STATICCHECK) .PHONY: $(STATICCHECK)
$(STATICCHECK): $(STATICCHECK):
GOOS= GOARCH= $(GO) get -u honnef.co/go/tools/cmd/staticcheck ifdef GO111MODULE
# Get staticcheck from a temporary directory to avoid modifying the local go.{mod,sum}.
# See https://github.com/golang/go/issues/27643.
# For now, we are using the next branch of staticcheck because master isn't compatible yet with Go modules.
tmpModule=$$(mktemp -d 2>&1) && \
mkdir -p $${tmpModule}/staticcheck && \
cd "$${tmpModule}"/staticcheck && \
GO111MODULE=on $(GO) mod init example.com/staticcheck && \
GO111MODULE=on GOOS= GOARCH= $(GO) get -u honnef.co/go/tools/cmd/staticcheck@next && \
rm -rf $${tmpModule};
else
GOOS= GOARCH= GO111MODULE=off $(GO) get -u honnef.co/go/tools/cmd/staticcheck
endif
ifdef GOVENDOR
.PHONY: $(GOVENDOR) .PHONY: $(GOVENDOR)
$(GOVENDOR): $(GOVENDOR):
GOOS= GOARCH= $(GO) get -u github.com/kardianos/govendor GOOS= GOARCH= $(GO) get -u github.com/kardianos/govendor
endif
.PHONY: precheck
precheck::
define PRECHECK_COMMAND_template =
precheck:: $(1)_precheck
PRECHECK_COMMAND_$(1) ?= $(1) $$(strip $$(PRECHECK_OPTIONS_$(1)))
.PHONY: $(1)_precheck
$(1)_precheck:
@if ! $$(PRECHECK_COMMAND_$(1)) 1>/dev/null 2>&1; then \
echo "Execution of '$$(PRECHECK_COMMAND_$(1))' command failed. Is $(1) installed?"; \
exit 1; \
fi
endef

View File

@ -1 +1 @@
0.9.0 0.9.2

View File

@ -56,10 +56,12 @@ type HealthStatus string
const ( const (
// Possible values for ErrorType. // Possible values for ErrorType.
ErrBadData ErrorType = "bad_data" ErrBadData ErrorType = "bad_data"
ErrTimeout = "timeout" ErrTimeout ErrorType = "timeout"
ErrCanceled = "canceled" ErrCanceled ErrorType = "canceled"
ErrExec = "execution" ErrExec ErrorType = "execution"
ErrBadResponse = "bad_response" ErrBadResponse ErrorType = "bad_response"
ErrServer ErrorType = "server_error"
ErrClient ErrorType = "client_error"
// Possible values for HealthStatus. // Possible values for HealthStatus.
HealthGood HealthStatus = "up" HealthGood HealthStatus = "up"
@ -71,6 +73,7 @@ const (
type Error struct { type Error struct {
Type ErrorType Type ErrorType
Msg string Msg string
Detail string
} }
func (e *Error) Error() string { func (e *Error) Error() string {
@ -460,6 +463,16 @@ func apiError(code int) bool {
return code == statusAPIError || code == http.StatusBadRequest return code == statusAPIError || code == http.StatusBadRequest
} }
func errorTypeAndMsgFor(resp *http.Response) (ErrorType, string) {
switch resp.StatusCode / 100 {
case 4:
return ErrClient, fmt.Sprintf("client error: %d", resp.StatusCode)
case 5:
return ErrServer, fmt.Sprintf("server error: %d", resp.StatusCode)
}
return ErrBadResponse, fmt.Sprintf("bad response code %d", resp.StatusCode)
}
func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) { func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, []byte, error) {
resp, body, err := c.Client.Do(ctx, req) resp, body, err := c.Client.Do(ctx, req)
if err != nil { if err != nil {
@ -469,9 +482,11 @@ func (c apiClient) Do(ctx context.Context, req *http.Request) (*http.Response, [
code := resp.StatusCode code := resp.StatusCode
if code/100 != 2 && !apiError(code) { if code/100 != 2 && !apiError(code) {
errorType, errorMsg := errorTypeAndMsgFor(resp)
return resp, body, &Error{ return resp, body, &Error{
Type: ErrBadResponse, Type: errorType,
Msg: fmt.Sprintf("bad response code %d", resp.StatusCode), Msg: errorMsg,
Detail: string(body),
} }
} }

View File

@ -18,6 +18,7 @@ package v1
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
@ -32,6 +33,7 @@ import (
type apiTest struct { type apiTest struct {
do func() (interface{}, error) do func() (interface{}, error)
inErr error inErr error
inStatusCode int
inRes interface{} inRes interface{}
reqPath string reqPath string
@ -75,7 +77,9 @@ func (c *apiTestClient) Do(ctx context.Context, req *http.Request) (*http.Respon
} }
resp := &http.Response{} resp := &http.Response{}
if test.inErr != nil { if test.inStatusCode != 0 {
resp.StatusCode = test.inStatusCode
} else if test.inErr != nil {
resp.StatusCode = statusAPIError resp.StatusCode = statusAPIError
} else { } else {
resp.StatusCode = http.StatusOK resp.StatusCode = http.StatusOK
@ -194,6 +198,42 @@ func TestAPIs(t *testing.T) {
}, },
err: fmt.Errorf("some error"), err: fmt.Errorf("some error"),
}, },
{
do: doQuery("2", testTime),
inRes: "some body",
inStatusCode: 500,
inErr: &Error{
Type: ErrServer,
Msg: "server error: 500",
Detail: "some body",
},
reqMethod: "GET",
reqPath: "/api/v1/query",
reqParam: url.Values{
"query": []string{"2"},
"time": []string{testTime.Format(time.RFC3339Nano)},
},
err: errors.New("server_error: server error: 500"),
},
{
do: doQuery("2", testTime),
inRes: "some body",
inStatusCode: 404,
inErr: &Error{
Type: ErrClient,
Msg: "client error: 404",
Detail: "some body",
},
reqMethod: "GET",
reqPath: "/api/v1/query",
reqParam: url.Values{
"query": []string{"2"},
"time": []string{testTime.Format(time.RFC3339Nano)},
},
err: errors.New("client_error: client error: 404"),
},
{ {
do: doQueryRange("2", Range{ do: doQueryRange("2", Range{
@ -498,29 +538,34 @@ func TestAPIs(t *testing.T) {
var tests []apiTest var tests []apiTest
tests = append(tests, queryTests...) tests = append(tests, queryTests...)
for _, test := range tests { for i, test := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
client.curTest = test client.curTest = test
res, err := test.do() res, err := test.do()
if test.err != nil { if test.err != nil {
if err == nil { if err == nil {
t.Errorf("expected error %q but got none", test.err) t.Fatalf("expected error %q but got none", test.err)
continue
} }
if err.Error() != test.err.Error() { if err.Error() != test.err.Error() {
t.Errorf("unexpected error: want %s, got %s", test.err, err) t.Errorf("unexpected error: want %s, got %s", test.err, err)
} }
continue if apiErr, ok := err.(*Error); ok {
if apiErr.Detail != test.inRes {
t.Errorf("%q should be %q", apiErr.Detail, test.inRes)
}
}
return
} }
if err != nil { if err != nil {
t.Errorf("unexpected error: %s", err) t.Fatalf("unexpected error: %s", err)
continue
} }
if !reflect.DeepEqual(res, test.res) { if !reflect.DeepEqual(res, test.res) {
t.Errorf("unexpected result: want %v, got %v", test.res, res) t.Errorf("unexpected result: want %v, got %v", test.res, res)
} }
})
} }
} }
@ -534,8 +579,8 @@ type testClient struct {
type apiClientTest struct { type apiClientTest struct {
code int code int
response interface{} response interface{}
expected string expectedBody string
err *Error expectedErr *Error
} }
func (c *testClient) URL(ep string, args map[string]string) *url.URL { func (c *testClient) URL(ep string, args map[string]string) *url.URL {
@ -575,98 +620,108 @@ func (c *testClient) Do(ctx context.Context, req *http.Request) (*http.Response,
func TestAPIClientDo(t *testing.T) { func TestAPIClientDo(t *testing.T) {
tests := []apiClientTest{ tests := []apiClientTest{
{ {
code: statusAPIError,
response: &apiResponse{ response: &apiResponse{
Status: "error", Status: "error",
Data: json.RawMessage(`null`), Data: json.RawMessage(`null`),
ErrorType: ErrBadData, ErrorType: ErrBadData,
Error: "failed", Error: "failed",
}, },
err: &Error{ expectedErr: &Error{
Type: ErrBadData, Type: ErrBadData,
Msg: "failed", Msg: "failed",
}, },
code: statusAPIError, expectedBody: `null`,
expected: `null`,
}, },
{ {
code: statusAPIError,
response: &apiResponse{ response: &apiResponse{
Status: "error", Status: "error",
Data: json.RawMessage(`"test"`), Data: json.RawMessage(`"test"`),
ErrorType: ErrTimeout, ErrorType: ErrTimeout,
Error: "timed out", Error: "timed out",
}, },
err: &Error{ expectedErr: &Error{
Type: ErrTimeout, Type: ErrTimeout,
Msg: "timed out", Msg: "timed out",
}, },
code: statusAPIError, expectedBody: `test`,
expected: `test`,
}, },
{ {
response: "bad json",
err: &Error{
Type: ErrBadResponse,
Msg: "bad response code 500",
},
code: http.StatusInternalServerError, code: http.StatusInternalServerError,
response: "500 error details",
expectedErr: &Error{
Type: ErrServer,
Msg: "server error: 500",
Detail: "500 error details",
},
}, },
{ {
code: http.StatusNotFound,
response: "404 error details",
expectedErr: &Error{
Type: ErrClient,
Msg: "client error: 404",
Detail: "404 error details",
},
},
{
code: http.StatusBadRequest,
response: &apiResponse{ response: &apiResponse{
Status: "error", Status: "error",
Data: json.RawMessage(`null`), Data: json.RawMessage(`null`),
ErrorType: ErrBadData, ErrorType: ErrBadData,
Error: "end timestamp must not be before start time", Error: "end timestamp must not be before start time",
}, },
err: &Error{ expectedErr: &Error{
Type: ErrBadData, Type: ErrBadData,
Msg: "end timestamp must not be before start time", Msg: "end timestamp must not be before start time",
}, },
code: http.StatusBadRequest,
}, },
{ {
code: statusAPIError,
response: "bad json", response: "bad json",
err: &Error{ expectedErr: &Error{
Type: ErrBadResponse, Type: ErrBadResponse,
Msg: "invalid character 'b' looking for beginning of value", Msg: "invalid character 'b' looking for beginning of value",
}, },
code: statusAPIError,
}, },
{ {
code: statusAPIError,
response: &apiResponse{ response: &apiResponse{
Status: "success", Status: "success",
Data: json.RawMessage(`"test"`), Data: json.RawMessage(`"test"`),
}, },
err: &Error{ expectedErr: &Error{
Type: ErrBadResponse, Type: ErrBadResponse,
Msg: "inconsistent body for response code", Msg: "inconsistent body for response code",
}, },
code: statusAPIError,
}, },
{ {
code: statusAPIError,
response: &apiResponse{ response: &apiResponse{
Status: "success", Status: "success",
Data: json.RawMessage(`"test"`), Data: json.RawMessage(`"test"`),
ErrorType: ErrTimeout, ErrorType: ErrTimeout,
Error: "timed out", Error: "timed out",
}, },
err: &Error{ expectedErr: &Error{
Type: ErrBadResponse, Type: ErrBadResponse,
Msg: "inconsistent body for response code", Msg: "inconsistent body for response code",
}, },
code: statusAPIError,
}, },
{ {
code: http.StatusOK,
response: &apiResponse{ response: &apiResponse{
Status: "error", Status: "error",
Data: json.RawMessage(`"test"`), Data: json.RawMessage(`"test"`),
ErrorType: ErrTimeout, ErrorType: ErrTimeout,
Error: "timed out", Error: "timed out",
}, },
err: &Error{ expectedErr: &Error{
Type: ErrBadResponse, Type: ErrBadResponse,
Msg: "inconsistent body for response code", Msg: "inconsistent body for response code",
}, },
code: http.StatusOK,
}, },
} }
@ -677,30 +732,37 @@ func TestAPIClientDo(t *testing.T) {
} }
client := &apiClient{tc} client := &apiClient{tc}
for _, test := range tests { for i, test := range tests {
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
tc.ch <- test tc.ch <- test
_, body, err := client.Do(context.Background(), tc.req) _, body, err := client.Do(context.Background(), tc.req)
if test.err != nil { if test.expectedErr != nil {
if err == nil { if err == nil {
t.Errorf("expected error %q but got none", test.err) t.Fatalf("expected error %q but got none", test.expectedErr)
continue
} }
if test.err.Error() != err.Error() { if test.expectedErr.Error() != err.Error() {
t.Errorf("unexpected error: want %q, got %q", test.err, err) t.Errorf("unexpected error: want %q, got %q", test.expectedErr, err)
} }
continue if test.expectedErr.Detail != "" {
apiErr := err.(*Error)
if apiErr.Detail != test.expectedErr.Detail {
t.Errorf("unexpected error details: want %q, got %q", test.expectedErr.Detail, apiErr.Detail)
}
}
return
} }
if err != nil { if err != nil {
t.Errorf("unexpeceted error %s", err) t.Fatalf("unexpeceted error %s", err)
continue
} }
want, got := test.expected, string(body) want, got := test.expectedBody, string(body)
if want != got { if want != got {
t.Errorf("unexpected body: want %q, got %q", want, got) t.Errorf("unexpected body: want %q, got %q", want, got)
} }
})
} }
} }

12
vendor/github.com/prometheus/client_golang/go.mod generated vendored Normal file
View File

@ -0,0 +1,12 @@
module github.com/prometheus/client_golang
require (
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973
github.com/golang/protobuf v1.2.0
github.com/matttproud/golang_protobuf_extensions v1.0.1 // indirect
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910
github.com/prometheus/common v0.0.0-20181126121408-4724e9255275
github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a
golang.org/x/net v0.0.0-20181201002055-351d144fa1fc
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f // indirect
)

16
vendor/github.com/prometheus/client_golang/go.sum generated vendored Normal file
View File

@ -0,0 +1,16 @@
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973 h1:xJ4a3vCFaGF/jqvzLMYoU8P317H5OQ+Via4RmuPwCS0=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/golang/protobuf v1.2.0 h1:P3YflyNX/ehuJFLhxviNdFxQPkGK5cDcApsge1SqnvM=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/matttproud/golang_protobuf_extensions v1.0.1 h1:4hp9jkHxhMHkqkrB3Ix0jegS5sx/RkqARlsWZ6pIwiU=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910 h1:idejC8f05m9MGOsuEi1ATq9shN03HrxNkD/luQvxCv8=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/common v0.0.0-20181126121408-4724e9255275 h1:PnBWHBf+6L0jOqq0gIVUe6Yk0/QMZ640k6NvkxcBf+8=
github.com/prometheus/common v0.0.0-20181126121408-4724e9255275/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a h1:9a8MnZMP0X2nLJdBg+pBmGgkJlSaKC2KaQmTCk1XDtE=
github.com/prometheus/procfs v0.0.0-20181204211112-1dc9a6cbc91a/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
golang.org/x/net v0.0.0-20181201002055-351d144fa1fc h1:a3CU5tJYVj92DY2LaA1kUkrsqD5/3mLDhx2NcNqyW+0=
golang.org/x/net v0.0.0-20181201002055-351d144fa1fc/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f h1:Bl/8QSvNqXvPGPGXa2z5xUTmV7VDcZyvRZ+QQXkXTZQ=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@ -136,7 +136,7 @@ func NewCounterVec(opts CounterOpts, labelNames []string) *CounterVec {
return &CounterVec{ return &CounterVec{
metricVec: newMetricVec(desc, func(lvs ...string) Metric { metricVec: newMetricVec(desc, func(lvs ...string) Metric {
if len(lvs) != len(desc.variableLabels) { if len(lvs) != len(desc.variableLabels) {
panic(errInconsistentCardinality) panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels, lvs))
} }
result := &counter{desc: desc, labelPairs: makeLabelPairs(desc, lvs)} result := &counter{desc: desc, labelPairs: makeLabelPairs(desc, lvs)}
result.init(result) // Init self-collection. result.init(result) // Init self-collection.

View File

@ -93,7 +93,7 @@ func NewDesc(fqName, help string, variableLabels []string, constLabels Labels) *
// First add only the const label names and sort them... // First add only the const label names and sort them...
for labelName := range constLabels { for labelName := range constLabels {
if !checkLabelName(labelName) { if !checkLabelName(labelName) {
d.err = fmt.Errorf("%q is not a valid label name", labelName) d.err = fmt.Errorf("%q is not a valid label name for metric %q", labelName, fqName)
return d return d
} }
labelNames = append(labelNames, labelName) labelNames = append(labelNames, labelName)
@ -115,7 +115,7 @@ func NewDesc(fqName, help string, variableLabels []string, constLabels Labels) *
// dimension with a different mix between preset and variable labels. // dimension with a different mix between preset and variable labels.
for _, labelName := range variableLabels { for _, labelName := range variableLabels {
if !checkLabelName(labelName) { if !checkLabelName(labelName) {
d.err = fmt.Errorf("%q is not a valid label name", labelName) d.err = fmt.Errorf("%q is not a valid label name for metric %q", labelName, fqName)
return d return d
} }
labelNames = append(labelNames, "$"+labelName) labelNames = append(labelNames, "$"+labelName)

View File

@ -13,7 +13,13 @@
package prometheus_test package prometheus_test
import "github.com/prometheus/client_golang/prometheus" import (
"log"
"net/http"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
)
// ClusterManager is an example for a system that might have been built without // ClusterManager is an example for a system that might have been built without
// Prometheus in mind. It models a central manager of jobs running in a // Prometheus in mind. It models a central manager of jobs running in a
@ -124,4 +130,13 @@ func ExampleCollector() {
// variables to then do something with them. // variables to then do something with them.
NewClusterManager("db", reg) NewClusterManager("db", reg)
NewClusterManager("ca", reg) NewClusterManager("ca", reg)
// Add the standard process and Go metrics to the custom registry.
reg.MustRegister(
prometheus.NewProcessCollector(prometheus.ProcessCollectorOpts{}),
prometheus.NewGoCollector(),
)
http.Handle("/metrics", promhttp.HandlerFor(reg, promhttp.HandlerOpts{}))
log.Fatal(http.ListenAndServe(":8080", nil))
} }

View File

@ -282,7 +282,7 @@ func ExampleRegister() {
// taskCounter unregistered. // taskCounter unregistered.
// taskCounterVec not registered: a previously registered descriptor with the same fully-qualified name as Desc{fqName: "worker_pool_completed_tasks_total", help: "Total number of tasks completed.", constLabels: {}, variableLabels: [worker_id]} has different label names or a different help string // taskCounterVec not registered: a previously registered descriptor with the same fully-qualified name as Desc{fqName: "worker_pool_completed_tasks_total", help: "Total number of tasks completed.", constLabels: {}, variableLabels: [worker_id]} has different label names or a different help string
// taskCounterVec registered. // taskCounterVec registered.
// Worker initialization failed: inconsistent label cardinality // Worker initialization failed: inconsistent label cardinality: expected 1 label values but got 2 in []string{"42", "spurious arg"}
// notMyCounter is nil. // notMyCounter is nil.
// taskCounterForWorker42 registered. // taskCounterForWorker42 registered.
// taskCounterForWorker2001 registered. // taskCounterForWorker2001 registered.

View File

@ -147,7 +147,7 @@ func NewGaugeVec(opts GaugeOpts, labelNames []string) *GaugeVec {
return &GaugeVec{ return &GaugeVec{
metricVec: newMetricVec(desc, func(lvs ...string) Metric { metricVec: newMetricVec(desc, func(lvs ...string) Metric {
if len(lvs) != len(desc.variableLabels) { if len(lvs) != len(desc.variableLabels) {
panic(errInconsistentCardinality) panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels, lvs))
} }
result := &gauge{desc: desc, labelPairs: makeLabelPairs(desc, lvs)} result := &gauge{desc: desc, labelPairs: makeLabelPairs(desc, lvs)}
result.init(result) // Init self-collection. result.init(result) // Init self-collection.

View File

@ -165,7 +165,7 @@ func NewHistogram(opts HistogramOpts) Histogram {
func newHistogram(desc *Desc, opts HistogramOpts, labelValues ...string) Histogram { func newHistogram(desc *Desc, opts HistogramOpts, labelValues ...string) Histogram {
if len(desc.variableLabels) != len(labelValues) { if len(desc.variableLabels) != len(labelValues) {
panic(errInconsistentCardinality) panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels, labelValues))
} }
for _, n := range desc.variableLabels { for _, n := range desc.variableLabels {

View File

@ -15,9 +15,7 @@ package prometheus
import ( import (
"bufio" "bufio"
"bytes"
"compress/gzip" "compress/gzip"
"fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
@ -41,19 +39,10 @@ const (
acceptEncodingHeader = "Accept-Encoding" acceptEncodingHeader = "Accept-Encoding"
) )
var bufPool sync.Pool var gzipPool = sync.Pool{
New: func() interface{} {
func getBuf() *bytes.Buffer { return gzip.NewWriter(nil)
buf := bufPool.Get() },
if buf == nil {
return &bytes.Buffer{}
}
return buf.(*bytes.Buffer)
}
func giveBuf(buf *bytes.Buffer) {
buf.Reset()
bufPool.Put(buf)
} }
// Handler returns an HTTP handler for the DefaultGatherer. It is // Handler returns an HTTP handler for the DefaultGatherer. It is
@ -71,58 +60,40 @@ func Handler() http.Handler {
// Deprecated: Use promhttp.HandlerFor(DefaultGatherer, promhttp.HandlerOpts{}) // Deprecated: Use promhttp.HandlerFor(DefaultGatherer, promhttp.HandlerOpts{})
// instead. See there for further documentation. // instead. See there for further documentation.
func UninstrumentedHandler() http.Handler { func UninstrumentedHandler() http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { return http.HandlerFunc(func(rsp http.ResponseWriter, req *http.Request) {
mfs, err := DefaultGatherer.Gather() mfs, err := DefaultGatherer.Gather()
if err != nil { if err != nil {
http.Error(w, "An error has occurred during metrics collection:\n\n"+err.Error(), http.StatusInternalServerError) httpError(rsp, err)
return return
} }
contentType := expfmt.Negotiate(req.Header) contentType := expfmt.Negotiate(req.Header)
buf := getBuf() header := rsp.Header()
defer giveBuf(buf) header.Set(contentTypeHeader, string(contentType))
writer, encoding := decorateWriter(req, buf)
enc := expfmt.NewEncoder(writer, contentType) w := io.Writer(rsp)
var lastErr error if gzipAccepted(req.Header) {
header.Set(contentEncodingHeader, "gzip")
gz := gzipPool.Get().(*gzip.Writer)
defer gzipPool.Put(gz)
gz.Reset(w)
defer gz.Close()
w = gz
}
enc := expfmt.NewEncoder(w, contentType)
for _, mf := range mfs { for _, mf := range mfs {
if err := enc.Encode(mf); err != nil { if err := enc.Encode(mf); err != nil {
lastErr = err httpError(rsp, err)
http.Error(w, "An error has occurred during metrics encoding:\n\n"+err.Error(), http.StatusInternalServerError)
return return
} }
} }
if closer, ok := writer.(io.Closer); ok {
closer.Close()
}
if lastErr != nil && buf.Len() == 0 {
http.Error(w, "No metrics encoded, last error:\n\n"+lastErr.Error(), http.StatusInternalServerError)
return
}
header := w.Header()
header.Set(contentTypeHeader, string(contentType))
header.Set(contentLengthHeader, fmt.Sprint(buf.Len()))
if encoding != "" {
header.Set(contentEncodingHeader, encoding)
}
w.Write(buf.Bytes())
}) })
} }
// decorateWriter wraps a writer to handle gzip compression if requested. It
// returns the decorated writer and the appropriate "Content-Encoding" header
// (which is empty if no compression is enabled).
func decorateWriter(request *http.Request, writer io.Writer) (io.Writer, string) {
header := request.Header.Get(acceptEncodingHeader)
parts := strings.Split(header, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "gzip" || strings.HasPrefix(part, "gzip;") {
return gzip.NewWriter(writer), "gzip"
}
}
return writer, ""
}
var instLabels = []string{"method", "code"} var instLabels = []string{"method", "code"}
type nower interface { type nower interface {
@ -503,3 +474,31 @@ func sanitizeCode(s int) string {
return strconv.Itoa(s) return strconv.Itoa(s)
} }
} }
// gzipAccepted returns whether the client will accept gzip-encoded content.
func gzipAccepted(header http.Header) bool {
a := header.Get(acceptEncodingHeader)
parts := strings.Split(a, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "gzip" || strings.HasPrefix(part, "gzip;") {
return true
}
}
return false
}
// httpError removes any content-encoding header and then calls http.Error with
// the provided error and http.StatusInternalServerErrer. Error contents is
// supposed to be uncompressed plain text. However, same as with a plain
// http.Error, any header settings will be void if the header has already been
// sent. The error message will still be written to the writer, but it will
// probably be of limited use.
func httpError(rsp http.ResponseWriter, err error) {
rsp.Header().Del(contentEncodingHeader)
http.Error(
rsp,
"An error has occurred while serving metrics:\n\n"+err.Error(),
http.StatusInternalServerError,
)
}

View File

@ -37,9 +37,22 @@ const reservedLabelPrefix = "__"
var errInconsistentCardinality = errors.New("inconsistent label cardinality") var errInconsistentCardinality = errors.New("inconsistent label cardinality")
func makeInconsistentCardinalityError(fqName string, labels, labelValues []string) error {
return fmt.Errorf(
"%s: %q has %d variable labels named %q but %d values %q were provided",
errInconsistentCardinality, fqName,
len(labels), labels,
len(labelValues), labelValues,
)
}
func validateValuesInLabels(labels Labels, expectedNumberOfValues int) error { func validateValuesInLabels(labels Labels, expectedNumberOfValues int) error {
if len(labels) != expectedNumberOfValues { if len(labels) != expectedNumberOfValues {
return errInconsistentCardinality return fmt.Errorf(
"%s: expected %d label values but got %d in %#v",
errInconsistentCardinality, expectedNumberOfValues,
len(labels), labels,
)
} }
for name, val := range labels { for name, val := range labels {
@ -53,7 +66,11 @@ func validateValuesInLabels(labels Labels, expectedNumberOfValues int) error {
func validateLabelValues(vals []string, expectedNumberOfValues int) error { func validateLabelValues(vals []string, expectedNumberOfValues int) error {
if len(vals) != expectedNumberOfValues { if len(vals) != expectedNumberOfValues {
return errInconsistentCardinality return fmt.Errorf(
"%s: expected %d label values but got %d in %#v",
errInconsistentCardinality, expectedNumberOfValues,
len(vals), vals,
)
} }
for _, val := range vals { for _, val := range vals {

View File

@ -32,7 +32,6 @@
package promhttp package promhttp
import ( import (
"bytes"
"compress/gzip" "compress/gzip"
"fmt" "fmt"
"io" "io"
@ -53,19 +52,10 @@ const (
acceptEncodingHeader = "Accept-Encoding" acceptEncodingHeader = "Accept-Encoding"
) )
var bufPool sync.Pool var gzipPool = sync.Pool{
New: func() interface{} {
func getBuf() *bytes.Buffer { return gzip.NewWriter(nil)
buf := bufPool.Get() },
if buf == nil {
return &bytes.Buffer{}
}
return buf.(*bytes.Buffer)
}
func giveBuf(buf *bytes.Buffer) {
buf.Reset()
bufPool.Put(buf)
} }
// Handler returns an http.Handler for the prometheus.DefaultGatherer, using // Handler returns an http.Handler for the prometheus.DefaultGatherer, using
@ -100,19 +90,18 @@ func HandlerFor(reg prometheus.Gatherer, opts HandlerOpts) http.Handler {
inFlightSem = make(chan struct{}, opts.MaxRequestsInFlight) inFlightSem = make(chan struct{}, opts.MaxRequestsInFlight)
} }
h := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { h := http.HandlerFunc(func(rsp http.ResponseWriter, req *http.Request) {
if inFlightSem != nil { if inFlightSem != nil {
select { select {
case inFlightSem <- struct{}{}: // All good, carry on. case inFlightSem <- struct{}{}: // All good, carry on.
defer func() { <-inFlightSem }() defer func() { <-inFlightSem }()
default: default:
http.Error(w, fmt.Sprintf( http.Error(rsp, fmt.Sprintf(
"Limit of concurrent requests reached (%d), try again later.", opts.MaxRequestsInFlight, "Limit of concurrent requests reached (%d), try again later.", opts.MaxRequestsInFlight,
), http.StatusServiceUnavailable) ), http.StatusServiceUnavailable)
return return
} }
} }
mfs, err := reg.Gather() mfs, err := reg.Gather()
if err != nil { if err != nil {
if opts.ErrorLog != nil { if opts.ErrorLog != nil {
@ -123,26 +112,40 @@ func HandlerFor(reg prometheus.Gatherer, opts HandlerOpts) http.Handler {
panic(err) panic(err)
case ContinueOnError: case ContinueOnError:
if len(mfs) == 0 { if len(mfs) == 0 {
http.Error(w, "No metrics gathered, last error:\n\n"+err.Error(), http.StatusInternalServerError) // Still report the error if no metrics have been gathered.
httpError(rsp, err)
return return
} }
case HTTPErrorOnError: case HTTPErrorOnError:
http.Error(w, "An error has occurred during metrics gathering:\n\n"+err.Error(), http.StatusInternalServerError) httpError(rsp, err)
return return
} }
} }
contentType := expfmt.Negotiate(req.Header) contentType := expfmt.Negotiate(req.Header)
buf := getBuf() header := rsp.Header()
defer giveBuf(buf) header.Set(contentTypeHeader, string(contentType))
writer, encoding := decorateWriter(req, buf, opts.DisableCompression)
enc := expfmt.NewEncoder(writer, contentType) w := io.Writer(rsp)
if !opts.DisableCompression && gzipAccepted(req.Header) {
header.Set(contentEncodingHeader, "gzip")
gz := gzipPool.Get().(*gzip.Writer)
defer gzipPool.Put(gz)
gz.Reset(w)
defer gz.Close()
w = gz
}
enc := expfmt.NewEncoder(w, contentType)
var lastErr error var lastErr error
for _, mf := range mfs { for _, mf := range mfs {
if err := enc.Encode(mf); err != nil { if err := enc.Encode(mf); err != nil {
lastErr = err lastErr = err
if opts.ErrorLog != nil { if opts.ErrorLog != nil {
opts.ErrorLog.Println("error encoding metric family:", err) opts.ErrorLog.Println("error encoding and sending metric family:", err)
} }
switch opts.ErrorHandling { switch opts.ErrorHandling {
case PanicOnError: case PanicOnError:
@ -150,28 +153,15 @@ func HandlerFor(reg prometheus.Gatherer, opts HandlerOpts) http.Handler {
case ContinueOnError: case ContinueOnError:
// Handled later. // Handled later.
case HTTPErrorOnError: case HTTPErrorOnError:
http.Error(w, "An error has occurred during metrics encoding:\n\n"+err.Error(), http.StatusInternalServerError) httpError(rsp, err)
return return
} }
} }
} }
if closer, ok := writer.(io.Closer); ok {
closer.Close() if lastErr != nil {
httpError(rsp, lastErr)
} }
if lastErr != nil && buf.Len() == 0 {
http.Error(w, "No metrics encoded, last error:\n\n"+lastErr.Error(), http.StatusInternalServerError)
return
}
header := w.Header()
header.Set(contentTypeHeader, string(contentType))
header.Set(contentLengthHeader, fmt.Sprint(buf.Len()))
if encoding != "" {
header.Set(contentEncodingHeader, encoding)
}
if _, err := w.Write(buf.Bytes()); err != nil && opts.ErrorLog != nil {
opts.ErrorLog.Println("error while sending encoded metrics:", err)
}
// TODO(beorn7): Consider streaming serving of metrics.
}) })
if opts.Timeout <= 0 { if opts.Timeout <= 0 {
@ -292,20 +282,30 @@ type HandlerOpts struct {
Timeout time.Duration Timeout time.Duration
} }
// decorateWriter wraps a writer to handle gzip compression if requested. It // gzipAccepted returns whether the client will accept gzip-encoded content.
// returns the decorated writer and the appropriate "Content-Encoding" header func gzipAccepted(header http.Header) bool {
// (which is empty if no compression is enabled). a := header.Get(acceptEncodingHeader)
func decorateWriter(request *http.Request, writer io.Writer, compressionDisabled bool) (io.Writer, string) { parts := strings.Split(a, ",")
if compressionDisabled {
return writer, ""
}
header := request.Header.Get(acceptEncodingHeader)
parts := strings.Split(header, ",")
for _, part := range parts { for _, part := range parts {
part = strings.TrimSpace(part) part = strings.TrimSpace(part)
if part == "gzip" || strings.HasPrefix(part, "gzip;") { if part == "gzip" || strings.HasPrefix(part, "gzip;") {
return gzip.NewWriter(writer), "gzip" return true
} }
} }
return writer, "" return false
}
// httpError removes any content-encoding header and then calls http.Error with
// the provided error and http.StatusInternalServerErrer. Error contents is
// supposed to be uncompressed plain text. However, same as with a plain
// http.Error, any header settings will be void if the header has already been
// sent. The error message will still be written to the writer, but it will
// probably be of limited use.
func httpError(rsp http.ResponseWriter, err error) {
rsp.Header().Del(contentEncodingHeader)
http.Error(
rsp,
"An error has occurred while serving metrics:\n\n"+err.Error(),
http.StatusInternalServerError,
)
} }

View File

@ -103,7 +103,7 @@ func TestHandlerErrorHandling(t *testing.T) {
}) })
wantMsg := `error gathering metrics: error collecting metric Desc{fqName: "invalid_metric", help: "not helpful", constLabels: {}, variableLabels: []}: collect error wantMsg := `error gathering metrics: error collecting metric Desc{fqName: "invalid_metric", help: "not helpful", constLabels: {}, variableLabels: []}: collect error
` `
wantErrorBody := `An error has occurred during metrics gathering: wantErrorBody := `An error has occurred while serving metrics:
error collecting metric Desc{fqName: "invalid_metric", help: "not helpful", constLabels: {}, variableLabels: []}: collect error error collecting metric Desc{fqName: "invalid_metric", help: "not helpful", constLabels: {}, variableLabels: []}: collect error
` `

View File

@ -16,6 +16,9 @@ package prometheus
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime" "runtime"
"sort" "sort"
"strings" "strings"
@ -23,6 +26,7 @@ import (
"unicode/utf8" "unicode/utf8"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/prometheus/common/expfmt"
dto "github.com/prometheus/client_model/go" dto "github.com/prometheus/client_model/go"
@ -533,6 +537,38 @@ func (r *Registry) Gather() ([]*dto.MetricFamily, error) {
return internal.NormalizeMetricFamilies(metricFamiliesByName), errs.MaybeUnwrap() return internal.NormalizeMetricFamilies(metricFamiliesByName), errs.MaybeUnwrap()
} }
// WriteToTextfile calls Gather on the provided Gatherer, encodes the result in the
// Prometheus text format, and writes it to a temporary file. Upon success, the
// temporary file is renamed to the provided filename.
//
// This is intended for use with the textfile collector of the node exporter.
// Note that the node exporter expects the filename to be suffixed with ".prom".
func WriteToTextfile(filename string, g Gatherer) error {
tmp, err := ioutil.TempFile(filepath.Dir(filename), filepath.Base(filename))
if err != nil {
return err
}
defer os.Remove(tmp.Name())
mfs, err := g.Gather()
if err != nil {
return err
}
for _, mf := range mfs {
if _, err := expfmt.MetricFamilyToText(tmp, mf); err != nil {
return err
}
}
if err := tmp.Close(); err != nil {
return err
}
if err := os.Chmod(tmp.Name(), 0644); err != nil {
return err
}
return os.Rename(tmp.Name(), filename)
}
// processMetric is an internal helper method only used by the Gather method. // processMetric is an internal helper method only used by the Gather method.
func processMetric( func processMetric(
metric Metric, metric Metric,
@ -836,7 +872,13 @@ func checkMetricConsistency(
h = hashAddByte(h, separatorByte) h = hashAddByte(h, separatorByte)
// Make sure label pairs are sorted. We depend on it for the consistency // Make sure label pairs are sorted. We depend on it for the consistency
// check. // check.
sort.Sort(labelPairSorter(dtoMetric.Label)) if !sort.IsSorted(labelPairSorter(dtoMetric.Label)) {
// We cannot sort dtoMetric.Label in place as it is immutable by contract.
copiedLabels := make([]*dto.LabelPair, len(dtoMetric.Label))
copy(copiedLabels, dtoMetric.Label)
sort.Sort(labelPairSorter(copiedLabels))
dtoMetric.Label = copiedLabels
}
for _, lp := range dtoMetric.Label { for _, lp := range dtoMetric.Label {
h = hashAdd(h, lp.GetName()) h = hashAdd(h, lp.GetName())
h = hashAddByte(h, separatorByte) h = hashAddByte(h, separatorByte)
@ -867,8 +909,8 @@ func checkDescConsistency(
} }
// Is the desc consistent with the content of the metric? // Is the desc consistent with the content of the metric?
lpsFromDesc := make([]*dto.LabelPair, 0, len(dtoMetric.Label)) lpsFromDesc := make([]*dto.LabelPair, len(desc.constLabelPairs), len(dtoMetric.Label))
lpsFromDesc = append(lpsFromDesc, desc.constLabelPairs...) copy(lpsFromDesc, desc.constLabelPairs)
for _, l := range desc.variableLabels { for _, l := range desc.variableLabels {
lpsFromDesc = append(lpsFromDesc, &dto.LabelPair{ lpsFromDesc = append(lpsFromDesc, &dto.LabelPair{
Name: proto.String(l), Name: proto.String(l),

View File

@ -21,9 +21,12 @@ package prometheus_test
import ( import (
"bytes" "bytes"
"fmt"
"io/ioutil"
"math/rand" "math/rand"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os"
"sync" "sync"
"testing" "testing"
"time" "time"
@ -250,7 +253,7 @@ metric: <
}, },
} }
expectedMetricFamilyInvalidLabelValueAsText := []byte(`An error has occurred during metrics gathering: expectedMetricFamilyInvalidLabelValueAsText := []byte(`An error has occurred while serving metrics:
collected metric "name" { label:<name:"constname" value:"\377" > label:<name:"labelname" value:"different_val" > counter:<value:42 > } has a label named "constname" whose value is not utf8: "\xff" collected metric "name" { label:<name:"constname" value:"\377" > label:<name:"labelname" value:"different_val" > counter:<value:42 > } has a label named "constname" whose value is not utf8: "\xff"
`) `)
@ -299,15 +302,15 @@ complex_bucket 1
}, },
}, },
} }
bucketCollisionMsg := []byte(`An error has occurred during metrics gathering: bucketCollisionMsg := []byte(`An error has occurred while serving metrics:
collected metric named "complex_bucket" collides with previously collected histogram named "complex" collected metric named "complex_bucket" collides with previously collected histogram named "complex"
`) `)
summaryCountCollisionMsg := []byte(`An error has occurred during metrics gathering: summaryCountCollisionMsg := []byte(`An error has occurred while serving metrics:
collected metric named "complex_count" collides with previously collected summary named "complex" collected metric named "complex_count" collides with previously collected summary named "complex"
`) `)
histogramCountCollisionMsg := []byte(`An error has occurred during metrics gathering: histogramCountCollisionMsg := []byte(`An error has occurred while serving metrics:
collected metric named "complex_count" collides with previously collected histogram named "complex" collected metric named "complex_count" collides with previously collected histogram named "complex"
`) `)
@ -333,7 +336,7 @@ collected metric named "complex_count" collides with previously collected histog
}, },
}, },
} }
duplicateLabelMsg := []byte(`An error has occurred during metrics gathering: duplicateLabelMsg := []byte(`An error has occurred while serving metrics:
collected metric "broken_metric" { label:<name:"foo" value:"bar" > label:<name:"foo" value:"baz" > counter:<value:2.7 > } has two or more labels with the same name: foo collected metric "broken_metric" { label:<name:"foo" value:"bar" > label:<name:"foo" value:"baz" > counter:<value:2.7 > } has two or more labels with the same name: foo
`) `)
@ -781,6 +784,11 @@ func TestAlreadyRegistered(t *testing.T) {
// same HistogramVec is registered concurrently and the Gather method of the // same HistogramVec is registered concurrently and the Gather method of the
// registry is called concurrently. // registry is called concurrently.
func TestHistogramVecRegisterGatherConcurrency(t *testing.T) { func TestHistogramVecRegisterGatherConcurrency(t *testing.T) {
labelNames := make([]string, 16) // Need at least 13 to expose #512.
for i := range labelNames {
labelNames[i] = fmt.Sprint("label_", i)
}
var ( var (
reg = prometheus.NewPedanticRegistry() reg = prometheus.NewPedanticRegistry()
hv = prometheus.NewHistogramVec( hv = prometheus.NewHistogramVec(
@ -789,7 +797,7 @@ func TestHistogramVecRegisterGatherConcurrency(t *testing.T) {
Help: "This helps testing.", Help: "This helps testing.",
ConstLabels: prometheus.Labels{"foo": "bar"}, ConstLabels: prometheus.Labels{"foo": "bar"},
}, },
[]string{"one", "two", "three"}, labelNames,
) )
labelValues = []string{"a", "b", "c", "alpha", "beta", "gamma", "aleph", "beth", "gimel"} labelValues = []string{"a", "b", "c", "alpha", "beta", "gamma", "aleph", "beth", "gimel"}
quit = make(chan struct{}) quit = make(chan struct{})
@ -804,11 +812,11 @@ func TestHistogramVecRegisterGatherConcurrency(t *testing.T) {
return return
default: default:
obs := rand.NormFloat64()*.1 + .2 obs := rand.NormFloat64()*.1 + .2
hv.WithLabelValues( values := make([]string, 0, len(labelNames))
labelValues[rand.Intn(len(labelValues))], for range labelNames {
labelValues[rand.Intn(len(labelValues))], values = append(values, labelValues[rand.Intn(len(labelValues))])
labelValues[rand.Intn(len(labelValues))], }
).Observe(obs) hv.WithLabelValues(values...).Observe(obs)
} }
} }
} }
@ -846,7 +854,7 @@ func TestHistogramVecRegisterGatherConcurrency(t *testing.T) {
if len(g) != 1 { if len(g) != 1 {
t.Error("Gathered unexpected number of metric families:", len(g)) t.Error("Gathered unexpected number of metric families:", len(g))
} }
if len(g[0].Metric[0].Label) != 4 { if len(g[0].Metric[0].Label) != len(labelNames)+1 {
t.Error("Gathered unexpected number of label pairs:", len(g[0].Metric[0].Label)) t.Error("Gathered unexpected number of label pairs:", len(g[0].Metric[0].Label))
} }
} }
@ -871,3 +879,102 @@ func TestHistogramVecRegisterGatherConcurrency(t *testing.T) {
close(quit) close(quit)
wg.Wait() wg.Wait()
} }
func TestWriteToTextfile(t *testing.T) {
expectedOut := `# HELP test_counter test counter
# TYPE test_counter counter
test_counter{name="qux"} 1
# HELP test_gauge test gauge
# TYPE test_gauge gauge
test_gauge{name="baz"} 1.1
# HELP test_hist test histogram
# TYPE test_hist histogram
test_hist_bucket{name="bar",le="0.005"} 0
test_hist_bucket{name="bar",le="0.01"} 0
test_hist_bucket{name="bar",le="0.025"} 0
test_hist_bucket{name="bar",le="0.05"} 0
test_hist_bucket{name="bar",le="0.1"} 0
test_hist_bucket{name="bar",le="0.25"} 0
test_hist_bucket{name="bar",le="0.5"} 0
test_hist_bucket{name="bar",le="1"} 1
test_hist_bucket{name="bar",le="2.5"} 1
test_hist_bucket{name="bar",le="5"} 2
test_hist_bucket{name="bar",le="10"} 2
test_hist_bucket{name="bar",le="+Inf"} 2
test_hist_sum{name="bar"} 3.64
test_hist_count{name="bar"} 2
# HELP test_summary test summary
# TYPE test_summary summary
test_summary{name="foo",quantile="0.5"} 10
test_summary{name="foo",quantile="0.9"} 20
test_summary{name="foo",quantile="0.99"} 20
test_summary_sum{name="foo"} 30
test_summary_count{name="foo"} 2
`
registry := prometheus.NewRegistry()
summary := prometheus.NewSummaryVec(
prometheus.SummaryOpts{
Name: "test_summary",
Help: "test summary",
},
[]string{"name"},
)
histogram := prometheus.NewHistogramVec(
prometheus.HistogramOpts{
Name: "test_hist",
Help: "test histogram",
},
[]string{"name"},
)
gauge := prometheus.NewGaugeVec(
prometheus.GaugeOpts{
Name: "test_gauge",
Help: "test gauge",
},
[]string{"name"},
)
counter := prometheus.NewCounterVec(
prometheus.CounterOpts{
Name: "test_counter",
Help: "test counter",
},
[]string{"name"},
)
registry.MustRegister(summary)
registry.MustRegister(histogram)
registry.MustRegister(gauge)
registry.MustRegister(counter)
summary.With(prometheus.Labels{"name": "foo"}).Observe(10)
summary.With(prometheus.Labels{"name": "foo"}).Observe(20)
histogram.With(prometheus.Labels{"name": "bar"}).Observe(0.93)
histogram.With(prometheus.Labels{"name": "bar"}).Observe(2.71)
gauge.With(prometheus.Labels{"name": "baz"}).Set(1.1)
counter.With(prometheus.Labels{"name": "qux"}).Inc()
tmpfile, err := ioutil.TempFile("", "prom_registry_test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpfile.Name())
if err := prometheus.WriteToTextfile(tmpfile.Name(), registry); err != nil {
t.Fatal(err)
}
fileBytes, err := ioutil.ReadFile(tmpfile.Name())
if err != nil {
t.Fatal(err)
}
fileContents := string(fileBytes)
if fileContents != expectedOut {
t.Error("file contents didn't match unexpected")
}
}

View File

@ -181,7 +181,7 @@ func NewSummary(opts SummaryOpts) Summary {
func newSummary(desc *Desc, opts SummaryOpts, labelValues ...string) Summary { func newSummary(desc *Desc, opts SummaryOpts, labelValues ...string) Summary {
if len(desc.variableLabels) != len(labelValues) { if len(desc.variableLabels) != len(labelValues) {
panic(errInconsistentCardinality) panic(makeInconsistentCardinalityError(desc.fqName, desc.variableLabels, labelValues))
} }
for _, n := range desc.variableLabels { for _, n := range desc.variableLabels {

View File

@ -37,7 +37,6 @@ import (
"bytes" "bytes"
"fmt" "fmt"
"io" "io"
"reflect"
"github.com/prometheus/common/expfmt" "github.com/prometheus/common/expfmt"
@ -125,38 +124,43 @@ func CollectAndCompare(c prometheus.Collector, expected io.Reader, metricNames .
// exposition format. If any metricNames are provided, only metrics with those // exposition format. If any metricNames are provided, only metrics with those
// names are compared. // names are compared.
func GatherAndCompare(g prometheus.Gatherer, expected io.Reader, metricNames ...string) error { func GatherAndCompare(g prometheus.Gatherer, expected io.Reader, metricNames ...string) error {
metrics, err := g.Gather() got, err := g.Gather()
if err != nil { if err != nil {
return fmt.Errorf("gathering metrics failed: %s", err) return fmt.Errorf("gathering metrics failed: %s", err)
} }
if metricNames != nil { if metricNames != nil {
metrics = filterMetrics(metrics, metricNames) got = filterMetrics(got, metricNames)
} }
var tp expfmt.TextParser var tp expfmt.TextParser
expectedMetrics, err := tp.TextToMetricFamilies(expected) wantRaw, err := tp.TextToMetricFamilies(expected)
if err != nil { if err != nil {
return fmt.Errorf("parsing expected metrics failed: %s", err) return fmt.Errorf("parsing expected metrics failed: %s", err)
} }
want := internal.NormalizeMetricFamilies(wantRaw)
if !reflect.DeepEqual(metrics, internal.NormalizeMetricFamilies(expectedMetrics)) { return compare(got, want)
// Encode the gathered output to the readable text format for comparison. }
var buf1 bytes.Buffer
enc := expfmt.NewEncoder(&buf1, expfmt.FmtText) // compare encodes both provided slices of metric families into the text format,
for _, mf := range metrics { // compares their string message, and returns an error if they do not match.
// The error contains the encoded text of both the desired and the actual
// result.
func compare(got, want []*dto.MetricFamily) error {
var gotBuf, wantBuf bytes.Buffer
enc := expfmt.NewEncoder(&gotBuf, expfmt.FmtText)
for _, mf := range got {
if err := enc.Encode(mf); err != nil { if err := enc.Encode(mf); err != nil {
return fmt.Errorf("encoding result failed: %s", err) return fmt.Errorf("encoding gathered metrics failed: %s", err)
} }
} }
// Encode normalized expected metrics again to generate them in the same ordering enc = expfmt.NewEncoder(&wantBuf, expfmt.FmtText)
// the registry does to spot differences more easily. for _, mf := range want {
var buf2 bytes.Buffer
enc = expfmt.NewEncoder(&buf2, expfmt.FmtText)
for _, mf := range internal.NormalizeMetricFamilies(expectedMetrics) {
if err := enc.Encode(mf); err != nil { if err := enc.Encode(mf); err != nil {
return fmt.Errorf("encoding result failed: %s", err) return fmt.Errorf("encoding expected metrics failed: %s", err)
} }
} }
if wantBuf.String() != gotBuf.String() {
return fmt.Errorf(` return fmt.Errorf(`
metric output does not match expectation; want: metric output does not match expectation; want:
@ -165,7 +169,8 @@ metric output does not match expectation; want:
got: got:
%s %s
`, buf2.String(), buf1.String()) `, wantBuf.String(), gotBuf.String())
} }
return nil return nil
} }

View File

@ -143,6 +143,104 @@ func TestCollectAndCompare(t *testing.T) {
} }
} }
func TestCollectAndCompareNoLabel(t *testing.T) {
const metadata = `
# HELP some_total A value that represents a counter.
# TYPE some_total counter
`
c := prometheus.NewCounter(prometheus.CounterOpts{
Name: "some_total",
Help: "A value that represents a counter.",
})
c.Inc()
expected := `
some_total 1
`
if err := CollectAndCompare(c, strings.NewReader(metadata+expected), "some_total"); err != nil {
t.Errorf("unexpected collecting result:\n%s", err)
}
}
func TestCollectAndCompareHistogram(t *testing.T) {
inputs := []struct {
name string
c prometheus.Collector
metadata string
expect string
labels []string
observation float64
}{
{
name: "Testing Histogram Collector",
c: prometheus.NewHistogram(prometheus.HistogramOpts{
Name: "some_histogram",
Help: "An example of a histogram",
Buckets: []float64{1, 2, 3},
}),
metadata: `
# HELP some_histogram An example of a histogram
# TYPE some_histogram histogram
`,
expect: `
some_histogram{le="1"} 0
some_histogram{le="2"} 0
some_histogram{le="3"} 1
some_histogram_bucket{le="+Inf"} 1
some_histogram_sum 2.5
some_histogram_count 1
`,
observation: 2.5,
},
{
name: "Testing HistogramVec Collector",
c: prometheus.NewHistogramVec(prometheus.HistogramOpts{
Name: "some_histogram",
Help: "An example of a histogram",
Buckets: []float64{1, 2, 3},
}, []string{"test"}),
metadata: `
# HELP some_histogram An example of a histogram
# TYPE some_histogram histogram
`,
expect: `
some_histogram_bucket{test="test",le="1"} 0
some_histogram_bucket{test="test",le="2"} 0
some_histogram_bucket{test="test",le="3"} 1
some_histogram_bucket{test="test",le="+Inf"} 1
some_histogram_sum{test="test"} 2.5
some_histogram_count{test="test"} 1
`,
observation: 2.5,
},
}
for _, input := range inputs {
switch collector := input.c.(type) {
case prometheus.Histogram:
collector.Observe(input.observation)
case *prometheus.HistogramVec:
collector.WithLabelValues("test").Observe(input.observation)
default:
t.Fatalf("unsuported collector tested")
}
t.Run(input.name, func(t *testing.T) {
if err := CollectAndCompare(input.c, strings.NewReader(input.metadata+input.expect)); err != nil {
t.Errorf("unexpected collecting result:\n%s", err)
}
})
}
}
func TestNoMetricFilter(t *testing.T) { func TestNoMetricFilter(t *testing.T) {
const metadata = ` const metadata = `
# HELP some_total A value that represents a counter. # HELP some_total A value that represents a counter.

View File

@ -39,13 +39,16 @@ func NewTimer(o Observer) *Timer {
// ObserveDuration records the duration passed since the Timer was created with // ObserveDuration records the duration passed since the Timer was created with
// NewTimer. It calls the Observe method of the Observer provided during // NewTimer. It calls the Observe method of the Observer provided during
// construction with the duration in seconds as an argument. ObserveDuration is // construction with the duration in seconds as an argument. The observed
// usually called with a defer statement. // duration is also returned. ObserveDuration is usually called with a defer
// statement.
// //
// Note that this method is only guaranteed to never observe negative durations // Note that this method is only guaranteed to never observe negative durations
// if used with Go1.9+. // if used with Go1.9+.
func (t *Timer) ObserveDuration() { func (t *Timer) ObserveDuration() time.Duration {
d := time.Since(t.begin)
if t.observer != nil { if t.observer != nil {
t.observer.Observe(time.Since(t.begin).Seconds()) t.observer.Observe(d.Seconds())
} }
return d
} }

View File

@ -19,7 +19,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/gogo/protobuf/proto" "github.com/golang/protobuf/proto"
dto "github.com/prometheus/client_model/go" dto "github.com/prometheus/client_model/go"
) )

View File

@ -2,13 +2,17 @@ language: go
sudo: false sudo: false
go: matrix:
- "1.8" include:
- "1.9" - go: "1.8.x"
- "1.10" - go: "1.9.x"
- tip - go: "1.10.x"
- go: "1.11.x"
script: env: GO111MODULE=off
- go: "1.11.x"
env: GO111MODULE=on
- go: tip
script:
- ./.travis.gogenerate.sh - ./.travis.gogenerate.sh
- ./.travis.gofmt.sh - ./.travis.gofmt.sh
- ./.travis.govet.sh - ./.travis.govet.sh

View File

@ -1,22 +1,21 @@
Copyright (c) 2012 - 2013 Mat Ryer and Tyler Bunnell MIT License
Please consider promoting this project if you find it useful. Copyright (c) 2012-2018 Mat Ryer and Tyler Bunnell
Permission is hereby granted, free of charge, to any person Permission is hereby granted, free of charge, to any person obtaining a copy
obtaining a copy of this software and associated documentation of this software and associated documentation files (the "Software"), to deal
files (the "Software"), to deal in the Software without restriction, in the Software without restriction, including without limitation the rights
including without limitation the rights to use, copy, modify, merge, to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
publish, distribute, sublicense, and/or sell copies of the Software, copies of the Software, and to permit persons to whom the Software is
and to permit persons to whom the Software is furnished to do so, furnished to do so, subject to the following conditions:
subject to the following conditions:
The above copyright notice and this permission notice shall be included The above copyright notice and this permission notice shall be included in all
in all copies or substantial portions of the Software. copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. SOFTWARE.

View File

@ -287,8 +287,10 @@ To install Testify, use `go get`:
This will then make the following packages available to you: This will then make the following packages available to you:
github.com/stretchr/testify/assert github.com/stretchr/testify/assert
github.com/stretchr/testify/require
github.com/stretchr/testify/mock github.com/stretchr/testify/mock
github.com/stretchr/testify/http github.com/stretchr/testify/suite
github.com/stretchr/testify/http (deprecated)
Import the `testify/assert` package into your code using this template: Import the `testify/assert` package into your code using this template:
@ -319,7 +321,7 @@ To update Testify to the latest version, use `go get -u github.com/stretchr/test
Supported go versions Supported go versions
================== ==================
We support the three major Go versions, which are 1.8, 1.9 and 1.10 at the moment. We support the three major Go versions, which are 1.9, 1.10, and 1.11 at the moment.
------ ------
@ -329,3 +331,10 @@ Contributing
Please feel free to submit issues, fork the repository and send pull requests! Please feel free to submit issues, fork the repository and send pull requests!
When submitting an issue, we ask that you please include a complete test function that demonstrates the issue. Extra credit for those using Testify to write the test code that demonstrates it. When submitting an issue, we ask that you please include a complete test function that demonstrates the issue. Extra credit for those using Testify to write the test code that demonstrates it.
------
License
=======
This project is licensed under the terms of the MIT license.

View File

@ -39,7 +39,7 @@ type ValueAssertionFunc func(TestingT, interface{}, ...interface{}) bool
// for table driven tests. // for table driven tests.
type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool type BoolAssertionFunc func(TestingT, bool, ...interface{}) bool
// ValuesAssertionFunc is a common function prototype when validating an error value. Can be useful // ErrorAssertionFunc is a common function prototype when validating an error value. Can be useful
// for table driven tests. // for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool type ErrorAssertionFunc func(TestingT, error, ...interface{}) bool
@ -179,7 +179,11 @@ func messageFromMsgAndArgs(msgAndArgs ...interface{}) string {
return "" return ""
} }
if len(msgAndArgs) == 1 { if len(msgAndArgs) == 1 {
return msgAndArgs[0].(string) msg := msgAndArgs[0]
if msgAsStr, ok := msg.(string); ok {
return msgAsStr
}
return fmt.Sprintf("%+v", msg)
} }
if len(msgAndArgs) > 1 { if len(msgAndArgs) > 1 {
return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...) return fmt.Sprintf(msgAndArgs[0].(string), msgAndArgs[1:]...)
@ -415,6 +419,17 @@ func NotNil(t TestingT, object interface{}, msgAndArgs ...interface{}) bool {
return Fail(t, "Expected value not to be nil.", msgAndArgs...) return Fail(t, "Expected value not to be nil.", msgAndArgs...)
} }
// containsKind checks if a specified kind in the slice of kinds.
func containsKind(kinds []reflect.Kind, kind reflect.Kind) bool {
for i := 0; i < len(kinds); i++ {
if kind == kinds[i] {
return true
}
}
return false
}
// isNil checks if a specified object is nil or not, without Failing. // isNil checks if a specified object is nil or not, without Failing.
func isNil(object interface{}) bool { func isNil(object interface{}) bool {
if object == nil { if object == nil {
@ -423,7 +438,14 @@ func isNil(object interface{}) bool {
value := reflect.ValueOf(object) value := reflect.ValueOf(object)
kind := value.Kind() kind := value.Kind()
if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() { isNilableKind := containsKind(
[]reflect.Kind{
reflect.Chan, reflect.Func,
reflect.Interface, reflect.Map,
reflect.Ptr, reflect.Slice},
kind)
if isNilableKind && value.IsNil() {
return true return true
} }
@ -1327,7 +1349,7 @@ func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
} }
// diff returns a diff of both values as long as both are of the same type and // diff returns a diff of both values as long as both are of the same type and
// are a struct, map, slice or array. Otherwise it returns an empty string. // are a struct, map, slice, array or string. Otherwise it returns an empty string.
func diff(expected interface{}, actual interface{}) string { func diff(expected interface{}, actual interface{}) string {
if expected == nil || actual == nil { if expected == nil || actual == nil {
return "" return ""
@ -1345,7 +1367,7 @@ func diff(expected interface{}, actual interface{}) string {
} }
var e, a string var e, a string
if ek != reflect.String { if et != reflect.TypeOf("") {
e = spewConfig.Sdump(expected) e = spewConfig.Sdump(expected)
a = spewConfig.Sdump(actual) a = spewConfig.Sdump(actual)
} else { } else {

View File

@ -175,6 +175,8 @@ func TestIsType(t *testing.T) {
} }
type myType string
func TestEqual(t *testing.T) { func TestEqual(t *testing.T) {
mockT := new(testing.T) mockT := new(testing.T)
@ -200,6 +202,9 @@ func TestEqual(t *testing.T) {
if !Equal(mockT, uint64(123), uint64(123)) { if !Equal(mockT, uint64(123), uint64(123)) {
t.Error("Equal should return true") t.Error("Equal should return true")
} }
if !Equal(mockT, myType("1"), myType("1")) {
t.Error("Equal should return true")
}
if !Equal(mockT, &struct{}{}, &struct{}{}) { if !Equal(mockT, &struct{}{}, &struct{}{}) {
t.Error("Equal should return true (pointer equality is based on equality of underlying value)") t.Error("Equal should return true (pointer equality is based on equality of underlying value)")
} }
@ -207,6 +212,9 @@ func TestEqual(t *testing.T) {
if Equal(mockT, m["bar"], "something") { if Equal(mockT, m["bar"], "something") {
t.Error("Equal should return false") t.Error("Equal should return false")
} }
if Equal(mockT, myType("1"), myType("2")) {
t.Error("Equal should return false")
}
} }
// bufferT implements TestingT. Its implementation of Errorf writes the output that would be produced by // bufferT implements TestingT. Its implementation of Errorf writes the output that would be produced by
@ -275,6 +283,8 @@ func TestEqualFormatting(t *testing.T) {
}{ }{
{equalWant: "want", equalGot: "got", want: "\tassertions.go:\\d+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n"}, {equalWant: "want", equalGot: "got", want: "\tassertions.go:\\d+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n"},
{equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+hello, world!\n"}, {equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{"hello, %v!", "world"}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+hello, world!\n"},
{equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{123}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+123\n"},
{equalWant: "want", equalGot: "got", msgAndArgs: []interface{}{struct{ a string }{"hello"}}, want: "\tassertions.go:[0-9]+: \n\t+Error Trace:\t\n\t+Error:\\s+Not equal:\\s+\n\\s+expected: \"want\"\n\\s+actual\\s+: \"got\"\n\\s+Diff:\n\\s+-+ Expected\n\\s+\\++ Actual\n\\s+@@ -1 \\+1 @@\n\\s+-want\n\\s+\\+got\n\\s+Messages:\\s+{a:hello}\n"},
} { } {
mockT := &bufferT{} mockT := &bufferT{}
Equal(mockT, currCase.equalWant, currCase.equalGot, currCase.msgAndArgs...) Equal(mockT, currCase.equalWant, currCase.equalGot, currCase.msgAndArgs...)

7
vendor/github.com/stretchr/testify/go.mod generated vendored Normal file
View File

@ -0,0 +1,7 @@
module github.com/stretchr/testify
require (
github.com/davecgh/go-spew v1.1.0
github.com/pmezard/go-difflib v1.0.0
github.com/stretchr/objx v0.1.0
)

6
vendor/github.com/stretchr/testify/go.sum generated vendored Normal file
View File

@ -0,0 +1,6 @@
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=

View File

@ -176,6 +176,7 @@ func (c *Call) Maybe() *Call {
// Mock. // Mock.
// On("MyMethod", 1).Return(nil). // On("MyMethod", 1).Return(nil).
// On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error")) // On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
//go:noinline
func (c *Call) On(methodName string, arguments ...interface{}) *Call { func (c *Call) On(methodName string, arguments ...interface{}) *Call {
return c.Parent.On(methodName, arguments...) return c.Parent.On(methodName, arguments...)
} }
@ -691,7 +692,7 @@ func (args Arguments) Diff(objects []interface{}) (string, int) {
output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher) output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
} else { } else {
differences++ differences++
output = fmt.Sprintf("%s\t%d: PASS: %s not matched by %s\n", output, i, actualFmt, matcher) output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
} }
} else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() { } else if reflect.TypeOf(expected) == reflect.TypeOf((*AnythingOfTypeArgument)(nil)).Elem() {

View File

@ -32,6 +32,7 @@ func (i *TestExampleImplementation) TheExampleMethod(a, b, c int) (int, error) {
return args.Int(0), errors.New("Whoops") return args.Int(0), errors.New("Whoops")
} }
//go:noinline
func (i *TestExampleImplementation) TheExampleMethod2(yesorno bool) { func (i *TestExampleImplementation) TheExampleMethod2(yesorno bool) {
i.Called(yesorno) i.Called(yesorno)
} }
@ -1492,6 +1493,7 @@ func unexpectedCallRegex(method, calledArg, expectedArg, diff string) string {
rMethod, calledArg, rMethod, expectedArg, diff) rMethod, calledArg, rMethod, expectedArg, diff)
} }
//go:noinline
func ConcurrencyTestMethod(m *Mock) { func ConcurrencyTestMethod(m *Mock) {
m.Called() m.Called()
} }

View File

@ -22,7 +22,7 @@ type ValueAssertionFunc func(TestingT, interface{}, ...interface{})
// for table driven tests. // for table driven tests.
type BoolAssertionFunc func(TestingT, bool, ...interface{}) type BoolAssertionFunc func(TestingT, bool, ...interface{})
// ValuesAssertionFunc is a common function prototype when validating an error value. Can be useful // ErrorAssertionFunc is a common function prototype when validating an error value. Can be useful
// for table driven tests. // for table driven tests.
type ErrorAssertionFunc func(TestingT, error, ...interface{}) type ErrorAssertionFunc func(TestingT, error, ...interface{})

View File

@ -55,10 +55,32 @@ func (suite *Suite) Assert() *assert.Assertions {
return suite.Assertions return suite.Assertions
} }
func failOnPanic(t *testing.T) {
r := recover()
if r != nil {
t.Errorf("test panicked: %v", r)
t.FailNow()
}
}
// Run provides suite functionality around golang subtests. It should be
// called in place of t.Run(name, func(t *testing.T)) in test suite code.
// The passed-in func will be executed as a subtest with a fresh instance of t.
// Provides compatibility with go test pkg -run TestSuite/TestName/SubTestName.
func (suite *Suite) Run(name string, subtest func()) bool {
oldT := suite.T()
defer suite.SetT(oldT)
return oldT.Run(name, func(t *testing.T) {
suite.SetT(t)
subtest()
})
}
// Run takes a testing suite and runs all of the tests attached // Run takes a testing suite and runs all of the tests attached
// to it. // to it.
func Run(t *testing.T, suite TestingSuite) { func Run(t *testing.T, suite TestingSuite) {
suite.SetT(t) suite.SetT(t)
defer failOnPanic(t)
if setupAllSuite, ok := suite.(SetupAllSuite); ok { if setupAllSuite, ok := suite.(SetupAllSuite); ok {
setupAllSuite.SetupSuite() setupAllSuite.SetupSuite()
@ -84,6 +106,8 @@ func Run(t *testing.T, suite TestingSuite) {
F: func(t *testing.T) { F: func(t *testing.T) {
parentT := suite.T() parentT := suite.T()
suite.SetT(t) suite.SetT(t)
defer failOnPanic(t)
if setupTestSuite, ok := suite.(SetupTestSuite); ok { if setupTestSuite, ok := suite.(SetupTestSuite); ok {
setupTestSuite.SetupTest() setupTestSuite.SetupTest()
} }

View File

@ -42,6 +42,99 @@ func (s *SuiteRequireTwice) TestRequireTwo() {
r.Equal(1, 2) r.Equal(1, 2)
} }
type panickingSuite struct {
Suite
panicInSetupSuite bool
panicInSetupTest bool
panicInBeforeTest bool
panicInTest bool
panicInAfterTest bool
panicInTearDownTest bool
panicInTearDownSuite bool
}
func (s *panickingSuite) SetupSuite() {
if s.panicInSetupSuite {
panic("oops in setup suite")
}
}
func (s *panickingSuite) SetupTest() {
if s.panicInSetupTest {
panic("oops in setup test")
}
}
func (s *panickingSuite) BeforeTest(_, _ string) {
if s.panicInBeforeTest {
panic("oops in before test")
}
}
func (s *panickingSuite) Test() {
if s.panicInTest {
panic("oops in test")
}
}
func (s *panickingSuite) AfterTest(_, _ string) {
if s.panicInAfterTest {
panic("oops in after test")
}
}
func (s *panickingSuite) TearDownTest() {
if s.panicInTearDownTest {
panic("oops in tear down test")
}
}
func (s *panickingSuite) TearDownSuite() {
if s.panicInTearDownSuite {
panic("oops in tear down suite")
}
}
func TestSuiteRecoverPanic(t *testing.T) {
ok := true
panickingTests := []testing.InternalTest{
{
Name: "TestPanicInSetupSuite",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInSetupSuite: true}) },
},
{
Name: "TestPanicInSetupTest",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInSetupTest: true}) },
},
{
Name: "TestPanicInBeforeTest",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInBeforeTest: true}) },
},
{
Name: "TestPanicInTest",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInTest: true}) },
},
{
Name: "TestPanicInAfterTest",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInAfterTest: true}) },
},
{
Name: "TestPanicInTearDownTest",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInTearDownTest: true}) },
},
{
Name: "TestPanicInTearDownSuite",
F: func(t *testing.T) { Run(t, &panickingSuite{panicInTearDownSuite: true}) },
},
}
require.NotPanics(t, func() {
ok = testing.RunTests(allTestsFilter, panickingTests)
})
assert.False(t, ok)
}
// This suite is intended to store values to make sure that only // This suite is intended to store values to make sure that only
// testing-suite-related methods are run. It's also a fully // testing-suite-related methods are run. It's also a fully
// functional example of a testing suite, using setup/teardown methods // functional example of a testing suite, using setup/teardown methods
@ -59,6 +152,7 @@ type SuiteTester struct {
TearDownTestRunCount int TearDownTestRunCount int
TestOneRunCount int TestOneRunCount int
TestTwoRunCount int TestTwoRunCount int
TestSubtestRunCount int
NonTestMethodRunCount int NonTestMethodRunCount int
SuiteNameBefore []string SuiteNameBefore []string
@ -153,6 +247,27 @@ func (suite *SuiteTester) NonTestMethod() {
suite.NonTestMethodRunCount++ suite.NonTestMethodRunCount++
} }
func (suite *SuiteTester) TestSubtest() {
suite.TestSubtestRunCount++
for _, t := range []struct {
testName string
}{
{"first"},
{"second"},
} {
suiteT := suite.T()
suite.Run(t.testName, func() {
// We should get a different *testing.T for subtests, so that
// go test recognizes them as proper subtests for output formatting
// and running individual subtests
subTestT := suite.T()
suite.NotEqual(subTestT, suiteT)
})
suite.Equal(suiteT, suite.T())
}
}
// TestRunSuite will be run by the 'go test' command, so within it, we // TestRunSuite will be run by the 'go test' command, so within it, we
// can run our suite using the Run(*testing.T, TestingSuite) function. // can run our suite using the Run(*testing.T, TestingSuite) function.
func TestRunSuite(t *testing.T) { func TestRunSuite(t *testing.T) {
@ -168,18 +283,20 @@ func TestRunSuite(t *testing.T) {
assert.Equal(t, suiteTester.SetupSuiteRunCount, 1) assert.Equal(t, suiteTester.SetupSuiteRunCount, 1)
assert.Equal(t, suiteTester.TearDownSuiteRunCount, 1) assert.Equal(t, suiteTester.TearDownSuiteRunCount, 1)
assert.Equal(t, len(suiteTester.SuiteNameAfter), 3) assert.Equal(t, len(suiteTester.SuiteNameAfter), 4)
assert.Equal(t, len(suiteTester.SuiteNameBefore), 3) assert.Equal(t, len(suiteTester.SuiteNameBefore), 4)
assert.Equal(t, len(suiteTester.TestNameAfter), 3) assert.Equal(t, len(suiteTester.TestNameAfter), 4)
assert.Equal(t, len(suiteTester.TestNameBefore), 3) assert.Equal(t, len(suiteTester.TestNameBefore), 4)
assert.Contains(t, suiteTester.TestNameAfter, "TestOne") assert.Contains(t, suiteTester.TestNameAfter, "TestOne")
assert.Contains(t, suiteTester.TestNameAfter, "TestTwo") assert.Contains(t, suiteTester.TestNameAfter, "TestTwo")
assert.Contains(t, suiteTester.TestNameAfter, "TestSkip") assert.Contains(t, suiteTester.TestNameAfter, "TestSkip")
assert.Contains(t, suiteTester.TestNameAfter, "TestSubtest")
assert.Contains(t, suiteTester.TestNameBefore, "TestOne") assert.Contains(t, suiteTester.TestNameBefore, "TestOne")
assert.Contains(t, suiteTester.TestNameBefore, "TestTwo") assert.Contains(t, suiteTester.TestNameBefore, "TestTwo")
assert.Contains(t, suiteTester.TestNameBefore, "TestSkip") assert.Contains(t, suiteTester.TestNameBefore, "TestSkip")
assert.Contains(t, suiteTester.TestNameBefore, "TestSubtest")
for _, suiteName := range suiteTester.SuiteNameAfter { for _, suiteName := range suiteTester.SuiteNameAfter {
assert.Equal(t, "SuiteTester", suiteName) assert.Equal(t, "SuiteTester", suiteName)
@ -197,15 +314,16 @@ func TestRunSuite(t *testing.T) {
assert.False(t, when.IsZero()) assert.False(t, when.IsZero())
} }
// There are three test methods (TestOne, TestTwo, and TestSkip), so // There are four test methods (TestOne, TestTwo, TestSkip, and TestSubtest), so
// the SetupTest and TearDownTest methods (which should be run once for // the SetupTest and TearDownTest methods (which should be run once for
// each test) should have been run three times. // each test) should have been run four times.
assert.Equal(t, suiteTester.SetupTestRunCount, 3) assert.Equal(t, suiteTester.SetupTestRunCount, 4)
assert.Equal(t, suiteTester.TearDownTestRunCount, 3) assert.Equal(t, suiteTester.TearDownTestRunCount, 4)
// Each test should have been run once. // Each test should have been run once.
assert.Equal(t, suiteTester.TestOneRunCount, 1) assert.Equal(t, suiteTester.TestOneRunCount, 1)
assert.Equal(t, suiteTester.TestTwoRunCount, 1) assert.Equal(t, suiteTester.TestTwoRunCount, 1)
assert.Equal(t, suiteTester.TestSubtestRunCount, 1)
// Methods that don't match the test method identifier shouldn't // Methods that don't match the test method identifier shouldn't
// have been run at all. // have been run at all.

3
vendor/golang.org/x/sync/AUTHORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code refers to The Go Authors for copyright purposes.
# The master list of authors is in the main Go distribution,
# visible at http://tip.golang.org/AUTHORS.

26
vendor/golang.org/x/sync/CONTRIBUTING.md generated vendored Normal file
View File

@ -0,0 +1,26 @@
# Contributing to Go
Go is an open source project.
It is the work of hundreds of contributors. We appreciate your help!
## Filing issues
When [filing an issue](https://golang.org/issue/new), make sure to answer these five questions:
1. What version of Go are you using (`go version`)?
2. What operating system and processor architecture are you using?
3. What did you do?
4. What did you expect to see?
5. What did you see instead?
General questions should go to the [golang-nuts mailing list](https://groups.google.com/group/golang-nuts) instead of the issue tracker.
The gophers there will answer or ask you to file an issue if you've tripped over a bug.
## Contributing code
Please read the [Contribution Guidelines](https://golang.org/doc/contribute.html)
before sending patches.
Unless otherwise noted, the Go source files are distributed under
the BSD-style license found in the LICENSE file.

3
vendor/golang.org/x/sync/CONTRIBUTORS generated vendored Normal file
View File

@ -0,0 +1,3 @@
# This source code was written by the Go contributors.
# The master list of contributors is in the main Go distribution,
# visible at http://tip.golang.org/CONTRIBUTORS.

27
vendor/golang.org/x/sync/LICENSE generated vendored Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2009 The Go Authors. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above
copyright notice, this list of conditions and the following disclaimer
in the documentation and/or other materials provided with the
distribution.
* Neither the name of Google Inc. nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

22
vendor/golang.org/x/sync/PATENTS generated vendored Normal file
View File

@ -0,0 +1,22 @@
Additional IP Rights Grant (Patents)
"This implementation" means the copyrightable works distributed by
Google as part of the Go project.
Google hereby grants to You a perpetual, worldwide, non-exclusive,
no-charge, royalty-free, irrevocable (except as stated in this section)
patent license to make, have made, use, offer to sell, sell, import,
transfer and otherwise run, modify and propagate the contents of this
implementation of Go, where such license applies only to those patent
claims, both currently owned or controlled by Google and acquired in
the future, licensable by Google that are necessarily infringed by this
implementation of Go. This grant does not include claims that would be
infringed only as a consequence of further modification of this
implementation. If you or your agent or exclusive licensee institute or
order or agree to the institution of patent litigation against any
entity (including a cross-claim or counterclaim in a lawsuit) alleging
that this implementation of Go or any code incorporated within this
implementation of Go constitutes direct or contributory patent
infringement, or inducement of patent infringement, then any patent
rights granted to you under this License for this implementation of Go
shall terminate as of the date such litigation is filed.

18
vendor/golang.org/x/sync/README.md generated vendored Normal file
View File

@ -0,0 +1,18 @@
# Go Sync
This repository provides Go concurrency primitives in addition to the
ones provided by the language and "sync" and "sync/atomic" packages.
## Download/Install
The easiest way to install is to run `go get -u golang.org/x/sync`. You can
also manually git clone the repository to `$GOPATH/src/golang.org/x/sync`.
## Report Issues / Send Patches
This repository uses Gerrit for code changes. To learn how to submit changes to
this repository, see https://golang.org/doc/contribute.html.
The main issue tracker for the sync repository is located at
https://github.com/golang/go/issues. Prefix your issue with "x/sync:" in the
subject line, so it is easy to find.

1
vendor/golang.org/x/sync/codereview.cfg generated vendored Normal file
View File

@ -0,0 +1 @@
issuerepo: golang/go

66
vendor/golang.org/x/sync/errgroup/errgroup.go generated vendored Normal file
View File

@ -0,0 +1,66 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package errgroup provides synchronization, error propagation, and Context
// cancelation for groups of goroutines working on subtasks of a common task.
package errgroup
import (
"context"
"sync"
)
// A Group is a collection of goroutines working on subtasks that are part of
// the same overall task.
//
// A zero Group is valid and does not cancel on error.
type Group struct {
cancel func()
wg sync.WaitGroup
errOnce sync.Once
err error
}
// WithContext returns a new Group and an associated Context derived from ctx.
//
// The derived Context is canceled the first time a function passed to Go
// returns a non-nil error or the first time Wait returns, whichever occurs
// first.
func WithContext(ctx context.Context) (*Group, context.Context) {
ctx, cancel := context.WithCancel(ctx)
return &Group{cancel: cancel}, ctx
}
// Wait blocks until all function calls from the Go method have returned, then
// returns the first non-nil error (if any) from them.
func (g *Group) Wait() error {
g.wg.Wait()
if g.cancel != nil {
g.cancel()
}
return g.err
}
// Go calls the given function in a new goroutine.
//
// The first call to return a non-nil error cancels the group; its error will be
// returned by Wait.
func (g *Group) Go(f func() error) {
g.wg.Add(1)
go func() {
defer g.wg.Done()
if err := f(); err != nil {
g.errOnce.Do(func() {
g.err = err
if g.cancel != nil {
g.cancel()
}
})
}
}()
}

View File

@ -0,0 +1,101 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errgroup_test
import (
"context"
"crypto/md5"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"golang.org/x/sync/errgroup"
)
// Pipeline demonstrates the use of a Group to implement a multi-stage
// pipeline: a version of the MD5All function with bounded parallelism from
// https://blog.golang.org/pipelines.
func ExampleGroup_pipeline() {
m, err := MD5All(context.Background(), ".")
if err != nil {
log.Fatal(err)
}
for k, sum := range m {
fmt.Printf("%s:\t%x\n", k, sum)
}
}
type result struct {
path string
sum [md5.Size]byte
}
// MD5All reads all the files in the file tree rooted at root and returns a map
// from file path to the MD5 sum of the file's contents. If the directory walk
// fails or any read operation fails, MD5All returns an error.
func MD5All(ctx context.Context, root string) (map[string][md5.Size]byte, error) {
// ctx is canceled when g.Wait() returns. When this version of MD5All returns
// - even in case of error! - we know that all of the goroutines have finished
// and the memory they were using can be garbage-collected.
g, ctx := errgroup.WithContext(ctx)
paths := make(chan string)
g.Go(func() error {
defer close(paths)
return filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.Mode().IsRegular() {
return nil
}
select {
case paths <- path:
case <-ctx.Done():
return ctx.Err()
}
return nil
})
})
// Start a fixed number of goroutines to read and digest files.
c := make(chan result)
const numDigesters = 20
for i := 0; i < numDigesters; i++ {
g.Go(func() error {
for path := range paths {
data, err := ioutil.ReadFile(path)
if err != nil {
return err
}
select {
case c <- result{path, md5.Sum(data)}:
case <-ctx.Done():
return ctx.Err()
}
}
return nil
})
}
go func() {
g.Wait()
close(c)
}()
m := make(map[string][md5.Size]byte)
for r := range c {
m[r.path] = r.sum
}
// Check whether any of the goroutines failed. Since g is accumulating the
// errors, we don't need to send them (or check for them) in the individual
// results sent on the channel.
if err := g.Wait(); err != nil {
return nil, err
}
return m, nil
}

176
vendor/golang.org/x/sync/errgroup/errgroup_test.go generated vendored Normal file
View File

@ -0,0 +1,176 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package errgroup_test
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"testing"
"golang.org/x/sync/errgroup"
)
var (
Web = fakeSearch("web")
Image = fakeSearch("image")
Video = fakeSearch("video")
)
type Result string
type Search func(ctx context.Context, query string) (Result, error)
func fakeSearch(kind string) Search {
return func(_ context.Context, query string) (Result, error) {
return Result(fmt.Sprintf("%s result for %q", kind, query)), nil
}
}
// JustErrors illustrates the use of a Group in place of a sync.WaitGroup to
// simplify goroutine counting and error handling. This example is derived from
// the sync.WaitGroup example at https://golang.org/pkg/sync/#example_WaitGroup.
func ExampleGroup_justErrors() {
var g errgroup.Group
var urls = []string{
"http://www.golang.org/",
"http://www.google.com/",
"http://www.somestupidname.com/",
}
for _, url := range urls {
// Launch a goroutine to fetch the URL.
url := url // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
// Fetch the URL.
resp, err := http.Get(url)
if err == nil {
resp.Body.Close()
}
return err
})
}
// Wait for all HTTP fetches to complete.
if err := g.Wait(); err == nil {
fmt.Println("Successfully fetched all URLs.")
}
}
// Parallel illustrates the use of a Group for synchronizing a simple parallel
// task: the "Google Search 2.0" function from
// https://talks.golang.org/2012/concurrency.slide#46, augmented with a Context
// and error-handling.
func ExampleGroup_parallel() {
Google := func(ctx context.Context, query string) ([]Result, error) {
g, ctx := errgroup.WithContext(ctx)
searches := []Search{Web, Image, Video}
results := make([]Result, len(searches))
for i, search := range searches {
i, search := i, search // https://golang.org/doc/faq#closures_and_goroutines
g.Go(func() error {
result, err := search(ctx, query)
if err == nil {
results[i] = result
}
return err
})
}
if err := g.Wait(); err != nil {
return nil, err
}
return results, nil
}
results, err := Google(context.Background(), "golang")
if err != nil {
fmt.Fprintln(os.Stderr, err)
return
}
for _, result := range results {
fmt.Println(result)
}
// Output:
// web result for "golang"
// image result for "golang"
// video result for "golang"
}
func TestZeroGroup(t *testing.T) {
err1 := errors.New("errgroup_test: 1")
err2 := errors.New("errgroup_test: 2")
cases := []struct {
errs []error
}{
{errs: []error{}},
{errs: []error{nil}},
{errs: []error{err1}},
{errs: []error{err1, nil}},
{errs: []error{err1, nil, err2}},
}
for _, tc := range cases {
var g errgroup.Group
var firstErr error
for i, err := range tc.errs {
err := err
g.Go(func() error { return err })
if firstErr == nil && err != nil {
firstErr = err
}
if gErr := g.Wait(); gErr != firstErr {
t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
"g.Wait() = %v; want %v",
g, tc.errs[:i+1], err, firstErr)
}
}
}
}
func TestWithContext(t *testing.T) {
errDoom := errors.New("group_test: doomed")
cases := []struct {
errs []error
want error
}{
{want: nil},
{errs: []error{nil}, want: nil},
{errs: []error{errDoom}, want: errDoom},
{errs: []error{errDoom, nil}, want: errDoom},
}
for _, tc := range cases {
g, ctx := errgroup.WithContext(context.Background())
for _, err := range tc.errs {
err := err
g.Go(func() error { return err })
}
if err := g.Wait(); err != tc.want {
t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
"g.Wait() = %v; want %v",
g, tc.errs, err, tc.want)
}
canceled := false
select {
case <-ctx.Done():
canceled = true
default:
}
if !canceled {
t.Errorf("after %T.Go(func() error { return err }) for err in %v\n"+
"ctx.Done() was not closed",
g, tc.errs)
}
}
}

127
vendor/golang.org/x/sync/semaphore/semaphore.go generated vendored Normal file
View File

@ -0,0 +1,127 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package semaphore provides a weighted semaphore implementation.
package semaphore // import "golang.org/x/sync/semaphore"
import (
"container/list"
"context"
"sync"
)
type waiter struct {
n int64
ready chan<- struct{} // Closed when semaphore acquired.
}
// NewWeighted creates a new weighted semaphore with the given
// maximum combined weight for concurrent access.
func NewWeighted(n int64) *Weighted {
w := &Weighted{size: n}
return w
}
// Weighted provides a way to bound concurrent access to a resource.
// The callers can request access with a given weight.
type Weighted struct {
size int64
cur int64
mu sync.Mutex
waiters list.List
}
// Acquire acquires the semaphore with a weight of n, blocking until resources
// are available or ctx is done. On success, returns nil. On failure, returns
// ctx.Err() and leaves the semaphore unchanged.
//
// If ctx is already done, Acquire may still succeed without blocking.
func (s *Weighted) Acquire(ctx context.Context, n int64) error {
s.mu.Lock()
if s.size-s.cur >= n && s.waiters.Len() == 0 {
s.cur += n
s.mu.Unlock()
return nil
}
if n > s.size {
// Don't make other Acquire calls block on one that's doomed to fail.
s.mu.Unlock()
<-ctx.Done()
return ctx.Err()
}
ready := make(chan struct{})
w := waiter{n: n, ready: ready}
elem := s.waiters.PushBack(w)
s.mu.Unlock()
select {
case <-ctx.Done():
err := ctx.Err()
s.mu.Lock()
select {
case <-ready:
// Acquired the semaphore after we were canceled. Rather than trying to
// fix up the queue, just pretend we didn't notice the cancelation.
err = nil
default:
s.waiters.Remove(elem)
}
s.mu.Unlock()
return err
case <-ready:
return nil
}
}
// TryAcquire acquires the semaphore with a weight of n without blocking.
// On success, returns true. On failure, returns false and leaves the semaphore unchanged.
func (s *Weighted) TryAcquire(n int64) bool {
s.mu.Lock()
success := s.size-s.cur >= n && s.waiters.Len() == 0
if success {
s.cur += n
}
s.mu.Unlock()
return success
}
// Release releases the semaphore with a weight of n.
func (s *Weighted) Release(n int64) {
s.mu.Lock()
s.cur -= n
if s.cur < 0 {
s.mu.Unlock()
panic("semaphore: bad release")
}
for {
next := s.waiters.Front()
if next == nil {
break // No more waiters blocked.
}
w := next.Value.(waiter)
if s.size-s.cur < w.n {
// Not enough tokens for the next waiter. We could keep going (to try to
// find a waiter with a smaller request), but under load that could cause
// starvation for large requests; instead, we leave all remaining waiters
// blocked.
//
// Consider a semaphore used as a read-write lock, with N tokens, N
// readers, and one writer. Each reader can Acquire(1) to obtain a read
// lock. The writer can Acquire(N) to obtain a write lock, excluding all
// of the readers. If we allow the readers to jump ahead in the queue,
// the writer will starve — there is always one token available for every
// reader.
break
}
s.cur += w.n
s.waiters.Remove(next)
close(w.ready)
}
s.mu.Unlock()
}

View File

@ -0,0 +1,131 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.7
package semaphore_test
import (
"context"
"fmt"
"testing"
"golang.org/x/sync/semaphore"
)
// weighted is an interface matching a subset of *Weighted. It allows
// alternate implementations for testing and benchmarking.
type weighted interface {
Acquire(context.Context, int64) error
TryAcquire(int64) bool
Release(int64)
}
// semChan implements Weighted using a channel for
// comparing against the condition variable-based implementation.
type semChan chan struct{}
func newSemChan(n int64) semChan {
return semChan(make(chan struct{}, n))
}
func (s semChan) Acquire(_ context.Context, n int64) error {
for i := int64(0); i < n; i++ {
s <- struct{}{}
}
return nil
}
func (s semChan) TryAcquire(n int64) bool {
if int64(len(s))+n > int64(cap(s)) {
return false
}
for i := int64(0); i < n; i++ {
s <- struct{}{}
}
return true
}
func (s semChan) Release(n int64) {
for i := int64(0); i < n; i++ {
<-s
}
}
// acquireN calls Acquire(size) on sem N times and then calls Release(size) N times.
func acquireN(b *testing.B, sem weighted, size int64, N int) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < N; j++ {
sem.Acquire(context.Background(), size)
}
for j := 0; j < N; j++ {
sem.Release(size)
}
}
}
// tryAcquireN calls TryAcquire(size) on sem N times and then calls Release(size) N times.
func tryAcquireN(b *testing.B, sem weighted, size int64, N int) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
for j := 0; j < N; j++ {
if !sem.TryAcquire(size) {
b.Fatalf("TryAcquire(%v) = false, want true", size)
}
}
for j := 0; j < N; j++ {
sem.Release(size)
}
}
}
func BenchmarkNewSeq(b *testing.B) {
for _, cap := range []int64{1, 128} {
b.Run(fmt.Sprintf("Weighted-%d", cap), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = semaphore.NewWeighted(cap)
}
})
b.Run(fmt.Sprintf("semChan-%d", cap), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = newSemChan(cap)
}
})
}
}
func BenchmarkAcquireSeq(b *testing.B) {
for _, c := range []struct {
cap, size int64
N int
}{
{1, 1, 1},
{2, 1, 1},
{16, 1, 1},
{128, 1, 1},
{2, 2, 1},
{16, 2, 8},
{128, 2, 64},
{2, 1, 2},
{16, 8, 2},
{128, 64, 2},
} {
for _, w := range []struct {
name string
w weighted
}{
{"Weighted", semaphore.NewWeighted(c.cap)},
{"semChan", newSemChan(c.cap)},
} {
b.Run(fmt.Sprintf("%s-acquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) {
acquireN(b, w.w, c.size, c.N)
})
b.Run(fmt.Sprintf("%s-tryAcquire-%d-%d-%d", w.name, c.cap, c.size, c.N), func(b *testing.B) {
tryAcquireN(b, w.w, c.size, c.N)
})
}
}
}

View File

@ -0,0 +1,84 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package semaphore_test
import (
"context"
"fmt"
"log"
"runtime"
"golang.org/x/sync/semaphore"
)
// Example_workerPool demonstrates how to use a semaphore to limit the number of
// goroutines working on parallel tasks.
//
// This use of a semaphore mimics a typical “worker pool” pattern, but without
// the need to explicitly shut down idle workers when the work is done.
func Example_workerPool() {
ctx := context.TODO()
var (
maxWorkers = runtime.GOMAXPROCS(0)
sem = semaphore.NewWeighted(int64(maxWorkers))
out = make([]int, 32)
)
// Compute the output using up to maxWorkers goroutines at a time.
for i := range out {
// When maxWorkers goroutines are in flight, Acquire blocks until one of the
// workers finishes.
if err := sem.Acquire(ctx, 1); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
break
}
go func(i int) {
defer sem.Release(1)
out[i] = collatzSteps(i + 1)
}(i)
}
// Acquire all of the tokens to wait for any remaining workers to finish.
//
// If you are already waiting for the workers by some other means (such as an
// errgroup.Group), you can omit this final Acquire call.
if err := sem.Acquire(ctx, int64(maxWorkers)); err != nil {
log.Printf("Failed to acquire semaphore: %v", err)
}
fmt.Println(out)
// Output:
// [0 1 7 2 5 8 16 3 19 6 14 9 9 17 17 4 12 20 20 7 7 15 15 10 23 10 111 18 18 18 106 5]
}
// collatzSteps computes the number of steps to reach 1 under the Collatz
// conjecture. (See https://en.wikipedia.org/wiki/Collatz_conjecture.)
func collatzSteps(n int) (steps int) {
if n <= 0 {
panic("nonpositive input")
}
for ; n > 1; steps++ {
if steps < 0 {
panic("too many steps")
}
if n%2 == 0 {
n /= 2
continue
}
const maxInt = int(^uint(0) >> 1)
if n > (maxInt-1)/3 {
panic("overflow")
}
n = 3*n + 1
}
return steps
}

171
vendor/golang.org/x/sync/semaphore/semaphore_test.go generated vendored Normal file
View File

@ -0,0 +1,171 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package semaphore_test
import (
"context"
"math/rand"
"runtime"
"sync"
"testing"
"time"
"golang.org/x/sync/errgroup"
"golang.org/x/sync/semaphore"
)
const maxSleep = 1 * time.Millisecond
func HammerWeighted(sem *semaphore.Weighted, n int64, loops int) {
for i := 0; i < loops; i++ {
sem.Acquire(context.Background(), n)
time.Sleep(time.Duration(rand.Int63n(int64(maxSleep/time.Nanosecond))) * time.Nanosecond)
sem.Release(n)
}
}
func TestWeighted(t *testing.T) {
t.Parallel()
n := runtime.GOMAXPROCS(0)
loops := 10000 / n
sem := semaphore.NewWeighted(int64(n))
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
i := i
go func() {
defer wg.Done()
HammerWeighted(sem, int64(i), loops)
}()
}
wg.Wait()
}
func TestWeightedPanic(t *testing.T) {
t.Parallel()
defer func() {
if recover() == nil {
t.Fatal("release of an unacquired weighted semaphore did not panic")
}
}()
w := semaphore.NewWeighted(1)
w.Release(1)
}
func TestWeightedTryAcquire(t *testing.T) {
t.Parallel()
ctx := context.Background()
sem := semaphore.NewWeighted(2)
tries := []bool{}
sem.Acquire(ctx, 1)
tries = append(tries, sem.TryAcquire(1))
tries = append(tries, sem.TryAcquire(1))
sem.Release(2)
tries = append(tries, sem.TryAcquire(1))
sem.Acquire(ctx, 1)
tries = append(tries, sem.TryAcquire(1))
want := []bool{true, false, true, false}
for i := range tries {
if tries[i] != want[i] {
t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i])
}
}
}
func TestWeightedAcquire(t *testing.T) {
t.Parallel()
ctx := context.Background()
sem := semaphore.NewWeighted(2)
tryAcquire := func(n int64) bool {
ctx, cancel := context.WithTimeout(ctx, 10*time.Millisecond)
defer cancel()
return sem.Acquire(ctx, n) == nil
}
tries := []bool{}
sem.Acquire(ctx, 1)
tries = append(tries, tryAcquire(1))
tries = append(tries, tryAcquire(1))
sem.Release(2)
tries = append(tries, tryAcquire(1))
sem.Acquire(ctx, 1)
tries = append(tries, tryAcquire(1))
want := []bool{true, false, true, false}
for i := range tries {
if tries[i] != want[i] {
t.Errorf("tries[%d]: got %t, want %t", i, tries[i], want[i])
}
}
}
func TestWeightedDoesntBlockIfTooBig(t *testing.T) {
t.Parallel()
const n = 2
sem := semaphore.NewWeighted(n)
{
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go sem.Acquire(ctx, n+1)
}
g, ctx := errgroup.WithContext(context.Background())
for i := n * 3; i > 0; i-- {
g.Go(func() error {
err := sem.Acquire(ctx, 1)
if err == nil {
time.Sleep(1 * time.Millisecond)
sem.Release(1)
}
return err
})
}
if err := g.Wait(); err != nil {
t.Errorf("semaphore.NewWeighted(%v) failed to AcquireCtx(_, 1) with AcquireCtx(_, %v) pending", n, n+1)
}
}
// TestLargeAcquireDoesntStarve times out if a large call to Acquire starves.
// Merely returning from the test function indicates success.
func TestLargeAcquireDoesntStarve(t *testing.T) {
t.Parallel()
ctx := context.Background()
n := int64(runtime.GOMAXPROCS(0))
sem := semaphore.NewWeighted(n)
running := true
var wg sync.WaitGroup
wg.Add(int(n))
for i := n; i > 0; i-- {
sem.Acquire(ctx, 1)
go func() {
defer func() {
sem.Release(1)
wg.Done()
}()
for running {
time.Sleep(1 * time.Millisecond)
sem.Release(1)
sem.Acquire(ctx, 1)
}
}()
}
sem.Acquire(ctx, n)
running = false
sem.Release(n)
wg.Wait()
}

111
vendor/golang.org/x/sync/singleflight/singleflight.go generated vendored Normal file
View File

@ -0,0 +1,111 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package singleflight provides a duplicate function call suppression
// mechanism.
package singleflight // import "golang.org/x/sync/singleflight"
import "sync"
// call is an in-flight or completed singleflight.Do call
type call struct {
wg sync.WaitGroup
// These fields are written once before the WaitGroup is done
// and are only read after the WaitGroup is done.
val interface{}
err error
// These fields are read and written with the singleflight
// mutex held before the WaitGroup is done, and are read but
// not written after the WaitGroup is done.
dups int
chans []chan<- Result
}
// Group represents a class of work and forms a namespace in
// which units of work can be executed with duplicate suppression.
type Group struct {
mu sync.Mutex // protects m
m map[string]*call // lazily initialized
}
// Result holds the results of Do, so they can be passed
// on a channel.
type Result struct {
Val interface{}
Err error
Shared bool
}
// Do executes and returns the results of the given function, making
// sure that only one execution is in-flight for a given key at a
// time. If a duplicate comes in, the duplicate caller waits for the
// original to complete and receives the same results.
// The return value shared indicates whether v was given to multiple callers.
func (g *Group) Do(key string, fn func() (interface{}, error)) (v interface{}, err error, shared bool) {
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
g.mu.Unlock()
c.wg.Wait()
return c.val, c.err, true
}
c := new(call)
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
g.doCall(c, key, fn)
return c.val, c.err, c.dups > 0
}
// DoChan is like Do but returns a channel that will receive the
// results when they are ready.
func (g *Group) DoChan(key string, fn func() (interface{}, error)) <-chan Result {
ch := make(chan Result, 1)
g.mu.Lock()
if g.m == nil {
g.m = make(map[string]*call)
}
if c, ok := g.m[key]; ok {
c.dups++
c.chans = append(c.chans, ch)
g.mu.Unlock()
return ch
}
c := &call{chans: []chan<- Result{ch}}
c.wg.Add(1)
g.m[key] = c
g.mu.Unlock()
go g.doCall(c, key, fn)
return ch
}
// doCall handles the single call for a key.
func (g *Group) doCall(c *call, key string, fn func() (interface{}, error)) {
c.val, c.err = fn()
c.wg.Done()
g.mu.Lock()
delete(g.m, key)
for _, ch := range c.chans {
ch <- Result{c.val, c.err, c.dups > 0}
}
g.mu.Unlock()
}
// Forget tells the singleflight to forget about a key. Future calls
// to Do for this key will call the function rather than waiting for
// an earlier call to complete.
func (g *Group) Forget(key string) {
g.mu.Lock()
delete(g.m, key)
g.mu.Unlock()
}

View File

@ -0,0 +1,87 @@
// Copyright 2013 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package singleflight
import (
"errors"
"fmt"
"sync"
"sync/atomic"
"testing"
"time"
)
func TestDo(t *testing.T) {
var g Group
v, err, _ := g.Do("key", func() (interface{}, error) {
return "bar", nil
})
if got, want := fmt.Sprintf("%v (%T)", v, v), "bar (string)"; got != want {
t.Errorf("Do = %v; want %v", got, want)
}
if err != nil {
t.Errorf("Do error = %v", err)
}
}
func TestDoErr(t *testing.T) {
var g Group
someErr := errors.New("Some error")
v, err, _ := g.Do("key", func() (interface{}, error) {
return nil, someErr
})
if err != someErr {
t.Errorf("Do error = %v; want someErr %v", err, someErr)
}
if v != nil {
t.Errorf("unexpected non-nil value %#v", v)
}
}
func TestDoDupSuppress(t *testing.T) {
var g Group
var wg1, wg2 sync.WaitGroup
c := make(chan string, 1)
var calls int32
fn := func() (interface{}, error) {
if atomic.AddInt32(&calls, 1) == 1 {
// First invocation.
wg1.Done()
}
v := <-c
c <- v // pump; make available for any future calls
time.Sleep(10 * time.Millisecond) // let more goroutines enter Do
return v, nil
}
const n = 10
wg1.Add(1)
for i := 0; i < n; i++ {
wg1.Add(1)
wg2.Add(1)
go func() {
defer wg2.Done()
wg1.Done()
v, err, _ := g.Do("key", fn)
if err != nil {
t.Errorf("Do error: %v", err)
return
}
if s, _ := v.(string); s != "bar" {
t.Errorf("Do = %T %v; want %q", v, v, "bar")
}
}()
}
wg1.Wait()
// At least one goroutine is in fn now and all of them have at
// least reached the line before the Do.
c <- "bar"
wg2.Wait()
if got := atomic.LoadInt32(&calls); got <= 0 || got >= n {
t.Errorf("number of calls = %d; want over 0 and less than %d", got, n)
}
}

372
vendor/golang.org/x/sync/syncmap/map.go generated vendored Normal file
View File

@ -0,0 +1,372 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package syncmap provides a concurrent map implementation.
// It is a prototype for a proposed addition to the sync package
// in the standard library.
// (https://golang.org/issue/18177)
package syncmap
import (
"sync"
"sync/atomic"
"unsafe"
)
// Map is a concurrent map with amortized-constant-time loads, stores, and deletes.
// It is safe for multiple goroutines to call a Map's methods concurrently.
//
// The zero Map is valid and empty.
//
// A Map must not be copied after first use.
type Map struct {
mu sync.Mutex
// read contains the portion of the map's contents that are safe for
// concurrent access (with or without mu held).
//
// The read field itself is always safe to load, but must only be stored with
// mu held.
//
// Entries stored in read may be updated concurrently without mu, but updating
// a previously-expunged entry requires that the entry be copied to the dirty
// map and unexpunged with mu held.
read atomic.Value // readOnly
// dirty contains the portion of the map's contents that require mu to be
// held. To ensure that the dirty map can be promoted to the read map quickly,
// it also includes all of the non-expunged entries in the read map.
//
// Expunged entries are not stored in the dirty map. An expunged entry in the
// clean map must be unexpunged and added to the dirty map before a new value
// can be stored to it.
//
// If the dirty map is nil, the next write to the map will initialize it by
// making a shallow copy of the clean map, omitting stale entries.
dirty map[interface{}]*entry
// misses counts the number of loads since the read map was last updated that
// needed to lock mu to determine whether the key was present.
//
// Once enough misses have occurred to cover the cost of copying the dirty
// map, the dirty map will be promoted to the read map (in the unamended
// state) and the next store to the map will make a new dirty copy.
misses int
}
// readOnly is an immutable struct stored atomically in the Map.read field.
type readOnly struct {
m map[interface{}]*entry
amended bool // true if the dirty map contains some key not in m.
}
// expunged is an arbitrary pointer that marks entries which have been deleted
// from the dirty map.
var expunged = unsafe.Pointer(new(interface{}))
// An entry is a slot in the map corresponding to a particular key.
type entry struct {
// p points to the interface{} value stored for the entry.
//
// If p == nil, the entry has been deleted and m.dirty == nil.
//
// If p == expunged, the entry has been deleted, m.dirty != nil, and the entry
// is missing from m.dirty.
//
// Otherwise, the entry is valid and recorded in m.read.m[key] and, if m.dirty
// != nil, in m.dirty[key].
//
// An entry can be deleted by atomic replacement with nil: when m.dirty is
// next created, it will atomically replace nil with expunged and leave
// m.dirty[key] unset.
//
// An entry's associated value can be updated by atomic replacement, provided
// p != expunged. If p == expunged, an entry's associated value can be updated
// only after first setting m.dirty[key] = e so that lookups using the dirty
// map find the entry.
p unsafe.Pointer // *interface{}
}
func newEntry(i interface{}) *entry {
return &entry{p: unsafe.Pointer(&i)}
}
// Load returns the value stored in the map for a key, or nil if no
// value is present.
// The ok result indicates whether value was found in the map.
func (m *Map) Load(key interface{}) (value interface{}, ok bool) {
read, _ := m.read.Load().(readOnly)
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
// Avoid reporting a spurious miss if m.dirty got promoted while we were
// blocked on m.mu. (If further loads of the same key will not miss, it's
// not worth copying the dirty map for this key.)
read, _ = m.read.Load().(readOnly)
e, ok = read.m[key]
if !ok && read.amended {
e, ok = m.dirty[key]
// Regardless of whether the entry was present, record a miss: this key
// will take the slow path until the dirty map is promoted to the read
// map.
m.missLocked()
}
m.mu.Unlock()
}
if !ok {
return nil, false
}
return e.load()
}
func (e *entry) load() (value interface{}, ok bool) {
p := atomic.LoadPointer(&e.p)
if p == nil || p == expunged {
return nil, false
}
return *(*interface{})(p), true
}
// Store sets the value for a key.
func (m *Map) Store(key, value interface{}) {
read, _ := m.read.Load().(readOnly)
if e, ok := read.m[key]; ok && e.tryStore(&value) {
return
}
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
// The entry was previously expunged, which implies that there is a
// non-nil dirty map and this entry is not in it.
m.dirty[key] = e
}
e.storeLocked(&value)
} else if e, ok := m.dirty[key]; ok {
e.storeLocked(&value)
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(readOnly{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
}
m.mu.Unlock()
}
// tryStore stores a value if the entry has not been expunged.
//
// If the entry is expunged, tryStore returns false and leaves the entry
// unchanged.
func (e *entry) tryStore(i *interface{}) bool {
p := atomic.LoadPointer(&e.p)
if p == expunged {
return false
}
for {
if atomic.CompareAndSwapPointer(&e.p, p, unsafe.Pointer(i)) {
return true
}
p = atomic.LoadPointer(&e.p)
if p == expunged {
return false
}
}
}
// unexpungeLocked ensures that the entry is not marked as expunged.
//
// If the entry was previously expunged, it must be added to the dirty map
// before m.mu is unlocked.
func (e *entry) unexpungeLocked() (wasExpunged bool) {
return atomic.CompareAndSwapPointer(&e.p, expunged, nil)
}
// storeLocked unconditionally stores a value to the entry.
//
// The entry must be known not to be expunged.
func (e *entry) storeLocked(i *interface{}) {
atomic.StorePointer(&e.p, unsafe.Pointer(i))
}
// LoadOrStore returns the existing value for the key if present.
// Otherwise, it stores and returns the given value.
// The loaded result is true if the value was loaded, false if stored.
func (m *Map) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
// Avoid locking if it's a clean hit.
read, _ := m.read.Load().(readOnly)
if e, ok := read.m[key]; ok {
actual, loaded, ok := e.tryLoadOrStore(value)
if ok {
return actual, loaded
}
}
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
if e, ok := read.m[key]; ok {
if e.unexpungeLocked() {
m.dirty[key] = e
}
actual, loaded, _ = e.tryLoadOrStore(value)
} else if e, ok := m.dirty[key]; ok {
actual, loaded, _ = e.tryLoadOrStore(value)
m.missLocked()
} else {
if !read.amended {
// We're adding the first new key to the dirty map.
// Make sure it is allocated and mark the read-only map as incomplete.
m.dirtyLocked()
m.read.Store(readOnly{m: read.m, amended: true})
}
m.dirty[key] = newEntry(value)
actual, loaded = value, false
}
m.mu.Unlock()
return actual, loaded
}
// tryLoadOrStore atomically loads or stores a value if the entry is not
// expunged.
//
// If the entry is expunged, tryLoadOrStore leaves the entry unchanged and
// returns with ok==false.
func (e *entry) tryLoadOrStore(i interface{}) (actual interface{}, loaded, ok bool) {
p := atomic.LoadPointer(&e.p)
if p == expunged {
return nil, false, false
}
if p != nil {
return *(*interface{})(p), true, true
}
// Copy the interface after the first load to make this method more amenable
// to escape analysis: if we hit the "load" path or the entry is expunged, we
// shouldn't bother heap-allocating.
ic := i
for {
if atomic.CompareAndSwapPointer(&e.p, nil, unsafe.Pointer(&ic)) {
return i, false, true
}
p = atomic.LoadPointer(&e.p)
if p == expunged {
return nil, false, false
}
if p != nil {
return *(*interface{})(p), true, true
}
}
}
// Delete deletes the value for a key.
func (m *Map) Delete(key interface{}) {
read, _ := m.read.Load().(readOnly)
e, ok := read.m[key]
if !ok && read.amended {
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
e, ok = read.m[key]
if !ok && read.amended {
delete(m.dirty, key)
}
m.mu.Unlock()
}
if ok {
e.delete()
}
}
func (e *entry) delete() (hadValue bool) {
for {
p := atomic.LoadPointer(&e.p)
if p == nil || p == expunged {
return false
}
if atomic.CompareAndSwapPointer(&e.p, p, nil) {
return true
}
}
}
// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
//
// Range does not necessarily correspond to any consistent snapshot of the Map's
// contents: no key will be visited more than once, but if the value for any key
// is stored or deleted concurrently, Range may reflect any mapping for that key
// from any point during the Range call.
//
// Range may be O(N) with the number of elements in the map even if f returns
// false after a constant number of calls.
func (m *Map) Range(f func(key, value interface{}) bool) {
// We need to be able to iterate over all of the keys that were already
// present at the start of the call to Range.
// If read.amended is false, then read.m satisfies that property without
// requiring us to hold m.mu for a long time.
read, _ := m.read.Load().(readOnly)
if read.amended {
// m.dirty contains keys not in read.m. Fortunately, Range is already O(N)
// (assuming the caller does not break out early), so a call to Range
// amortizes an entire copy of the map: we can promote the dirty copy
// immediately!
m.mu.Lock()
read, _ = m.read.Load().(readOnly)
if read.amended {
read = readOnly{m: m.dirty}
m.read.Store(read)
m.dirty = nil
m.misses = 0
}
m.mu.Unlock()
}
for k, e := range read.m {
v, ok := e.load()
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
func (m *Map) missLocked() {
m.misses++
if m.misses < len(m.dirty) {
return
}
m.read.Store(readOnly{m: m.dirty})
m.dirty = nil
m.misses = 0
}
func (m *Map) dirtyLocked() {
if m.dirty != nil {
return
}
read, _ := m.read.Load().(readOnly)
m.dirty = make(map[interface{}]*entry, len(read.m))
for k, e := range read.m {
if !e.tryExpungeLocked() {
m.dirty[k] = e
}
}
}
func (e *entry) tryExpungeLocked() (isExpunged bool) {
p := atomic.LoadPointer(&e.p)
for p == nil {
if atomic.CompareAndSwapPointer(&e.p, nil, expunged) {
return true
}
p = atomic.LoadPointer(&e.p)
}
return p == expunged
}

216
vendor/golang.org/x/sync/syncmap/map_bench_test.go generated vendored Normal file
View File

@ -0,0 +1,216 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap_test
import (
"fmt"
"reflect"
"sync/atomic"
"testing"
"golang.org/x/sync/syncmap"
)
type bench struct {
setup func(*testing.B, mapInterface)
perG func(b *testing.B, pb *testing.PB, i int, m mapInterface)
}
func benchMap(b *testing.B, bench bench) {
for _, m := range [...]mapInterface{&DeepCopyMap{}, &RWMutexMap{}, &syncmap.Map{}} {
b.Run(fmt.Sprintf("%T", m), func(b *testing.B) {
m = reflect.New(reflect.TypeOf(m).Elem()).Interface().(mapInterface)
if bench.setup != nil {
bench.setup(b, m)
}
b.ResetTimer()
var i int64
b.RunParallel(func(pb *testing.PB) {
id := int(atomic.AddInt64(&i, 1) - 1)
bench.perG(b, pb, id*b.N, m)
})
})
}
}
func BenchmarkLoadMostlyHits(b *testing.B) {
const hits, misses = 1023, 1
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < hits; i++ {
m.LoadOrStore(i, i)
}
// Prime the map to get it into a steady state.
for i := 0; i < hits*2; i++ {
m.Load(i % hits)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Load(i % (hits + misses))
}
},
})
}
func BenchmarkLoadMostlyMisses(b *testing.B) {
const hits, misses = 1, 1023
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < hits; i++ {
m.LoadOrStore(i, i)
}
// Prime the map to get it into a steady state.
for i := 0; i < hits*2; i++ {
m.Load(i % hits)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Load(i % (hits + misses))
}
},
})
}
func BenchmarkLoadOrStoreBalanced(b *testing.B) {
const hits, misses = 128, 128
benchMap(b, bench{
setup: func(b *testing.B, m mapInterface) {
if _, ok := m.(*DeepCopyMap); ok {
b.Skip("DeepCopyMap has quadratic running time.")
}
for i := 0; i < hits; i++ {
m.LoadOrStore(i, i)
}
// Prime the map to get it into a steady state.
for i := 0; i < hits*2; i++ {
m.Load(i % hits)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
j := i % (hits + misses)
if j < hits {
if _, ok := m.LoadOrStore(j, i); !ok {
b.Fatalf("unexpected miss for %v", j)
}
} else {
if v, loaded := m.LoadOrStore(i, i); loaded {
b.Fatalf("failed to store %v: existing value %v", i, v)
}
}
}
},
})
}
func BenchmarkLoadOrStoreUnique(b *testing.B) {
benchMap(b, bench{
setup: func(b *testing.B, m mapInterface) {
if _, ok := m.(*DeepCopyMap); ok {
b.Skip("DeepCopyMap has quadratic running time.")
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.LoadOrStore(i, i)
}
},
})
}
func BenchmarkLoadOrStoreCollision(b *testing.B) {
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
m.LoadOrStore(0, 0)
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.LoadOrStore(0, 0)
}
},
})
}
func BenchmarkRange(b *testing.B) {
const mapSize = 1 << 10
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < mapSize; i++ {
m.Store(i, i)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Range(func(_, _ interface{}) bool { return true })
}
},
})
}
// BenchmarkAdversarialAlloc tests performance when we store a new value
// immediately whenever the map is promoted to clean and otherwise load a
// unique, missing key.
//
// This forces the Load calls to always acquire the map's mutex.
func BenchmarkAdversarialAlloc(b *testing.B) {
benchMap(b, bench{
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
var stores, loadsSinceStore int64
for ; pb.Next(); i++ {
m.Load(i)
if loadsSinceStore++; loadsSinceStore > stores {
m.LoadOrStore(i, stores)
loadsSinceStore = 0
stores++
}
}
},
})
}
// BenchmarkAdversarialDelete tests performance when we periodically delete
// one key and add a different one in a large map.
//
// This forces the Load calls to always acquire the map's mutex and periodically
// makes a full copy of the map despite changing only one entry.
func BenchmarkAdversarialDelete(b *testing.B) {
const mapSize = 1 << 10
benchMap(b, bench{
setup: func(_ *testing.B, m mapInterface) {
for i := 0; i < mapSize; i++ {
m.Store(i, i)
}
},
perG: func(b *testing.B, pb *testing.PB, i int, m mapInterface) {
for ; pb.Next(); i++ {
m.Load(i)
if i%mapSize == 0 {
m.Range(func(k, _ interface{}) bool {
m.Delete(k)
return false
})
m.Store(i, i)
}
}
},
})
}

151
vendor/golang.org/x/sync/syncmap/map_reference_test.go generated vendored Normal file
View File

@ -0,0 +1,151 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap_test
import (
"sync"
"sync/atomic"
)
// This file contains reference map implementations for unit-tests.
// mapInterface is the interface Map implements.
type mapInterface interface {
Load(interface{}) (interface{}, bool)
Store(key, value interface{})
LoadOrStore(key, value interface{}) (actual interface{}, loaded bool)
Delete(interface{})
Range(func(key, value interface{}) (shouldContinue bool))
}
// RWMutexMap is an implementation of mapInterface using a sync.RWMutex.
type RWMutexMap struct {
mu sync.RWMutex
dirty map[interface{}]interface{}
}
func (m *RWMutexMap) Load(key interface{}) (value interface{}, ok bool) {
m.mu.RLock()
value, ok = m.dirty[key]
m.mu.RUnlock()
return
}
func (m *RWMutexMap) Store(key, value interface{}) {
m.mu.Lock()
if m.dirty == nil {
m.dirty = make(map[interface{}]interface{})
}
m.dirty[key] = value
m.mu.Unlock()
}
func (m *RWMutexMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
m.mu.Lock()
actual, loaded = m.dirty[key]
if !loaded {
actual = value
if m.dirty == nil {
m.dirty = make(map[interface{}]interface{})
}
m.dirty[key] = value
}
m.mu.Unlock()
return actual, loaded
}
func (m *RWMutexMap) Delete(key interface{}) {
m.mu.Lock()
delete(m.dirty, key)
m.mu.Unlock()
}
func (m *RWMutexMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
m.mu.RLock()
keys := make([]interface{}, 0, len(m.dirty))
for k := range m.dirty {
keys = append(keys, k)
}
m.mu.RUnlock()
for _, k := range keys {
v, ok := m.Load(k)
if !ok {
continue
}
if !f(k, v) {
break
}
}
}
// DeepCopyMap is an implementation of mapInterface using a Mutex and
// atomic.Value. It makes deep copies of the map on every write to avoid
// acquiring the Mutex in Load.
type DeepCopyMap struct {
mu sync.Mutex
clean atomic.Value
}
func (m *DeepCopyMap) Load(key interface{}) (value interface{}, ok bool) {
clean, _ := m.clean.Load().(map[interface{}]interface{})
value, ok = clean[key]
return value, ok
}
func (m *DeepCopyMap) Store(key, value interface{}) {
m.mu.Lock()
dirty := m.dirty()
dirty[key] = value
m.clean.Store(dirty)
m.mu.Unlock()
}
func (m *DeepCopyMap) LoadOrStore(key, value interface{}) (actual interface{}, loaded bool) {
clean, _ := m.clean.Load().(map[interface{}]interface{})
actual, loaded = clean[key]
if loaded {
return actual, loaded
}
m.mu.Lock()
// Reload clean in case it changed while we were waiting on m.mu.
clean, _ = m.clean.Load().(map[interface{}]interface{})
actual, loaded = clean[key]
if !loaded {
dirty := m.dirty()
dirty[key] = value
actual = value
m.clean.Store(dirty)
}
m.mu.Unlock()
return actual, loaded
}
func (m *DeepCopyMap) Delete(key interface{}) {
m.mu.Lock()
dirty := m.dirty()
delete(dirty, key)
m.clean.Store(dirty)
m.mu.Unlock()
}
func (m *DeepCopyMap) Range(f func(key, value interface{}) (shouldContinue bool)) {
clean, _ := m.clean.Load().(map[interface{}]interface{})
for k, v := range clean {
if !f(k, v) {
break
}
}
}
func (m *DeepCopyMap) dirty() map[interface{}]interface{} {
clean, _ := m.clean.Load().(map[interface{}]interface{})
dirty := make(map[interface{}]interface{}, len(clean)+1)
for k, v := range clean {
dirty[k] = v
}
return dirty
}

172
vendor/golang.org/x/sync/syncmap/map_test.go generated vendored Normal file
View File

@ -0,0 +1,172 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package syncmap_test
import (
"math/rand"
"reflect"
"runtime"
"sync"
"testing"
"testing/quick"
"golang.org/x/sync/syncmap"
)
type mapOp string
const (
opLoad = mapOp("Load")
opStore = mapOp("Store")
opLoadOrStore = mapOp("LoadOrStore")
opDelete = mapOp("Delete")
)
var mapOps = [...]mapOp{opLoad, opStore, opLoadOrStore, opDelete}
// mapCall is a quick.Generator for calls on mapInterface.
type mapCall struct {
op mapOp
k, v interface{}
}
func (c mapCall) apply(m mapInterface) (interface{}, bool) {
switch c.op {
case opLoad:
return m.Load(c.k)
case opStore:
m.Store(c.k, c.v)
return nil, false
case opLoadOrStore:
return m.LoadOrStore(c.k, c.v)
case opDelete:
m.Delete(c.k)
return nil, false
default:
panic("invalid mapOp")
}
}
type mapResult struct {
value interface{}
ok bool
}
func randValue(r *rand.Rand) interface{} {
b := make([]byte, r.Intn(4))
for i := range b {
b[i] = 'a' + byte(rand.Intn(26))
}
return string(b)
}
func (mapCall) Generate(r *rand.Rand, size int) reflect.Value {
c := mapCall{op: mapOps[rand.Intn(len(mapOps))], k: randValue(r)}
switch c.op {
case opStore, opLoadOrStore:
c.v = randValue(r)
}
return reflect.ValueOf(c)
}
func applyCalls(m mapInterface, calls []mapCall) (results []mapResult, final map[interface{}]interface{}) {
for _, c := range calls {
v, ok := c.apply(m)
results = append(results, mapResult{v, ok})
}
final = make(map[interface{}]interface{})
m.Range(func(k, v interface{}) bool {
final[k] = v
return true
})
return results, final
}
func applyMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) {
return applyCalls(new(syncmap.Map), calls)
}
func applyRWMutexMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) {
return applyCalls(new(RWMutexMap), calls)
}
func applyDeepCopyMap(calls []mapCall) ([]mapResult, map[interface{}]interface{}) {
return applyCalls(new(DeepCopyMap), calls)
}
func TestMapMatchesRWMutex(t *testing.T) {
if err := quick.CheckEqual(applyMap, applyRWMutexMap, nil); err != nil {
t.Error(err)
}
}
func TestMapMatchesDeepCopy(t *testing.T) {
if err := quick.CheckEqual(applyMap, applyDeepCopyMap, nil); err != nil {
t.Error(err)
}
}
func TestConcurrentRange(t *testing.T) {
const mapSize = 1 << 10
m := new(syncmap.Map)
for n := int64(1); n <= mapSize; n++ {
m.Store(n, int64(n))
}
done := make(chan struct{})
var wg sync.WaitGroup
defer func() {
close(done)
wg.Wait()
}()
for g := int64(runtime.GOMAXPROCS(0)); g > 0; g-- {
r := rand.New(rand.NewSource(g))
wg.Add(1)
go func(g int64) {
defer wg.Done()
for i := int64(0); ; i++ {
select {
case <-done:
return
default:
}
for n := int64(1); n < mapSize; n++ {
if r.Int63n(mapSize) == 0 {
m.Store(n, n*i*g)
} else {
m.Load(n)
}
}
}
}(g)
}
iters := 1 << 10
if testing.Short() {
iters = 16
}
for n := iters; n > 0; n-- {
seen := make(map[int64]bool, mapSize)
m.Range(func(ki, vi interface{}) bool {
k, v := ki.(int64), vi.(int64)
if v%k != 0 {
t.Fatalf("while Storing multiples of %v, Range saw value %v", k, v)
}
if seen[k] {
t.Fatalf("Range visited key %v twice", k)
}
seen[k] = true
return true
})
if len(seen) != mapSize {
t.Fatalf("Range visited %v elements of %v-element Map", len(seen), mapSize)
}
}
}

View File

@ -714,6 +714,14 @@ var unmarshalTests = []struct {
"---\nhello\n...\n}not yaml", "---\nhello\n...\n}not yaml",
"hello", "hello",
}, },
{
"a: 5\n",
&struct{ A jsonNumberT }{"5"},
},
{
"a: 5.5\n",
&struct{ A jsonNumberT }{"5.5"},
},
} }
type M map[interface{}]interface{} type M map[interface{}]interface{}

28
vendor/gopkg.in/yaml.v2/encode.go generated vendored
View File

@ -13,6 +13,19 @@ import (
"unicode/utf8" "unicode/utf8"
) )
// jsonNumber is the interface of the encoding/json.Number datatype.
// Repeating the interface here avoids a dependency on encoding/json, and also
// supports other libraries like jsoniter, which use a similar datatype with
// the same interface. Detecting this interface is useful when dealing with
// structures containing json.Number, which is a string under the hood. The
// encoder should prefer the use of Int64(), Float64() and string(), in that
// order, when encoding this type.
type jsonNumber interface {
Float64() (float64, error)
Int64() (int64, error)
String() string
}
type encoder struct { type encoder struct {
emitter yaml_emitter_t emitter yaml_emitter_t
event yaml_event_t event yaml_event_t
@ -89,6 +102,21 @@ func (e *encoder) marshal(tag string, in reflect.Value) {
} }
iface := in.Interface() iface := in.Interface()
switch m := iface.(type) { switch m := iface.(type) {
case jsonNumber:
integer, err := m.Int64()
if err == nil {
// In this case the json.Number is a valid int64
in = reflect.ValueOf(integer)
break
}
float, err := m.Float64()
if err == nil {
// In this case the json.Number is a valid float64
in = reflect.ValueOf(float)
break
}
// fallback case - no number could be obtained
in = reflect.ValueOf(m.String())
case time.Time, *time.Time: case time.Time, *time.Time:
// Although time.Time implements TextMarshaler, // Although time.Time implements TextMarshaler,
// we don't want to treat it as a string for YAML // we don't want to treat it as a string for YAML

View File

@ -15,6 +15,24 @@ import (
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
type jsonNumberT string
func (j jsonNumberT) Int64() (int64, error) {
val, err := strconv.Atoi(string(j))
if err != nil {
return 0, err
}
return int64(val), nil
}
func (j jsonNumberT) Float64() (float64, error) {
return strconv.ParseFloat(string(j), 64)
}
func (j jsonNumberT) String() string {
return string(j)
}
var marshalIntTest = 123 var marshalIntTest = 123
var marshalTests = []struct { var marshalTests = []struct {
@ -367,6 +385,18 @@ var marshalTests = []struct {
map[string]string{"a": "你好 #comment"}, map[string]string{"a": "你好 #comment"},
"a: '你好 #comment'\n", "a: '你好 #comment'\n",
}, },
{
map[string]interface{}{"a": jsonNumberT("5")},
"a: 5\n",
},
{
map[string]interface{}{"a": jsonNumberT("100.5")},
"a: 100.5\n",
},
{
map[string]interface{}{"a": jsonNumberT("bogus")},
"a: bogus\n",
},
} }
func (s *S) TestMarshal(c *C) { func (s *S) TestMarshal(c *C) {