add dockerdist

This commit is contained in:
jgsqware 2016-06-07 18:09:52 +02:00
parent a6396921db
commit 06acefc8e7
479 changed files with 169929 additions and 243455 deletions

View File

@ -0,0 +1,94 @@
package cmd
import (
"fmt"
"html/template"
"os"
"github.com/Sirupsen/logrus"
"github.com/coreos/clair/cmd/clairctl/docker"
"github.com/coreos/clair/cmd/clairctl/dockerdist"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/docker/reference"
"github.com/docker/engine-api/client"
"github.com/docker/engine-api/types"
"github.com/spf13/cobra"
)
const pullDockerTplt = `
Image: {{.Named.FullName}}
{{.V1Manifest.FSLayers | len}} layers found
{{range .V1Manifest.FSLayers}} {{.BlobSum}}
{{end}}
`
var pullDockerCmd = &cobra.Command{
Use: "pull-docker IMAGE",
Short: "Pull Docker image to Clair",
Long: `Upload a Docker image to Clair for further analysis`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
fmt.Printf("clairctl: \"pull\" requires a minimum of 1 argument\n")
os.Exit(1)
}
imageName := args[0]
if !docker.IsLocal {
n, manifest, err := dockerdist.DownloadManifest(imageName, true)
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("parsing image %q: %v", imageName, err)
}
// Ensure that the manifest type is supported.
switch manifest.(type) {
case *schema1.SignedManifest:
break
default:
fmt.Println(errInternalError)
logrus.Fatalf("only v1 manifests are currently supported")
}
data := struct {
V1Manifest *schema1.SignedManifest
Named reference.Named
}{
V1Manifest: manifest.(*schema1.SignedManifest),
Named: n,
}
err = template.Must(template.New("pull").Parse(pullDockerTplt)).Execute(os.Stdout, data)
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("rendering image: %v", err)
}
} else {
localManifest(imageName)
}
},
}
func localManifest(imageName string) {
defaultHeaders := map[string]string{"User-Agent": "engine-api-cli-1.0"}
cli, err := client.NewClient("unix:///var/run/docker.sock", "v1.22", nil, defaultHeaders)
if err != nil {
panic(err)
}
t := types.ImageListOptions{MatchName: imageName}
cli.ImageList(t)
// histories, err := cli.ImageHistory(context.Background(), imageName)
if err != nil {
panic(err)
}
// for _, history := range histories {
// fmt.Println(history)
// }
}
func init() {
RootCmd.AddCommand(pullDockerCmd)
pullDockerCmd.Flags().BoolVarP(&docker.IsLocal, "local", "l", false, "Use local images")
}

View File

@ -0,0 +1,83 @@
package cmd
import (
"fmt"
"os"
"github.com/Sirupsen/logrus"
"github.com/coreos/clair/cmd/clairctl/docker"
"github.com/coreos/clair/cmd/clairctl/dockerdist"
"github.com/docker/distribution/manifest/schema1"
"github.com/spf13/cobra"
)
var pushDockerCmd = &cobra.Command{
Use: "push-docker IMAGE",
Short: "Push Docker image to Clair",
Long: `Upload a Docker image to Clair for further analysis`,
Run: func(cmd *cobra.Command, args []string) {
if len(args) != 1 {
fmt.Printf("clairctl: \"push\" requires a minimum of 1 argument\n")
os.Exit(1)
}
startLocalServer()
imageName := args[0]
var image docker.Image
if !docker.IsLocal {
image, manifest, err := dockerdist.DownloadManifest(imageName, true)
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("parsing local image %q: %v", imageName, err)
}
// Ensure that the manifest type is supported.
switch manifest.(type) {
case *schema1.SignedManifest:
break
default:
fmt.Println(errInternalError)
logrus.Fatalf("only v1 manifests are currently supported")
}
v1manifest := manifest.(*schema1.SignedManifest)
if err := dockerdist.Push(image, *v1manifest); err != nil {
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("pushing image %q: %v", imageName, err)
}
}
} else {
var err error
image, err = docker.Parse(imageName)
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("parsing local image %q: %v", imageName, err)
}
err = docker.Prepare(&image)
logrus.Debugf("prepared image layers: %d", len(image.FsLayers))
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("preparing local image %q from history: %v", imageName, err)
}
logrus.Info("Pushing Image [OLD WAY] should be deprecated")
if err := docker.Push(image); err != nil {
if err != nil {
fmt.Println(errInternalError)
logrus.Fatalf("pushing image %q: %v", imageName, err)
}
}
}
fmt.Printf("%v has been pushed to Clair\n", imageName)
},
}
func init() {
RootCmd.AddCommand(pushDockerCmd)
pushDockerCmd.Flags().BoolVarP(&docker.IsLocal, "local", "l", false, "Use local images")
}

View File

@ -13,55 +13,31 @@
# limitations under the License.
# The values specified here are the default values that Clair uses if no configuration file is specified or if the keys are not defined.
---
database:
# PostgreSQL Connection string
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html
source: postgresql://postgres:5432?sslmode=disable&user=postgres&password=root
clair:
database:
# PostgreSQL Connection string
# http://www.postgresql.org/docs/9.4/static/libpq-connect.html
source: postgresql://postgres:5432?sslmode=disable&user=postgres&password=root
# Number of elements kept in the cache
# Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database.
cacheSize: 16384
# Number of elements kept in the cache
# Values unlikely to change (e.g. namespaces) are cached in order to save prevent needless roundtrips to the database.
cacheSize: 16384
api:
# API server port
port: 6060
api:
# API server port
port: 6060
# Health server port
# This is an unencrypted endpoint useful for load balancers to check to healthiness of the clair server.
healthport: 6061
# Health server port
# This is an unencrypted endpoint useful for load balancers to check to healthiness of the clair server.
healthport: 6061
# Deadline before an API request will respond with a 503
timeout: 900s
# Deadline before an API request will respond with a 503
timeout: 900s
# 32-bit URL-safe base64 key used to encrypt pagination tokens
# If one is not provided, it will be generated.
# Multiple clair instances in the same cluster need the same value.
paginationKey:
# Optional PKI configuration
# If you want to easily generate client certificates and CAs, try the following projects:
# https://github.com/coreos/etcd-ca
# https://github.com/cloudflare/cfssl
cafile:
keyfile:
certfile:
updater:
# Frequency the database will be updated with vulnerabilities from the default data sources
# The value 0 disables the updater entirely.
interval: 2h
notifier:
# Number of attempts before the notification is marked as failed to be sent
attempts: 3
# Duration before a failed notification is retried
renotifyInterval: 2h
http:
# Optional endpoint that will receive notifications via POST requests
endpoint:
# 32-bit URL-safe base64 key used to encrypt pagination tokens
# If one is not provided, it will be generated.
# Multiple clair instances in the same cluster need the same value.
paginationKey:
# Optional PKI configuration
# If you want to easily generate client certificates and CAs, try the following projects:
@ -71,3 +47,31 @@ notifier:
cafile:
keyfile:
certfile:
updater:
# Frequency the database will be updated with vulnerabilities from the default data sources
# The value 0 disables the updater entirely.
interval: 0
notifier:
# Number of attempts before the notification is marked as failed to be sent
attempts: 3
# Duration before a failed notification is retried
renotifyInterval: 2h
http:
# Optional endpoint that will receive notifications via POST requests
endpoint:
# Optional PKI configuration
# If you want to easily generate client certificates and CAs, try the following projects:
# https://github.com/cloudflare/cfssl
# https://github.com/coreos/etcd-ca
servername:
cafile:
keyfile:
certfile:
# Optional HTTP Proxy: must be a valid URL (including the scheme).
proxy:

View File

@ -28,7 +28,7 @@ services:
- REGISTRY_AUTH_TOKEN_ROOTCERTBUNDLE=/ssl/server.pem
clair:
image: quay.io/coreos/clair:v1.0.0-rc1
image: quay.io/coreos/clair:v1.2.2
volumes:
- /tmp:/tmp
- ./config:/config

View File

@ -0,0 +1,185 @@
// Copyright 2016 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package dockerdist provides helper methods for retrieving and parsing a
// information from a remote Docker repository.
package dockerdist
import (
"errors"
"log"
"net/url"
distlib "github.com/docker/distribution"
"github.com/docker/distribution/digest"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/distribution/registry/client"
"github.com/docker/docker/cliconfig"
"github.com/docker/docker/distribution"
"github.com/docker/docker/reference"
"github.com/docker/docker/registry"
"github.com/docker/engine-api/types"
"github.com/docker/go-connections/tlsconfig"
"github.com/spf13/viper"
"golang.org/x/net/context"
)
// getRepositoryClient returns a client for performing registry operations against the given named
// image.
func getRepositoryClient(image reference.Named, insecure bool, scopes ...string) (distlib.Repository, error) {
// Lookup the index information for the name.
indexInfo, err := registry.ParseSearchIndexInfo(image.String())
if err != nil {
return nil, err
}
// Retrieve the user's Docker configuration file (if any).
configFile, err := cliconfig.Load(cliconfig.ConfigDir())
if err != nil {
return nil, err
}
// Resolve the authentication information for the registry specified, via the config file.
authConfig := registry.ResolveAuthConfig(configFile.AuthConfigs, indexInfo)
repoInfo := &registry.RepositoryInfo{
image,
indexInfo,
false,
}
metaHeaders := map[string][]string{}
tlsConfig := tlsconfig.ServerDefault
tlsConfig.InsecureSkipVerify = viper.GetBool("auth.insecureSkipVerify")
url, err := url.Parse("https://" + image.Hostname())
if insecure {
url, err = url.Parse("http://" + image.Hostname())
}
if err != nil {
return nil, err
}
endpoint := registry.APIEndpoint{
URL: url,
Version: registry.APIVersion2,
Official: false,
TrimHostname: true,
TLSConfig: &tlsConfig,
}
ctx := context.Background()
repo, _, err := distribution.NewV2Repository(ctx, repoInfo, endpoint, metaHeaders, &authConfig, scopes...)
return repo, err
}
// getDigest returns the digest for the given image.
func getDigest(ctx context.Context, repo distlib.Repository, image reference.Named) (digest.Digest, error) {
if withDigest, ok := image.(reference.Canonical); ok {
return withDigest.Digest(), nil
}
// Get TagService.
tagSvc := repo.Tags(ctx)
// Get Tag name.
tag := "latest"
if withTag, ok := image.(reference.NamedTagged); ok {
tag = withTag.Tag()
}
// Get Tag's Descriptor.
descriptor, err := tagSvc.Get(ctx, tag)
if err != nil {
// Docker returns an UnexpectedHTTPResponseError if it cannot parse the JSON body of an
// unexpected error. Unfortunately, HEAD requests *by definition* don't have bodies, so
// Docker will return this error for non-200 HEAD requests. We therefore have to hack
// around it... *sigh*.
if _, ok := err.(*client.UnexpectedHTTPResponseError); ok {
return "", errors.New("Received error when trying to fetch the specified tag: it might not exist or you do not have access")
}
return "", err
}
return descriptor.Digest, nil
}
// GetAuthCredentials returns the auth credentials (if any found) for the given repository, as found
// in the user's docker config.
func GetAuthCredentials(image string) (types.AuthConfig, error) {
// Lookup the index information for the name.
indexInfo, err := registry.ParseSearchIndexInfo(image)
if err != nil {
return types.AuthConfig{}, err
}
// Retrieve the user's Docker configuration file (if any).
configFile, err := cliconfig.Load(cliconfig.ConfigDir())
if err != nil {
return types.AuthConfig{}, err
}
// Resolve the authentication information for the registry specified, via the config file.
return registry.ResolveAuthConfig(configFile.AuthConfigs, indexInfo), nil
}
// DownloadManifest the manifest for the given image, using the given credentials.
func DownloadManifest(image string, insecure bool) (reference.Named, distlib.Manifest, error) {
// Parse the image name as a docker image reference.
named, err := reference.ParseNamed(image)
if err != nil {
return nil, nil, err
}
// Create a reference to a repository client for the repo.
repo, err := getRepositoryClient(named, insecure, "pull")
if err != nil {
return nil, nil, err
}
// Get the digest.
ctx := context.Background()
digest, err := getDigest(ctx, repo, named)
if err != nil {
return nil, nil, err
}
// Retrieve the manifest for the tag.
log.Printf("Downloading manifest for image %v", image)
manSvc, err := repo.Manifests(ctx)
if err != nil {
return nil, nil, err
}
manifest, err := manSvc.Get(ctx, digest)
if err != nil {
return nil, nil, err
}
// Verify the manifest if it's signed.
switch manifest.(type) {
case *schema1.SignedManifest:
_, verr := schema1.Verify(manifest.(*schema1.SignedManifest))
if verr != nil {
return nil, nil, verr
}
default:
log.Printf("Could not verify manifest for image %v: not signed", image)
}
return named, manifest, nil
}

View File

@ -0,0 +1,101 @@
package dockerdist
import (
"fmt"
"strings"
"github.com/Sirupsen/logrus"
"github.com/coreos/clair/api/v1"
"github.com/coreos/clair/cmd/clairctl/clair"
"github.com/coreos/clair/cmd/clairctl/config"
"github.com/coreos/clair/cmd/clairctl/docker"
"github.com/coreos/clair/cmd/clairctl/xstrings"
"github.com/docker/distribution/manifest/schema1"
"github.com/docker/docker/reference"
)
var registryMapping map[string]string
//Push image to Clair for analysis
func Push(image reference.Named, manifest schema1.SignedManifest) error {
layerCount := len(manifest.FSLayers)
parentID := ""
if layerCount == 0 {
logrus.Warningln("there is no layer to push")
}
localIP, err := config.LocalServerIP()
if err != nil {
return err
}
hURL := fmt.Sprintf("http://%v/v2", localIP)
if docker.IsLocal {
hURL += "/local"
logrus.Infof("using %v as local url", hURL)
}
for index, layer := range manifest.FSLayers {
lUID := xstrings.Substr(layer.BlobSum.String(), 0, 12)
logrus.Infof("Pushing Layer %d/%d [%v]", index+1, layerCount, lUID)
insertRegistryMapping(layer.BlobSum.String(), image.Hostname())
payload := v1.LayerEnvelope{Layer: &v1.Layer{
Name: layer.BlobSum.String(),
Path: blobsURI(image.Hostname(), image.RemoteName(), layer.BlobSum.String()),
ParentName: parentID,
Format: "Docker",
}}
//FIXME Update to TLS
//FIXME use local with new push
// if IsLocal {
// payload.Layer.Name = layer.History
// payload.Layer.Path += "/layer.tar"
// }
payload.Layer.Path = strings.Replace(payload.Layer.Path, image.Hostname(), hURL, 1)
if err := clair.Push(payload); err != nil {
logrus.Infof("adding layer %d/%d [%v]: %v", index+1, layerCount, lUID, err)
if err != clair.ErrUnanalizedLayer {
return err
}
parentID = ""
} else {
parentID = payload.Layer.Name
}
}
// if IsLocal {
// if err := cleanLocal(); err != nil {
// return err
// }
// }
return nil
}
func blobsURI(registry string, name string, digest string) string {
return strings.Join([]string{registry, name, "blobs", digest}, "/")
}
func insertRegistryMapping(layerDigest string, registryURI string) {
if strings.Contains(registryURI, "docker") {
registryURI = "https://" + registryURI + "/v2"
} else {
registryURI = "http://" + registryURI + "/v2"
}
logrus.Debugf("Saving %s[%s]", layerDigest, registryURI)
registryMapping[layerDigest] = registryURI
}
//GetRegistryMapping return the registryURI corresponding to the layerID passed as parameter
func GetRegistryMapping(layerDigest string) (string, error) {
registryURI, present := registryMapping[layerDigest]
if !present {
return "", fmt.Errorf("%v mapping not found", layerDigest)
}
return registryURI, nil
}
func init() {
registryMapping = map[string]string{}
}

View File

@ -13,6 +13,7 @@ import (
"github.com/Sirupsen/logrus"
"github.com/coreos/clair/cmd/clairctl/docker"
"github.com/coreos/clair/cmd/clairctl/docker/httpclient"
"github.com/coreos/clair/cmd/clairctl/dockerdist"
"github.com/spf13/viper"
)
@ -70,8 +71,12 @@ func newSingleHostReverseProxy() *httputil.ReverseProxy {
if !validID.MatchString(u) {
logrus.Errorf("cannot parse url: %v", u)
}
host, _ := docker.GetRegistryMapping(validID.FindStringSubmatch(u)[1])
var host string
if docker.IsLocal {
host, _ = docker.GetRegistryMapping(validID.FindStringSubmatch(u)[1])
} else {
host, _ = dockerdist.GetRegistryMapping(validID.FindStringSubmatch(u)[1])
}
out, _ := url.Parse(host)
request.URL.Scheme = out.Scheme
request.URL.Host = out.Host

188
glide.lock generated
View File

@ -1,49 +1,144 @@
hash: be31158bfcdaebfaf87d53e967e7b7e26d8b93a7aea2263abc1a2055041c36e9
updated: 2016-06-07T17:25:46.628558614+02:00
hash: 0cd6a528aefc39458f3a6d7be2e882452b5d5355d84564e04c762b62a862cc8f
updated: 2016-06-07T18:45:29.977471975+02:00
imports:
- name: bitbucket.org/liamstask/goose
version: ""
version: 8488cc47d90c8a502b1c41a462a6d9cc8ee0a895
subpackages:
- lib/goose
- name: github.com/Azure/go-ansiterm
version: 388960b655244e76e24c75f48631564eaefade62
subpackages:
- winterm
- name: github.com/beorn7/perks
version: ""
version: b965b613227fddccbfffe13eae360ed3fa822f8d
subpackages:
- quantile
- name: github.com/BurntSushi/toml
version: f0aeabca5a127c4078abb8c8d64298b147264b55
- name: github.com/codegangsta/negroni
version: ""
version: c7477ad8e330bef55bf1ebe300cf8aa67c492d1b
- name: github.com/coreos/go-systemd
version: ""
version: 4f14f6deef2da87e4aa59e6c1c1f3e02ba44c5e1
subpackages:
- journal
- name: github.com/coreos/pkg
version: ""
version: 2c77715c4df99b5420ffcae14ead08f52104065d
subpackages:
- capnslog
- timeutil
- name: github.com/davecgh/go-spew
version: ""
version: 5215b55f46b2b919f50a1df0eaa5886afe4e3b3d
subpackages:
- spew
- name: github.com/docker/distribution
version: c8dff1bb5764aa6ea48d261b813ef3b0d2e25989
subpackages:
- manifest/schema1
- digest
- registry/client
- context
- manifest
- reference
- registry/api/errcode
- registry/api/v2
- registry/client/transport
- registry/storage/cache
- registry/storage/cache/memory
- manifest/manifestlist
- manifest/schema2
- registry/client/auth
- uuid
- name: github.com/docker/docker
version: 508a17baba3c39496008fc5b5e3fe890b8a1b31b
subpackages:
- api/server/httputils
- reference
- cliconfig
- distribution
- registry
- pkg/homedir
- api
- distribution/metadata
- distribution/xfer
- dockerversion
- image
- image/v1
- layer
- pkg/ioutils
- pkg/progress
- pkg/stringid
- opts
- pkg/httputils
- pkg/mflag
- pkg/tarsum
- pkg/system
- pkg/version
- pkg/archive
- pkg/parsers/kernel
- pkg/useragent
- daemon/graphdriver
- pkg/idtools
- pkg/longpath
- pkg/random
- pkg/jsonmessage
- pkg/fileutils
- pkg/pools
- pkg/promise
- pkg/chrootarchive
- pkg/plugins
- pkg/jsonlog
- pkg/term
- pkg/reexec
- pkg/plugins/transport
- pkg/term/windows
- name: github.com/docker/engine-api
version: f9f614e3ccca1554dcf7ffe91dab96ebc72dc24d
subpackages:
- client
- types
- client/transport
- client/transport/cancellable
- types/container
- types/filters
- types/network
- types/reference
- types/registry
- types/time
- types/blkiodev
- types/strslice
- types/versions
- name: github.com/docker/go-connections
version: f9f614e3ccca1554dcf7ffe91dab96ebc72dc24d
subpackages:
- tlsconfig
- sockets
- nat
- name: github.com/docker/go-units
version: 09dda9d4b0d748c57c14048906d3d094a58ec0c9
- name: github.com/docker/libtrust
version: 9cbd2a1374f46905c68a4eb3694a130610adc62a
- name: github.com/fatih/color
version: 1b35f289c47d5c73c398cea8e006b7bcb6234a96
- name: github.com/fernet/fernet-go
version: ""
version: 1b2437bc582b3cfbb341ee5a29f8ef5b42912ff2
- name: github.com/fsnotify/fsnotify
version: 30411dbcefb7a1da7e84f75530ad3abe4011b4f8
- name: github.com/go-sql-driver/mysql
version: ""
version: d512f204a577a4ab037a1816604c48c9c13210be
- name: github.com/golang/protobuf
version: ""
version: 5fc2294e655b78ed8a02082d37808d46c17d7e64
subpackages:
- proto
- name: github.com/gorilla/context
version: 14f550f51af52180c2eefed15e5fd18d63c0a64a
- name: github.com/gorilla/mux
version: e444e69cbd2e2e3e0749a2f3c717cec491552bbf
- name: github.com/guregu/null
version: ""
version: 79c5bd36b615db4c06132321189f579c8a5fca98
subpackages:
- zero
- name: github.com/hashicorp/golang-lru
version: ""
version: 5c7531c003d8bf158b0fe5063649a2f41a822146
subpackages:
- simplelru
- name: github.com/hashicorp/hcl
@ -58,91 +153,106 @@ imports:
- json/scanner
- json/token
- name: github.com/inconshreveable/mousetrap
version: 76626ae9c91c4f2a10f34cad8ce83ea42c93bb75
- name: github.com/julienschmidt/httprouter
version: ""
- name: github.com/julienschmidt/httprouter
version: 21439ef4d70ba4f3e2a5ed9249e7b03af4019b40
- name: github.com/kr/text
version: 7cafcd837844e784b526369c9bce262804aebc60
- name: github.com/kylelemons/go-gypsy
version: ""
version: 42fc2c7ee9b8bd0ff636cd2d7a8c0a49491044c5
subpackages:
- yaml
- name: github.com/lib/pq
version: ""
version: 11fc39a580a008f1f39bb3d11d984fb34ed778d9
subpackages:
- oid
- name: github.com/magiconair/properties
version: c265cfa48dda6474e208715ca93e987829f572f8
- name: github.com/mattn/go-sqlite3
version: ""
version: 5510da399572b4962c020184bb291120c0a412e2
- name: github.com/matttproud/golang_protobuf_extensions
version: ""
version: d0c3fe89de86839aecf2e0579c40ba3bb336a453
subpackages:
- pbutil
- name: github.com/Microsoft/go-winio
version: 4f1a71750d95a5a8a46c40a67ffbed8129c2f138
- name: github.com/mitchellh/mapstructure
version: d2dd0262208475919e1a362f675cfc0e7c10e905
- name: github.com/opencontainers/runc
version: 6b4ebb033a807004b695adf05f75f73b83728c38
subpackages:
- libcontainer/user
- name: github.com/pborman/uuid
version: ""
version: dee7705ef7b324f27ceb85a121c61f2c2e8ce988
- name: github.com/pmezard/go-difflib
version: ""
version: e8554b8641db39598be7f6342874b958f12ae1d4
subpackages:
- difflib
- name: github.com/prometheus/client_golang
version: ""
version: 67994f177195311c3ea3d4407ed0175e34a4256f
subpackages:
- prometheus
- name: github.com/prometheus/client_model
version: ""
version: fa8ad6fec33561be4280a8f0514318c79d7f6cb6
subpackages:
- go
- name: github.com/prometheus/common
version: ""
version: dba5e39d4516169e840def50e507ef5f21b985f9
subpackages:
- expfmt
- internal/bitbucket.org/ww/goautoneg
- model
- name: github.com/prometheus/procfs
version: ""
version: 406e5b7bfd8201a36e2bb5f7bdae0b03380c2ce8
- name: github.com/shiena/ansicolor
version: a422bbe96644373c5753384a59d678f7d261ff10
version: f9f614e3ccca1554dcf7ffe91dab96ebc72dc24d
- name: github.com/Sirupsen/logrus
version: f3cfb454f4c209e6668c95216c4744b8fddb2356
version: f9f614e3ccca1554dcf7ffe91dab96ebc72dc24d
- name: github.com/spf13/cast
version: 27b586b42e29bec072fe7379259cc719e1289da6
- name: github.com/spf13/cobra
version: 1238ba19d24b0b9ceee2094e1cb31947d45c3e86
version: f9f614e3ccca1554dcf7ffe91dab96ebc72dc24d
- name: github.com/spf13/jwalterweatherman
version: 33c24e77fb80341fe7130ee7c594256ff08ccc46
- name: github.com/spf13/pflag
version: cb88ea77998c3f024757528e3305022ab50b43be
- name: github.com/spf13/viper
version: c1ccc378a054ea8d4e38d8c67f6938d4760b53dd
version: f9f614e3ccca1554dcf7ffe91dab96ebc72dc24d
- name: github.com/stretchr/testify
version: ""
version: 5b9da39b66e8e994455c2525c4421c8cc00a7f93
subpackages:
- assert
- name: github.com/tylerb/graceful
version: ""
version: 84177357ab104029f9237abcb52339a7b80760ef
- name: github.com/vbatts/tar-split
version: 226f7c74905f1fcc08ac128b517a1d65a1948eb9
subpackages:
- tar/asm
- tar/storage
- archive/tar
- name: github.com/ziutek/mymysql
version: ""
version: 75ce5fbba34b1912a3641adbd58cf317d7315821
subpackages:
- godrv
- mysql
- native
- name: golang.org/x/crypto
version: 89d9e62992539701a49a19c52ebb33e84cbbe80f
subpackages:
- bcrypt
- ssh/terminal
- blowfish
- name: golang.org/x/net
version: ""
subpackages:
- bcrypt
- blowfish
- ssh/terminal
- name: golang.org/x/net
version: 1d7a0b2100da090d8b02afcfb42f97e2c77e71a4
subpackages:
- netutil
- context
- proxy
- name: golang.org/x/sys
version: 076b546753157f758b316e59bcb51e6807c04057
subpackages:
- unix
- windows
- name: gopkg.in/yaml.v2
version: ""
version: f7716cbe52baa25d2e9b0d0da546fcf909fc16b4
devImports: []

View File

@ -93,4 +93,11 @@ import:
version: f7716cbe52baa25d2e9b0d0da546fcf909fc16b4
- package: github.com/fatih/color
version: ^0.1.0
- package: github.com/kr/text
- package: github.com/kr/text
- package: github.com/Azure/go-ansiterm
subpackages:
- winterm
- package: github.com/docker/docker
version: ^1.11.2
subpackages:
- api/server/httputils

1
vendor/github.com/Azure/go-ansiterm generated vendored Submodule

@ -0,0 +1 @@
Subproject commit 388960b655244e76e24c75f48631564eaefade62

1
vendor/github.com/Microsoft/go-winio generated vendored Submodule

@ -0,0 +1 @@
Subproject commit 4f1a71750d95a5a8a46c40a67ffbed8129c2f138

View File

@ -1,20 +0,0 @@
Copyright (C) 2013 Blake Mizerany
Permission is hereby granted, free of charge, to any person obtaining
a copy of this software and associated documentation files (the
"Software"), to deal in the Software without restriction, including
without limitation the rights to use, copy, modify, merge, publish,
distribute, sublicense, and/or sell copies of the Software, and to
permit persons to whom the Software is furnished to do so, subject to
the following conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View File

@ -1,15 +0,0 @@
language: go
sudo: false
go:
- 1.2.2
- 1.3.3
- 1.4
- 1.5.4
- 1.6.2
- master
matrix:
allow_failures:
- go: master

View File

@ -1,35 +0,0 @@
# Change Log
**ATTN**: This project uses [semantic versioning](http://semver.org/).
## [Unreleased]
### Added
- `Recovery.ErrorHandlerFunc` for custom error handling during recovery
### Fixed
- `Written()` correct returns `false` if no response header has been written
-
### Changed
- Set default status to `0` in the case that no handler writes status -- was
previously `200` (in 0.2.0, before that it was `0` so this reestablishes that
behavior)
## [0.2.0] - 2016-05-10
### Added
- Support for variadic handlers in `New()`
- Added `Negroni.Handlers()` to fetch all of the handlers for a given chain
- Allowed size in `Recovery` handler was bumped to 8k
- `Negroni.UseFunc` to push another handler onto the chain
### Changed
- Set the status before calling `beforeFuncs` so the information is available to them
- Set default status to `200` in the case that no handler writes status -- was previously `0`
- Panic if `nil` handler is given to `negroni.Use`
## 0.1.0 - 2013-07-22
### Added
- Initial implementation.
[Unreleased]: https://github.com/urfave/negroni/compare/v0.2.0...HEAD
[0.2.0]: https://github.com/urfave/negroni/compare/v0.1.0...v0.2.0

View File

@ -1,40 +1,24 @@
# Negroni
[![GoDoc](https://godoc.org/github.com/urfave/negroni?status.svg)](http://godoc.org/github.com/urfave/negroni)
[![Build Status](https://travis-ci.org/urfave/negroni.svg?branch=master)](https://travis-ci.org/urfave/negroni)
[![codebeat](https://codebeat.co/badges/47d320b1-209e-45e8-bd99-9094bc5111e2)](https://codebeat.co/projects/github-com-urfave-negroni)
# Negroni [![GoDoc](https://godoc.org/github.com/codegangsta/negroni?status.svg)](http://godoc.org/github.com/codegangsta/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
**Notice:** This is the library formally known as
`github.com/codegangsta/negroni` -- Github will automatically redirect requests
to this repository, but we recommend updating your references for clarity.
Negroni is an idiomatic approach to web middleware in Go. It is tiny, non-intrusive, and encourages use of `net/http` Handlers.
Negroni is an idiomatic approach to web middleware in Go. It is tiny,
non-intrusive, and encourages use of `net/http` Handlers.
If you like the idea of [Martini](http://github.com/go-martini/martini), but you think it contains too much magic, then Negroni is a great fit.
If you like the idea of [Martini](https://github.com/go-martini/martini), but
you think it contains too much magic, then Negroni is a great fit.
Language Translations:
* [German (de_DE)](translations/README_de_de.md)
* [Português Brasileiro (pt_BR)](translations/README_pt_br.md)
* [简体中文 (zh_cn)](translations/README_zh_cn.md)
* [繁體中文 (zh_tw)](translations/README_zh_tw.md)
* [日本語 (ja_JP)](translations/README_ja_JP.md)
## Getting Started
After installing Go and setting up your
[GOPATH](http://golang.org/doc/code.html#GOPATH), create your first `.go` file.
We'll call it `server.go`.
After installing Go and setting up your [GOPATH](http://golang.org/doc/code.html#GOPATH), create your first `.go` file. We'll call it `server.go`.
<!-- { "interrupt": true } -->
``` go
~~~ go
package main
import (
"fmt"
"github.com/codegangsta/negroni"
"net/http"
"github.com/urfave/negroni"
"fmt"
)
func main() {
@ -43,40 +27,34 @@ func main() {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic() // Includes some default middlewares
n := negroni.Classic()
n.UseHandler(mux)
http.ListenAndServe(":3000", n)
n.Run(":3000")
}
```
~~~
Then install the Negroni package (**NOTE**: &gt;= **go 1.1** is required):
```
go get github.com/urfave/negroni
```
Then install the Negroni package (**go 1.1** and greater is required):
~~~
go get github.com/codegangsta/negroni
~~~
Then run your server:
```
~~~
go run server.go
```
~~~
You will now have a Go `net/http` webserver running on `localhost:3000`.
You will now have a Go net/http webserver running on `localhost:3000`.
## Need Help?
If you have a question or feature request, [go ask the mailing list](https://groups.google.com/forum/#!forum/negroni-users). The GitHub issues for Negroni will be used exclusively for bug reports and pull requests.
## Is Negroni a Framework?
Negroni is **not** a framework. It is a middleware-focused library that is
designed to work directly with net/http.
Negroni is **not** a framework. It is a library that is designed to work directly with net/http.
## Routing?
Negroni is BYOR (Bring your own Router). The Go community already has a number of great http routers available, Negroni tries to play well with all of them by fully supporting `net/http`. For instance, integrating with [Gorilla Mux](http://github.com/gorilla/mux) looks like so:
Negroni is BYOR (Bring your own Router). The Go community already has a number
of great http routers available, and Negroni tries to play well with all of them
by fully supporting `net/http`. For instance, integrating with [Gorilla Mux]
looks like so:
``` go
~~~ go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
@ -86,54 +64,47 @@ n.Use(Middleware3)
// router goes last
n.UseHandler(router)
http.ListenAndServe(":3001", n)
```
n.Run(":3000")
~~~
## `negroni.Classic()`
`negroni.Classic()` provides some default middleware that is useful for most applications:
`negroni.Classic()` provides some default middleware that is useful for most
applications:
* [`negroni.Recovery`](#recovery) - Panic Recovery Middleware.
* [`negroni.Logger`](#logger) - Request/Response Logger Middleware.
* [`negroni.Static`](#static) - Static File serving under the "public"
directory.
* `negroni.Recovery` - Panic Recovery Middleware.
* `negroni.Logging` - Request/Response Logging Middleware.
* `negroni.Static` - Static File serving under the "public" directory.
This makes it really easy to get started with some useful features from Negroni.
## Handlers
Negroni provides a bidirectional middleware flow. This is done through the `negroni.Handler` interface:
Negroni provides a bidirectional middleware flow. This is done through the
`negroni.Handler` interface:
``` go
~~~ go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
```
~~~
If a middleware hasn't already written to the `ResponseWriter`, it should call
the next `http.HandlerFunc` in the chain to yield to the next middleware
handler. This can be used for great good:
If a middleware hasn't already written to the ResponseWriter, it should call the next `http.HandlerFunc` in the chain to yield to the next middleware handler. This can be used for great good:
``` go
~~~ go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// do some stuff before
next(rw, r)
// do some stuff after
}
```
~~~
And you can map it to the handler chain with the `Use` function:
``` go
~~~ go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
```
~~~
You can also map plain old `http.Handler`s:
``` go
~~~ go
n := negroni.New()
mux := http.NewServeMux()
@ -141,301 +112,70 @@ mux := http.NewServeMux()
n.UseHandler(mux)
http.ListenAndServe(":3000", n)
```
n.Run(":3000")
~~~
## `Run()`
Negroni has a convenience function called `Run`. `Run` takes an addr string identical to [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe).
Negroni has a convenience function called `Run`. `Run` takes an addr string
identical to [`http.ListenAndServe`](https://godoc.org/net/http#ListenAndServe).
<!-- { "interrupt": true } -->
``` go
package main
import (
"github.com/urfave/negroni"
)
func main() {
n := negroni.Classic()
n.Run(":8080")
}
```
In general, you will want to use `net/http` methods and pass `negroni` as a
`Handler`, as this is more flexible, e.g.:
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"log"
"net/http"
"time"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic() // Includes some default middlewares
n.UseHandler(mux)
s := &http.Server{
Addr: ":8080",
Handler: n,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
}
log.Fatal(s.ListenAndServe())
}
```
~~~ go
n := negroni.Classic()
// ...
log.Fatal(http.ListenAndServe(":8080", n))
~~~
## Route Specific Middleware
If you have a route group of routes that need specific middleware to be executed, you can simply create a new Negroni instance and use it as your route handler.
If you have a route group of routes that need specific middleware to be
executed, you can simply create a new Negroni instance and use it as your route
handler.
``` go
~~~ go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// add admin routes here
// Create a new negroni for the admin middleware
router.PathPrefix("/admin").Handler(negroni.New(
router.Handle("/admin", negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
```
If you are using [Gorilla Mux], here is an example using a subrouter:
``` go
router := mux.NewRouter()
subRouter := mux.NewRouter().PathPrefix("/subpath").Subrouter().StrictSlash(true)
subRouter.HandleFunc("/", someSubpathHandler) // "/subpath/"
subRouter.HandleFunc("/:id", someSubpathHandler) // "/subpath/:id"
// "/subpath" is necessary to ensure the subRouter and main router linkup
router.PathPrefix("/subpath").Handler(negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(subRouter),
))
```
## Bundled Middleware
### Static
This middleware will serve files on the filesystem. If the files do not exist,
it proxies the request to the next middleware. If you want the requests for
non-existent files to return a `404 File Not Found` to the user you should look
at using [http.FileServer](https://golang.org/pkg/net/http/#FileServer) as
a handler.
Example:
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
// Example of using a http.FileServer if you want "server-like" rather than "middleware" behavior
// mux.Handle("/public", http.FileServer(http.Dir("/home/public")))
n := negroni.New()
n.Use(negroni.NewStatic(http.Dir("/tmp")))
n.UseHandler(mux)
http.ListenAndServe(":3002", n)
}
```
Will serve files from the `/tmp` directory first, but proxy calls to the next
handler if the request does not match a file on the filesystem.
### Recovery
This middleware catches `panic`s and responds with a `500` response code. If
any other middleware has written a response code or body, this middleware will
fail to properly send a 500 to the client, as the client has already received
the HTTP response code. Additionally, an `ErrorHandlerFunc` can be attached
to report 500's to an error reporting service such as Sentry or Airbrake.
Example:
<!-- { "interrupt": true } -->
``` go
package main
import (
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
panic("oh no")
})
n := negroni.New()
n.Use(negroni.NewRecovery())
n.UseHandler(mux)
http.ListenAndServe(":3003", n)
}
```
Will return a `500 Internal Server Error` to each request. It will also log the
stack traces as well as print the stack trace to the requester if `PrintStack`
is set to `true` (the default).
Example with error handler:
``` go
package main
import (
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
panic("oh no")
})
n := negroni.New()
recovery := negroni.NewRecovery()
recovery.ErrorHandlerFunc = reportToSentry
n.Use(recovery)
n.UseHandler(mux)
http.ListenAndServe(":3003", n)
}
func reportToSentry(error interface{}) {
// write code here to report error to Sentry
}
```
## Logger
This middleware logs each incoming request and response.
Example:
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.New()
n.Use(negroni.NewLogger())
n.UseHandler(mux)
http.ListenAndServe(":3004", n)
}
```
Will print a log similar to:
```
[negroni] Started GET /
[negroni] Completed 200 OK in 145.446µs
```
on each request.
~~~
## Third Party Middleware
Here is a current list of Negroni compatible middlware. Feel free to put up a PR
linking your middleware if you have built one:
Here is a current list of Negroni compatible middlware. Feel free to put up a PR linking your middleware if you have built one:
| Middleware | Author | Description |
| -----------|--------|-------------|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Data binding from HTTP requests into structs |
| [cloudwatch](https://github.com/cvillecsteele/negroni-cloudwatch) | [Colin Steele](https://github.com/cvillecsteele) | AWS cloudwatch metrics middleware |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) support |
| [delay](https://github.com/jeffbmartinez/delay) | [Jeff Martinez](https://github.com/jeffbmartinez) | Add delays/latency to endpoints. Useful when testing effects of high latency |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
| [Graceful](https://github.com/tylerb/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | GZIP response compression |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Middleware checks for a JWT on the `Authorization` header on incoming requests and decodes it|
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 middleware |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Generate TinySVG, HTML and CSS on the fly |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, users and permissions |
| [prometheus](https://github.com/zbindenren/negroni-prometheus) | [Rene Zbinden](https://github.com/zbindenren) | Easily create metrics endpoint for the [prometheus](http://prometheus.io) instrumentation tool |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Render JSON, XML and HTML templates |
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | Secure authentication for REST API endpoints |
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Middleware that implements a few quick security wins |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Middleware checks for a JWT on the `Authorization` header on incoming requests and decodes it|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Data binding from HTTP requests into structs |
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Render JSON, XML and HTML templates |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | GZIP response compression |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 middleware |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session Management |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | Store information about your web application (response time, etc.) |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC authentication middleware |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, users and permissions |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Generate TinySVG, HTML and CSS on the fly |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) support |
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | Middleware that assigns a random X-Request-Id header to each request |
| [mgo session](https://github.com/joeljames/nigroni-mgo-session) | [Joel James](https://github.com/joeljames) | Middleware that handles creating and closing mgo sessions per request |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC authentication middleware |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | Store information about your web application (response time, etc.) |
## Examples
[Alexander Rødseth](https://github.com/xyproto) created
[mooseware](https://github.com/xyproto/mooseware), a skeleton for writing a
Negroni middleware handler.
[Alexander Rødseth](https://github.com/xyproto) created [mooseware](https://github.com/xyproto/mooseware), a skeleton for writing a Negroni middleware handler.
## Live code reload?
[gin](https://github.com/urfave/gin) and
[fresh](https://github.com/pilu/fresh) both live reload negroni apps.
[gin](https://github.com/codegangsta/gin) and [fresh](https://github.com/pilu/fresh) both live reload negroni apps.
## Essential Reading for Beginners of Go & Negroni
* [Using a Context to pass information from middleware to end handler](http://elithrar.github.io/article/map-string-interface/)
* [Understanding middleware](https://mattstauffer.co/blog/laravel-5.0-middleware-filter-style)
* [Understanding middleware](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
## About
Negroni is obsessively designed by none other than the [Code
Gangsta](https://codegangsta.io/)
[Gorilla Mux]: https://github.com/gorilla/mux
[`http.FileSystem`]: https://godoc.org/net/http#FileSystem
Negroni is obsessively designed by none other than the [Code Gangsta](http://codegangsta.io/)

View File

@ -2,12 +2,12 @@
//
// If you like the idea of Martini, but you think it contains too much magic, then Negroni is a great fit.
//
// For a full guide visit http://github.com/urfave/negroni
// For a full guide visit http://github.com/codegangsta/negroni
//
// package main
//
// import (
// "github.com/urfave/negroni"
// "github.com/codegangsta/negroni"
// "net/http"
// "fmt"
// )

View File

@ -75,10 +75,6 @@ func (n *Negroni) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
// Use adds a Handler onto the middleware stack. Handlers are invoked in the order they are added to a Negroni.
func (n *Negroni) Use(handler Handler) {
if handler == nil {
panic("handler cannot be nil")
}
n.handlers = append(n.handlers, handler)
n.middleware = build(n.handlers)
}

View File

@ -51,7 +51,7 @@ func TestNegroniServeHTTP(t *testing.T) {
expect(t, response.Code, http.StatusBadRequest)
}
// Ensures that a Negroni middleware chain
// Ensures that a Negroni middleware chain
// can correctly return all of its handlers.
func TestHandlers(t *testing.T) {
response := httptest.NewRecorder()
@ -63,7 +63,7 @@ func TestHandlers(t *testing.T) {
rw.WriteHeader(http.StatusOK)
}))
// Expects the length of handlers to be exactly 1
// Expects the length of handlers to be exactly 1
// after adding exactly one handler to the middleware chain
handlers = n.Handlers()
expect(t, 1, len(handlers))
@ -72,16 +72,4 @@ func TestHandlers(t *testing.T) {
// exactly the same as the one that was registered earlier
handlers[0].ServeHTTP(response, (*http.Request)(nil), nil)
expect(t, response.Code, http.StatusOK)
}
func TestNegroni_Use_Nil(t *testing.T) {
defer func() {
err := recover()
if err == nil {
t.Errorf("Expected negroni.Use(nil) to panic, but it did not")
}
}()
n := New()
n.Use(nil)
}
}

View File

@ -10,11 +10,10 @@ import (
// Recovery is a Negroni middleware that recovers from any panics and writes a 500 if there was one.
type Recovery struct {
Logger *log.Logger
PrintStack bool
ErrorHandlerFunc func(interface{})
StackAll bool
StackSize int
Logger *log.Logger
PrintStack bool
StackAll bool
StackSize int
}
// NewRecovery returns a new instance of Recovery
@ -30,10 +29,6 @@ func NewRecovery() *Recovery {
func (rec *Recovery) ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
defer func() {
if err := recover(); err != nil {
if rw.Header().Get("Content-Type") == "" {
rw.Header().Set("Content-Type", "text/plain; charset=utf-8")
}
rw.WriteHeader(http.StatusInternalServerError)
stack := make([]byte, rec.StackSize)
stack = stack[:runtime.Stack(stack, rec.StackAll)]
@ -44,10 +39,6 @@ func (rec *Recovery) ServeHTTP(rw http.ResponseWriter, r *http.Request, next htt
if rec.PrintStack {
fmt.Fprintf(rw, f, err, stack)
}
if rec.ErrorHandlerFunc != nil {
rec.ErrorHandlerFunc(err)
}
}
}()

View File

@ -22,24 +22,7 @@ func TestRecovery(t *testing.T) {
panic("here is a panic!")
}))
n.ServeHTTP(recorder, (*http.Request)(nil))
expect(t, recorder.Header().Get("Content-Type"), "text/plain; charset=utf-8")
expect(t, recorder.Code, http.StatusInternalServerError)
refute(t, recorder.Body.Len(), 0)
refute(t, len(buff.String()), 0)
}
func TestRecovery_noContentTypeOverwrite(t *testing.T) {
recorder := httptest.NewRecorder()
rec := NewRecovery()
rec.Logger = log.New(bytes.NewBuffer([]byte{}), "[negroni] ", 0)
n := New()
n.Use(rec)
n.UseHandler(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) {
res.Header().Set("Content-Type", "application/javascript; charset=utf-8")
panic("here is a panic!")
}))
n.ServeHTTP(recorder, (*http.Request)(nil))
expect(t, recorder.Header().Get("Content-Type"), "application/javascript; charset=utf-8")
}

View File

@ -13,8 +13,7 @@ import (
type ResponseWriter interface {
http.ResponseWriter
http.Flusher
// Status returns the status code of the response or 200 if the response has
// not been written (as this is the default response code in net/http)
// Status returns the status code of the response or 0 if the response has not been written.
Status() int
// Written returns whether or not the ResponseWriter has been written.
Written() bool
@ -29,9 +28,7 @@ type beforeFunc func(ResponseWriter)
// NewResponseWriter creates a ResponseWriter that wraps an http.ResponseWriter
func NewResponseWriter(rw http.ResponseWriter) ResponseWriter {
return &responseWriter{
ResponseWriter: rw,
}
return &responseWriter{rw, 0, 0, nil}
}
type responseWriter struct {

View File

@ -46,28 +46,6 @@ func (h *hijackableResponse) Hijack() (net.Conn, *bufio.ReadWriter, error) {
return nil, nil, nil
}
func TestResponseWriterBeforeWrite(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
expect(t, rw.Status(), 0)
expect(t, rw.Written(), false)
}
func TestResponseWriterBeforeFuncHasAccessToStatus(t *testing.T) {
var status int
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)
rw.Before(func(w ResponseWriter) {
status = w.Status()
})
rw.WriteHeader(http.StatusCreated)
expect(t, status, http.StatusCreated)
}
func TestResponseWriterWritingString(t *testing.T) {
rec := httptest.NewRecorder()
rw := NewResponseWriter(rec)

View File

@ -6,11 +6,7 @@ import (
"strings"
)
// Static is a middleware handler that serves static files in the given
// directory/filesystem. If the file does not exist on the filesystem, it
// passes along to the next middleware in the chain. If you desire "fileserver"
// type behavior where it returns a 404 for unfound files, you should consider
// using http.FileServer from the Go stdlib.
// Static is a middleware handler that serves static files in the given directory/filesystem.
type Static struct {
// Dir is the directory to serve static files from
Dir http.FileSystem

View File

@ -1,177 +0,0 @@
# Negroni [![GoDoc](https://godoc.org/github.com/urfave/negroni?status.svg)](http://godoc.org/github.com/urfave/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
Negroni ist ein Ansatz für eine idiomatische Middleware in Go. Sie ist klein, nicht-intrusiv und unterstützt die Nutzung von `net/http` Handlern.
Wenn Dir die Idee hinter [Martini](http://github.com/go-martini/martini) gefällt, aber Du denkst, es stecke zu viel Magie darin, dann ist Negroni eine passende Alternative.
## Wo fange ich an?
Nachdem Du Go installiert und den [GOPATH](http://golang.org/doc/code.html#GOPATH) eingerichtet hast, erstelle eine `.go`-Datei. Nennen wir sie `server.go`.
~~~ go
package main
import (
"github.com/urfave/negroni"
"net/http"
"fmt"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Willkommen auf der Homepage!")
})
n := negroni.Classic()
n.UseHandler(mux)
n.Run(":3000")
}
~~~
Installiere nun das Negroni Package (**go 1.1** und höher werden vorausgesetzt):
~~~
go get github.com/urfave/negroni
~~~
Dann starte Deinen Server:
~~~
go run server.go
~~~
Nun läuft ein `net/http`-Webserver von Go unter `localhost:3000`.
## Hilfe benötigt?
Wenn Du eine Frage hast oder Dir ein bestimmte Funktion wünscht, nutze die [Mailing Liste](https://groups.google.com/forum/#!forum/negroni-users). Issues auf Github werden ausschließlich für Bug Reports und Pull Requests genutzt.
## Ist Negroni ein Framework?
Negroni ist **kein** Framework. Es ist eine Bibliothek, geschaffen, um kompatibel mit `net/http` zu sein.
## Routing?
Negroni ist BYOR (Bring your own Router - Nutze Deinen eigenen Router). Die Go-Community verfügt bereits über eine Vielzahl von großartigen Routern. Negroni versucht möglichst alle zu unterstützen, indem es `net/http` vollständig unterstützt. Beispielsweise sieht eine Implementation mit [Gorilla Mux](http://github.com/gorilla/mux) folgendermaßen aus:
~~~ go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
n := negroni.New(Middleware1, Middleware2)
// Oder nutze eine Middleware mit der Use()-Funktion
n.Use(Middleware3)
// Der Router kommt als letztes
n.UseHandler(router)
n.Run(":3000")
~~~
## `negroni.Classic()`
`negroni.Classic()` stellt einige Standard-Middlewares bereit, die für die meisten Anwendungen von Nutzen ist:
* `negroni.Recovery` - Middleware für Panic Recovery .
* `negroni.Logging` - Anfrage/Rückmeldungs-Logging-Middleware.
* `negroni.Static` - Ausliefern von statischen Dateien unter dem "public" Verzeichnis.
Dies macht es wirklich einfach, mit den nützlichen Funktionen von Negroni zu starten.
## Handlers
Negroni stellt einen bidirektionalen Middleware-Flow bereit. Dies wird durch das `negroni.Handler`-Interface erreicht:
~~~ go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
~~~
Wenn eine Middleware nicht bereits den ResponseWriter genutzt hat, sollte sie die nächste `http.HandlerFunc` in der Verkettung von Middlewares aufrufen und diese ausführen. Das kann von großem Nutzen sein:
~~~ go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// Mache etwas vor dem Aufruf
next(rw, r)
// Mache etwas nach dem Aufruf
}
~~~
Und Du kannst eine Middleware durch die `Use`-Funktion der Verkettung von Middlewares zuordnen.
~~~ go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
~~~
Stattdessen kannst Du auch herkömmliche `http.Handler` zuordnen:
~~~ go
n := negroni.New()
mux := http.NewServeMux()
// Ordne Deine Routen zu
n.UseHandler(mux)
n.Run(":3000")
~~~
## `Run()`
Negroni hat eine nützliche Funktion namens `Run`. `Run` übernimmt eine Zeichenkette `addr` ähnlich wie [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe).
~~~ go
n := negroni.Classic()
// ...
log.Fatal(http.ListenAndServe(":8080", n))
~~~
## Routenspezifische Middleware
Wenn Du eine Gruppe von Routen hast, welche alle die gleiche Middleware ausführen müssen, kannst Du einfach eine neue Negroni-Instanz erstellen und sie als Route-Handler nutzen:
~~~ go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// Füge die Admin-Routen hier hinzu
// Erstelle eine neue Negroni-Instanz für die Admin-Middleware
router.Handle("/admin", negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
~~~
## Middlewares von Dritten
Hier ist eine aktuelle Liste von Middlewares, die kompatible mit Negroni sind. Tue Dir keinen Zwang an, Dich einzutragen, wenn Du selbst eine Middleware programmiert hast:
| Middleware | Autor | Beschreibung |
| -----------|--------|-------------|
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | Sichere Authentifikation für Endpunkte einer REST API |
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Eine Middleware mit ein paar nützlichen Sicherheitseinstellungen |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Eine Middleware die nach JWTs im `Authorization`-Feld des Header sucht und sie dekodiert.|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Data Binding von HTTP-Anfragen in Structs |
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-basierender Logger |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Rendere JSON, XML und HTML Vorlagen |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic Agent für die Go-Echtzeitumgebung |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | Kompression von HTTP-Rückmeldungen via GZIP |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 Middleware |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session Management |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, Benutzer und Berechtigungen |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Generiere TinySVG, HTML und CSS spontan |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) Unterstützung |
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | Eine Middleware die zufällige X-Request-Id-Header jedem Request anfügt |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC-basierte Middleware zur Authentifikation |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | Speichere wichtige Informationen über Deine Webanwendung (Reaktionszeit, etc.) |
## Beispiele
[Alexander Rødseth](https://github.com/xyproto) programmierte [mooseware](https://github.com/xyproto/mooseware), ein Grundgerüst zum Erstellen von Negroni Middleware-Handerln.
## Aktualisieren in Echtzeit?
[gin](https://github.com/urfave/gin) und [fresh](https://github.com/pilu/fresh) aktualisieren Deine Negroni-Anwendung automatisch.
## Unverzichbare Informationen für Go- & Negronineulinge
* [Nutze einen Kontext zum Übertragen von Middlewareinformationen an Handler (Englisch)](http://elithrar.github.io/article/map-string-interface/)
* [Middlewares verstehen (Englisch)](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
## Über das Projekt
Negroni wurde obsseziv von Niemand gerigeren als dem [Code Gangsta](http://codegangsta.io/) entwickelt.

View File

@ -1,374 +0,0 @@
# Negroni [![GoDoc](https://godoc.org/github.com/urfave/negroni?status.svg)](http://godoc.org/github.com/urfave/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61) [![codebeat](https://codebeat.co/badges/47d320b1-209e-45e8-bd99-9094bc5111e2)](https://codebeat.co/projects/github-com-urfave-negroni)
NegroniはGoによるWeb ミドルウェアへの慣用的なアプローチです。
軽量で押し付けがましい作法は無く、また`net/http`ハンドラの使用を推奨しています。
[Martini](https://github.com/go-martini/martini) の思想は気に入っているが、多くの魔法を含みすぎていると感じている方に、このNegroni はよく馴染むでしょう。
## はじめに
Goをインストールし、[GOPATH](http://golang.org/doc/code.html#GOPATH)の設定を行った後、.goファイルを作りましょう。これをserver.goとします。
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic() // Includes some default middlewares
n.UseHandler(mux)
http.ListenAndServe(":3000", n)
}
```
Negroni パッケージをインストールします (**NOTE**: &gt;= **go 1.1** 以上のバージョンが必要です):
```
go get github.com/urfave/negroni
```
インストールが完了したら、サーバーを起動しましょう。
```
go run server.go
```
すると、Go標準パッケージの `net/http` によるWebサーバーが`localhost:3000` で起動します。
## Negroni はWeb Application Framework ですか?
Negroni はrevel やmartini のような**フレームワークではありません**。 Negroniは `net/http`と直接結びついて動作する、ミドルウェアにフォーカスされたライブラリです。
## ルーティングの機能はありますか?
Negroni にルーティングの機能はありません。Goコミュニティには既に幾つかの優れたルーティングのライブラリが存在しており、Negroni は`net/http`と互換性のあるライブラリと協調して動作するように設計されています。例えば、[Gorilla Mux]と連携すると以下のようになります。
``` go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
n := negroni.New(Middleware1, Middleware2)
// Or use a middleware with the Use() function
n.Use(Middleware3)
// router goes last
n.UseHandler(router)
http.ListenAndServe(":3001", n)
```
## `negroni.Classic()`
`negroni.Classic()` は、多くのアプリケーションで役に立つミドルウェアを提供します
* [`negroni.Recovery`](#recovery) - Panic Recovery Middleware.
* [`negroni.Logger`](#logger) - Request/Response Logger Middleware.
* [`negroni.Static`](#static) - "public"ディレクトリの静的ファイルの処理
これらはNegroni の便利な機能を利用し始めるのをとても簡単にしてくれます。
## ハンドラ
Negroniは双方向のミドルウェアのフローを提供します。これは、`negroni.Handler` インターフェースを通じて行われます。
``` go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
```
ミドルウェアが既にResponseWriter に書き込み処理を行っていない場合、次のミドルウェア・ハンドラを動かすために、チェーン内の次のhttp.HandlerFuncを呼び出す必要があります。
``` go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// 前処理
next(rw, r)
// 後処理
}
```
この時、`MyMiddleware`を`Use` 関数によってハンドラチェーンに割り当てることができます。
``` go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
```
また、標準パッケージに備わっている`http.Handler`をハンドラチェーンに割り当てることもできます。
``` go
n := negroni.New()
mux := http.NewServeMux()
// ルーティングの処理
n.UseHandler(mux)
http.ListenAndServe(":3000", n)
```
## `Run()`
`Run` はアドレスの文字列を受け取り、[`http.ListenAndServe`](https://godoc.org/net/http#ListenAndServe)と同様にサーバーを起動します。
<!-- { "interrupt": true } -->
``` go
package main
import (
"github.com/urfave/negroni"
)
func main() {
n := negroni.Classic()
n.Run(":8080")
}
```
通常、`net/http` を使用し、ハンドラとして`Negroni`を渡すことになります。
以下により柔軟性のあるサンプルを示します。
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"log"
"net/http"
"time"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic()
n.UseHandler(mux)
s := &http.Server{
Addr: ":8080",
Handler: n,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
MaxHeaderBytes: 1 << 20,
}
log.Fatal(s.ListenAndServe())
}
```
## Route Specific Middleware
あるルーティンググループにおいて、実行する必要のあるミドルウェアがある場合、
新しいNegroni のインスタンスを作成し、ルーティングハンドラとして使用することができます。
``` go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// admin 関連のルーティングをココに記述
// admin のミドルウェアとして、Negroni インスタンスを作成
router.PathPrefix("/admin").Handler(negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
```
もし[Gorilla Mux]を利用する場合、サブルーターを使うサンプルは以下の通りです。
``` go
router := mux.NewRouter()
subRouter := mux.NewRouter().PathPrefix("/subpath").Subrouter().StrictSlash(true)
subRouter.HandleFunc("/", someSubpathHandler) // "/subpath/"
subRouter.HandleFunc("/:id", someSubpathHandler) // "/subpath/:id"
// "/subpath" is necessary to ensure the subRouter and main router linkup
router.PathPrefix("/subpath").Handler(negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(subRouter),
))
```
## Bundled Middleware
### Static
このミドルウェアは、ファイルシステム上のファイルをサーバーからクライアントに送信します。もし指定されたファイルが存在しない場合、次のミドルウェアにリクエストの処理を依頼します。存在しないファイルへのアクセスに対して`404 Not Found`を返したい場合、 [http.FileServer](https://golang.org/pkg/net/http/#FileServer) をハンドラとして利用すべきです。
Example:
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
// Example of using a http.FileServer if you want "server-like" rather than "middleware" behavior
// mux.Handle("/public", http.FileServer(http.Dir("/home/public")))
n := negroni.New()
n.Use(negroni.NewStatic(http.Dir("/tmp")))
n.UseHandler(mux)
http.ListenAndServe(":3002", n)
}
```
まず、`/tmp` ディレクトリからファイルをクライアントに送ろうとしますが、指定されたファイルがファイルシステム上に存在しない場合、プロキシは次のハンドラを呼び出します。
### Recovery
このミドルウェアは、`panic`を受け取ると、`500 internal Server Error` をレスポンスします。
もし他のミドルウェアが既に応答処理を行い、クライアントにHTTP レスポンスが帰っている場合、このミドルウェアは失敗します。
Example:
<!-- { "interrupt": true } -->
``` go
package main
import (
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
panic("oh no")
})
n := negroni.New()
n.Use(negroni.NewRecovery())
n.UseHandler(mux)
http.ListenAndServe(":3003", n)
}
```
上記のコードは、 `500 Internal Server Error` を各リクエストごとに返します。
また、スタックトレースをログに出力するだけでなく、PrintStack がtrue に設定されている場合、クライアントにスタックトレースを出力します。デフォルトでtrue に設定されています。)
## Logger
このミドルウェアは、送られてきた全てのリクエストとレスポンスを記録します。
Example:
<!-- { "interrupt": true } -->
``` go
package main
import (
"fmt"
"net/http"
"github.com/urfave/negroni"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.New()
n.Use(negroni.NewLogger())
n.UseHandler(mux)
http.ListenAndServe(":3004", n)
}
```
各リクエストごとに、以下のようなログが出力されます。
```
[negroni] Started GET /
[negroni] Completed 200 OK in 145.446µs
```
## サードパーティ製のミドルウェア
Negroni と互換性のあるミドルウェアの一覧です。あなたが作ったミドルウェアをここに追加してもらっても構いませんPRを送って下さい
**注意**: ここの一覧は古くなっている可能性があります。英語版のREADME.md を適宜参照して下さい。
| ミドルウェア名 | 作者 | 概要 |
| -----------|--------|-------------|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | Data binding from HTTP requests into structs |
| [cloudwatch](https://github.com/cvillecsteele/negroni-cloudwatch) | [Colin Steele](https://github.com/cvillecsteele) | AWS cloudwatch metrics middleware |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) support |
| [delay](https://github.com/jeffbmartinez/delay) | [Jeff Martinez](https://github.com/jeffbmartinez) | Add delays/latency to endpoints. Useful when testing effects of high latency |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
| [Graceful](https://github.com/tylerb/graceful) | [Tyler Bunnell](https://github.com/tylerb) | Graceful HTTP Shutdown |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | GZIP response compression |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Middleware checks for a JWT on the `Authorization` header on incoming requests and decodes it|
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | Logrus-based logger |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 middleware |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | Generate TinySVG, HTML and CSS on the fly |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies, users and permissions |
| [prometheus](https://github.com/zbindenren/negroni-prometheus) | [Rene Zbinden](https://github.com/zbindenren) | Easily create metrics endpoint for the [prometheus](http://prometheus.io) instrumentation tool |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | Render JSON, XML and HTML templates |
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | Secure authentication for REST API endpoints |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Middleware that implements a few quick security wins |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session Management |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | Store information about your web application (response time, etc.) |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC authentication middleware |
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | Middleware that assigns a random X-Request-Id header to each request |
## Examples
[Alexander Rødseth](https://github.com/xyproto) created
[mooseware](https://github.com/xyproto/mooseware), a skeleton for writing a
Negroni middleware handler.
## Live code reload?
[gin](https://github.com/urfave/gin) and
[fresh](https://github.com/pilu/fresh) both live reload negroni apps.
## Go や Negroni の初心者にオススメの参考資料(英語)
* [Using a Context to pass information from middleware to end handler](http://elithrar.github.io/article/map-string-interface/)
* [Understanding middleware](https://mattstauffer.co/blog/laravel-5.0-middleware-filter-style)
## Negroni について
Negroni は他ならぬ[Code
Gangsta](https://codegangsta.io/)によって異常なまでにデザインされた素晴らしいライブラリです。
[Gorilla Mux]: https://github.com/gorilla/mux
[`http.FileSystem`]: https://godoc.org/net/http#FileSystem

View File

@ -1,4 +1,4 @@
# Negroni [![GoDoc](https://godoc.org/github.com/urfave/negroni?status.svg)](http://godoc.org/github.com/urfave/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
# Negroni [![GoDoc](https://godoc.org/github.com/codegangsta/negroni?status.svg)](http://godoc.org/github.com/codegangsta/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
Negroni é uma abordagem idiomática para middleware web em Go. É pequeno, não intrusivo, e incentiva uso da biblioteca `net/http`.
@ -12,7 +12,7 @@ Depois de instalar Go e definir seu [GOPATH](http://golang.org/doc/code.html#GOP
package main
import (
"github.com/urfave/negroni"
"github.com/codegangsta/negroni"
"net/http"
"fmt"
)
@ -31,7 +31,7 @@ func main() {
Depois instale o pacote Negroni (**go 1.1** ou superior)
~~~
go get github.com/urfave/negroni
go get github.com/codegangsta/negroni
~~~
Depois execute seu servidor:
@ -159,7 +159,7 @@ Aqui está uma lista atual de Middleware Compatíveis com Negroni. Sinta se livr
[Alexander Rødseth](https://github.com/xyproto) criou [mooseware](https://github.com/xyproto/mooseware), uma estrutura para escrever um handler middleware Negroni.
## Servidor com autoreload?
[gin](https://github.com/urfave/gin) e [fresh](https://github.com/pilu/fresh) são aplicativos para autoreload do Negroni.
[gin](https://github.com/codegangsta/gin) e [fresh](https://github.com/pilu/fresh) são aplicativos para autoreload do Negroni.
## Leitura Essencial para Iniciantes em Go & Negroni
* [Usando um contexto para passar informação de um middleware para o manipulador final](http://elithrar.github.io/article/map-string-interface/)

View File

@ -1,182 +0,0 @@
# Negroni [![GoDoc](https://godoc.org/github.com/urfave/negroni?status.svg)](http://godoc.org/github.com/urfave/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
在Go语言里Negroni 是一个很地道的 web 中间件,它是微型,非嵌入式,并鼓励使用原生 `net/http` 处理器的库。
如果你用过并喜欢 [Martini](http://github.com/go-martini/martini) 框架,但又不想框架中有太多魔幻性的特征,那 Negroni 就是你的菜了,相信它非常适合你。
语言翻译:
* [Português Brasileiro (pt_BR)](translations/README_pt_br.md)
* [简体中文 (zh_CN)](translations/README_zh_cn.md)
## 入门指导
当安装了 Go 语言并设置好了 [GOPATH](http://golang.org/doc/code.html#GOPATH) 后,新建你第一个`.go` 文件,我们叫它 `server.go` 吧。
~~~ go
package main
import (
"github.com/urfave/negroni"
"net/http"
"fmt"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic()
n.UseHandler(mux)
n.Run(":3000")
}
~~~
然后安装 Negroni 包(它依赖 **Go 1.1** 或更高的版本):
~~~
go get github.com/urfave/negroni
~~~
然后运行刚建好的 server.go 文件:
~~~
go run server.go
~~~
这时一个 Go `net/http` Web 服务器就跑在 `localhost:3000` 上,使用浏览器打开 `localhost:3000` 可以看到输出结果。
## 需要你的贡献
如果你有问题或新想法,请到[邮件群组](https://groups.google.com/forum/#!forum/negroni-users)里反馈GitHub issues 是专门给提交 bug 报告和 pull 请求用途的,欢迎你的参与。
## Negroni 是一个框架吗?
Negroni **不**是一个框架,它是为了方便使用 `net/http` 而设计的一个库而已。
## 路由呢?
Negroni 没有带路由功能,使用 Negroni 时,需要找一个适合你的路由。不过好在 Go 社区里已经有相当多可用的路由Negroni 更喜欢和那些完全支持 `net/http` 库的路由组合使用,比如,结合 [Gorilla Mux](http://github.com/gorilla/mux) 使用像这样:
~~~ go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
n := negroni.New(Middleware1, Middleware2)
// Or use a middleware with the Use() function
n.Use(Middleware3)
// router goes last
n.UseHandler(router)
n.Run(":3000")
~~~
## `negroni.Classic()` 经典实例
`negroni.Classic()` 提供一些默认的中间件,这些中间件在多数应用都很有用。
* `negroni.Recovery` - 异常(恐慌)恢复中间件
* `negroni.Logging` - 请求 / 响应 log 日志中间件
* `negroni.Static` - 静态文件处理中间件,默认目录在 "public" 下.
`negroni.Classic()` 让你一开始就非常容易上手 Negroni ,并使用它那些通用的功能。
## Handlers (处理器)
Negroni 提供双向的中间件机制,这个特征很棒,都是得益于 `negroni.Handler` 这个接口。
~~~ go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
~~~
如果一个中间件没有写入 ResponseWriter 响应,它会在中间件链里调用下一个 `http.HandlerFunc` 执行下去, 它可以这么优雅的使用。如下:
~~~ go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// do some stuff before
next(rw, r)
// do some stuff after
}
~~~
你也可以用 `Use` 函数把这些 `http.Handler` 处理器引进到处理器链上来:
~~~ go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
~~~
你还可以使用 `http.Handler`(s) 把 `http.Handler` 处理器引进来。
~~~ go
n := negroni.New()
mux := http.NewServeMux()
// map your routes
n.UseHandler(mux)
n.Run(":3000")
~~~
## `Run()`
Negroni 提供一个很好用的函数叫 `Run` ,把地址字符串传人该函数,即可实现很地道的 [http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe) 函数功能了。
~~~ go
n := negroni.Classic()
// ...
log.Fatal(http.ListenAndServe(":8080", n))
~~~
## 特定路由中间件
如果你需要群组路由功能,需要借助特定的路由中间件完成,做法很简单,只需建立一个新 Negroni 实例,传人路由处理器里即可。
~~~ go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// add admin routes here
// Create a new negroni for the admin middleware
router.Handle("/admin", negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
~~~
## 第三方中间件
以下的兼容 Negroni 的中间件列表,如果你也有兼容 Negroni 的中间件,可以提交到这个列表来交换链接,我们很乐意做这样有益的事情。
| 中间件 | 作者 | 描述 |
| -------------|------------|-------------|
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | REST API 接口的安全认证 |
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | 优雅关闭 HTTP 的中间件 |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | Middleware that implements a few quick security wins |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | Middleware checks for a JWT on the `Authorization` header on incoming requests and decodes it|
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | HTTP 请求数据注入到 structs 实体|
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | 基于 Logrus-based logger 日志 |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | 渲染 JSON, XML and HTML 中间件 |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | New Relic agent for Go runtime |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | 响应流 GZIP 压缩 |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2 中间件 |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session 会话管理 |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies 用户和权限 |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | 快速生成 TinySVG HTML and CSS 中间件 |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) (CORS) support |
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | 给每个请求指定一个随机 X-Request-Id 头的中间件 |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) 基于 HMAC 鉴权认证的中间件 |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | 检测 web 应用当前运行状态信息 (响应时间等等。) |
## 范例
[Alexander Rødseth](https://github.com/xyproto) 创建的 [mooseware](https://github.com/xyproto/mooseware) 是一个写兼容 Negroni 中间件的处理器骨架的范例。
## 即时编译
[gin](https://github.com/urfave/gin) 和 [fresh](https://github.com/pilu/fresh) 这两个应用是即时编译的 Negroni 工具,推荐用户开发的时候使用。
## Go & Negroni 初学者必读推荐
* [在中间件中使用上下文把消息传递给后端处理器](http://elithrar.github.io/article/map-string-interface/)
* [了解中间件](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
## 关于
Negroni 由 [Code Gangsta](http://codegangsta.io/) 主导设计开发完成

View File

@ -1,177 +0,0 @@
# Negroni(尼格龍尼) [![GoDoc](https://godoc.org/github.com/urfave/negroni?status.svg)](http://godoc.org/github.com/urfave/negroni) [![wercker status](https://app.wercker.com/status/13688a4a94b82d84a0b8d038c4965b61/s "wercker status")](https://app.wercker.com/project/bykey/13688a4a94b82d84a0b8d038c4965b61)
尼格龍尼符合Go的web 中介器特性. 精簡、非侵入式、鼓勵使用 `net/http` Handler.
如果你喜歡[Martini](http://github.com/go-martini/martini),但覺得這其中包太多神奇的功能,那麼尼格龍尼會是你的最佳選擇。
## 入門
安裝完Go且設好[GOPATH](http://golang.org/doc/code.html#GOPATH),建立你的第一個`.go`檔。可以命名為`server.go`。
~~~ go
package main
import (
"github.com/urfave/negroni"
"net/http"
"fmt"
)
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, req *http.Request) {
fmt.Fprintf(w, "Welcome to the home page!")
})
n := negroni.Classic()
n.UseHandler(mux)
n.Run(":3000")
}
~~~
安裝尼格龍尼套件 (最低需求為**go 1.1**或更高版本):
~~~
go get github.com/urfave/negroni
~~~
執行伺服器:
~~~
go run server.go
~~~
你現在起了一個Go的net/http網頁伺服器在`localhost:3000`.
## 有問題?
如果你有問題或新功能建議,[到這郵件群組討論](https://groups.google.com/forum/#!forum/negroni-users)。尼格龍尼在GitHub上的issues專欄是專門用來回報bug跟pull requests。
## 尼格龍尼是個framework嗎?
尼格龍尼**不是**framework是個設計用來直接使用net/http的library。
## 路由?
尼格龍尼是BYOR (Bring your own Router帶給你自訂路由)。在Go社群已經有大量可用的http路由器, 尼格龍尼試著做好完全支援`net/http`,例如與[Gorilla Mux](http://github.com/gorilla/mux)整合:
~~~ go
router := mux.NewRouter()
router.HandleFunc("/", HomeHandler)
n := negroni.New(中介器1, 中介器2)
// Or use a 中介器 with the Use() function
n.Use(中介器3)
// router goes last
n.UseHandler(router)
n.Run(":3000")
~~~
## `negroni.Classic()`
`negroni.Classic()` 提供一些好用的預設中介器:
* `negroni.Recovery` - Panic 還原中介器
* `negroni.Logging` - Request/Response 紀錄中介器
* `negroni.Static` - 在"public"目錄下的靜態檔案服務
尼格龍尼的這些功能讓你開發變得很簡單。
## 處理器(Handlers)
尼格龍尼提供一個雙向中介器的機制,介面為`negroni.Handler`:
~~~ go
type Handler interface {
ServeHTTP(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc)
}
~~~
如果中介器沒有寫入ResponseWriter會呼叫通道裡面的下個`http.HandlerFunc`讓給中介處理器。可以被用來做良好的應用:
~~~ go
func MyMiddleware(rw http.ResponseWriter, r *http.Request, next http.HandlerFunc) {
// 在這之前做一些事
next(rw, r)
// 在這之後做一些事
}
~~~
然後你可以透過`Use`函數對應到處理器的通道:
~~~ go
n := negroni.New()
n.Use(negroni.HandlerFunc(MyMiddleware))
~~~
你也可以應原始的舊`http.Handler`:
~~~ go
n := negroni.New()
mux := http.NewServeMux()
// map your routes
n.UseHandler(mux)
n.Run(":3000")
~~~
## `Run()`
尼格龍尼有一個很好用的函數`Run``Run`接收addr字串辨識[http.ListenAndServe](http://golang.org/pkg/net/http#ListenAndServe)。
~~~ go
n := negroni.Classic()
// ...
log.Fatal(http.ListenAndServe(":8080", n))
~~~
## 路由特有中介器
如果你有一群路由需要執行特別的中介器,你可以簡單的建立一個新的尼格龍尼實體當作路由處理器。
~~~ go
router := mux.NewRouter()
adminRoutes := mux.NewRouter()
// add admin routes here
// 為管理中介器建立一個新的尼格龍尼
router.Handle("/admin", negroni.New(
Middleware1,
Middleware2,
negroni.Wrap(adminRoutes),
))
~~~
## 第三方中介器
以下為目前尼格龍尼兼容的中介器清單。如果你自己做了一個中介器請自由放入你的中介器互換連結:
| 中介器 | 作者 | 說明 |
| -----------|--------|-------------|
| [RestGate](https://github.com/pjebs/restgate) | [Prasanga Siripala](https://github.com/pjebs) | REST API入口的安全認證 |
| [Graceful](https://github.com/stretchr/graceful) | [Tyler Bunnell](https://github.com/tylerb) | 優雅的HTTP關機 |
| [secure](https://github.com/unrolled/secure) | [Cory Jacobsen](https://github.com/unrolled) | 檢疫安全功能的中介器 |
| [JWT Middleware](https://github.com/auth0/go-jwt-middleware) | [Auth0](https://github.com/auth0) | 檢查JWT的中介器用來解析傳入請求的`Authorization` header |
| [binding](https://github.com/mholt/binding) | [Matt Holt](https://github.com/mholt) | 將HTTP請求轉到structs的資料綁定 |
| [logrus](https://github.com/meatballhat/negroni-logrus) | [Dan Buch](https://github.com/meatballhat) | 基於Logrus的紀錄器 |
| [render](https://github.com/unrolled/render) | [Cory Jacobsen](https://github.com/unrolled) | 渲染JSON、XML、HTML的樣板 |
| [gorelic](https://github.com/jingweno/negroni-gorelic) | [Jingwen Owen Ou](https://github.com/jingweno) | Go執行中的New Relic agent |
| [gzip](https://github.com/phyber/negroni-gzip) | [phyber](https://github.com/phyber) | GZIP回應壓縮 |
| [oauth2](https://github.com/goincremental/negroni-oauth2) | [David Bochenski](https://github.com/bochenski) | oAuth2中介器 |
| [sessions](https://github.com/goincremental/negroni-sessions) | [David Bochenski](https://github.com/bochenski) | Session管理 |
| [permissions2](https://github.com/xyproto/permissions2) | [Alexander Rødseth](https://github.com/xyproto) | Cookies與使用者權限 |
| [onthefly](https://github.com/xyproto/onthefly) | [Alexander Rødseth](https://github.com/xyproto) | 快速產生TinySVG、HTM、CSS |
| [cors](https://github.com/rs/cors) | [Olivier Poitrey](https://github.com/rs) | [Cross Origin Resource Sharing](http://www.w3.org/TR/cors/) 支援(CORS) |
| [xrequestid](https://github.com/pilu/xrequestid) | [Andrea Franz](https://github.com/pilu) | 在每個request指定一個隨機X-Request-Id header的中介器 |
| [VanGoH](https://github.com/auroratechnologies/vangoh) | [Taylor Wrobel](https://github.com/twrobel3) | Configurable [AWS-Style](http://docs.aws.amazon.com/AmazonS3/latest/dev/RESTAuthentication.html) HMAC 授權中介器 |
| [stats](https://github.com/thoas/stats) | [Florent Messa](https://github.com/thoas) | 儲存關於你的網頁應用資訊 (回應時間之類) |
## 範例
[mooseware](https://github.com/xyproto/mooseware)是用來寫尼格龍尼中介處理器的骨架,由[Alexander Rødseth](https://github.com/xyproto)建立。
## 即時程式重載?
[gin](https://github.com/urfave/gin)和[fresh](https://github.com/pilu/fresh)兩個尼格龍尼即時重載的應用。
## Go & 尼格龍尼初學者必讀
* [使用Context將資訊從中介器送到處理器](http://elithrar.github.io/article/map-string-interface/)
* [理解中介器](http://mattstauffer.co/blog/laravel-5.0-middleware-replacing-filters)
## 關於
尼格龍尼正是[Code Gangsta](http://codegangsta.io/)的執著設計。
譯者: Festum Qin (Festum@G.PL)

View File

@ -86,7 +86,7 @@ type Conn struct {
// New establishes a connection to the system bus and authenticates.
// Callers should call Close() when done with the connection.
func New() (*Conn, error) {
return NewConnection(func() (*dbus.Conn, error) {
return newConnection(func() (*dbus.Conn, error) {
return dbusAuthHelloConnection(dbus.SystemBusPrivate)
})
}
@ -95,7 +95,7 @@ func New() (*Conn, error) {
// authenticates. This can be used to connect to systemd user instances.
// Callers should call Close() when done with the connection.
func NewUserConnection() (*Conn, error) {
return NewConnection(func() (*dbus.Conn, error) {
return newConnection(func() (*dbus.Conn, error) {
return dbusAuthHelloConnection(dbus.SessionBusPrivate)
})
}
@ -104,7 +104,7 @@ func NewUserConnection() (*Conn, error) {
// This can be used for communicating with systemd without a dbus daemon.
// Callers should call Close() when done with the connection.
func NewSystemdConnection() (*Conn, error) {
return NewConnection(func() (*dbus.Conn, error) {
return newConnection(func() (*dbus.Conn, error) {
// We skip Hello when talking directly to systemd.
return dbusAuthConnection(func() (*dbus.Conn, error) {
return dbus.Dial("unix:path=/run/systemd/private")
@ -118,18 +118,13 @@ func (c *Conn) Close() {
c.sigconn.Close()
}
// NewConnection establishes a connection to a bus using a caller-supplied function.
// This allows connecting to remote buses through a user-supplied mechanism.
// The supplied function may be called multiple times, and should return independent connections.
// The returned connection must be fully initialised: the org.freedesktop.DBus.Hello call must have succeeded,
// and any authentication should be handled by the function.
func NewConnection(dialBus func() (*dbus.Conn, error)) (*Conn, error) {
sysconn, err := dialBus()
func newConnection(createBus func() (*dbus.Conn, error)) (*Conn, error) {
sysconn, err := createBus()
if err != nil {
return nil, err
}
sigconn, err := dialBus()
sigconn, err := createBus()
if err != nil {
sysconn.Close()
return nil, err

View File

@ -199,11 +199,6 @@ func (c *Conn) GetUnitProperty(unit string, propertyName string) (*Property, err
return c.getProperty(unit, "org.freedesktop.systemd1.Unit", propertyName)
}
// GetServiceProperty returns property for given service name and property name
func (c *Conn) GetServiceProperty(service string, propertyName string) (*Property, error) {
return c.getProperty(service, "org.freedesktop.systemd1.Service", propertyName)
}
// GetUnitTypeProperties returns the extra properties for a unit, specific to the unit type.
// Valid values for unitType: Service, Socket, Target, Device, Mount, Automount, Snapshot, Timer, Swap, Path, Slice, Scope
// return "dbus.Error: Unknown interface" if the unitType is not the correct type of the unit
@ -239,11 +234,12 @@ type UnitStatus struct {
JobPath dbus.ObjectPath // The job object path
}
type storeFunc func(retvalues ...interface{}) error
func (c *Conn) listUnitsInternal(f storeFunc) ([]UnitStatus, error) {
// ListUnits returns an array with all currently loaded units. Note that
// units may be known by multiple names at the same time, and hence there might
// be more unit names loaded than actual units behind them.
func (c *Conn) ListUnits() ([]UnitStatus, error) {
result := make([][]interface{}, 0)
err := f(&result)
err := c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnits", 0).Store(&result)
if err != nil {
return nil, err
}
@ -267,43 +263,15 @@ func (c *Conn) listUnitsInternal(f storeFunc) ([]UnitStatus, error) {
return status, nil
}
// ListUnits returns an array with all currently loaded units. Note that
// units may be known by multiple names at the same time, and hence there might
// be more unit names loaded than actual units behind them.
func (c *Conn) ListUnits() ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnits", 0).Store)
}
// ListUnitsFiltered returns an array with units filtered by state.
// It takes a list of units' statuses to filter.
func (c *Conn) ListUnitsFiltered(states []string) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnitsFiltered", 0, states).Store)
}
// ListUnitsByPatterns returns an array with units.
// It takes a list of units' statuses and names to filter.
// Note that units may be known by multiple names at the same time,
// and hence there might be more unit names loaded than actual units behind them.
func (c *Conn) ListUnitsByPatterns(states []string, patterns []string) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnitsByPatterns", 0, states, patterns).Store)
}
// ListUnitsByNames returns an array with units. It takes a list of units'
// names and returns an UnitStatus array. Comparing to ListUnitsByPatterns
// method, this method returns statuses even for inactive or non-existing
// units. Input array should contain exact unit names, but not patterns.
func (c *Conn) ListUnitsByNames(units []string) ([]UnitStatus, error) {
return c.listUnitsInternal(c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnitsByNames", 0, units).Store)
}
type UnitFile struct {
Path string
Type string
}
func (c *Conn) listUnitFilesInternal(f storeFunc) ([]UnitFile, error) {
// ListUnitFiles returns an array of all available units on disk.
func (c *Conn) ListUnitFiles() ([]UnitFile, error) {
result := make([][]interface{}, 0)
err := f(&result)
err := c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnitFiles", 0).Store(&result)
if err != nil {
return nil, err
}
@ -327,16 +295,6 @@ func (c *Conn) listUnitFilesInternal(f storeFunc) ([]UnitFile, error) {
return files, nil
}
// ListUnitFiles returns an array of all available units on disk.
func (c *Conn) ListUnitFiles() ([]UnitFile, error) {
return c.listUnitFilesInternal(c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnitFiles", 0).Store)
}
// ListUnitFilesByPatterns returns an array of all available units on disk matched the patterns.
func (c *Conn) ListUnitFilesByPatterns(states []string, patterns []string) ([]UnitFile, error) {
return c.listUnitFilesInternal(c.sysobj.Call("org.freedesktop.systemd1.Manager.ListUnitFilesByPatterns", 0, states, patterns).Store)
}
type LinkUnitFileChange EnableUnitFileChange
// LinkUnitFiles() links unit files (that are located outside of the

View File

@ -18,7 +18,6 @@ import (
"fmt"
"math/rand"
"os"
"path"
"path/filepath"
"reflect"
"testing"
@ -71,24 +70,6 @@ func linkUnit(target string, conn *Conn, t *testing.T) {
}
}
func getUnitStatus(units []UnitStatus, name string) *UnitStatus {
for _, u := range units {
if u.Name == name {
return &u
}
}
return nil
}
func getUnitFile(units []UnitFile, name string) *UnitFile {
for _, u := range units {
if path.Base(u.Path) == name {
return &u
}
}
return nil
}
// Ensure that basic unit starting and stopping works.
func TestStartStopUnit(t *testing.T) {
target := "start-stop.service"
@ -111,11 +92,18 @@ func TestStartStopUnit(t *testing.T) {
units, err := conn.ListUnits()
unit := getUnitStatus(units, target)
var unit *UnitStatus
for _, u := range units {
if u.Name == target {
unit = &u
}
}
if unit == nil {
t.Fatalf("Test unit not found in list")
} else if unit.ActiveState != "active" {
}
if unit.ActiveState != "active" {
t.Fatalf("Test unit not active")
}
@ -130,169 +118,18 @@ func TestStartStopUnit(t *testing.T) {
units, err = conn.ListUnits()
unit = getUnitStatus(units, target)
unit = nil
for _, u := range units {
if u.Name == target {
unit = &u
}
}
if unit != nil {
t.Fatalf("Test unit found in list, should be stopped")
}
}
// Ensure that ListUnitsByNames works.
func TestListUnitsByNames(t *testing.T) {
target1 := "systemd-journald.service"
target2 := "unexisting.service"
conn := setupConn(t)
units, err := conn.ListUnitsByNames([]string{target1, target2})
if err != nil {
t.Skip(err)
}
unit := getUnitStatus(units, target1)
if unit == nil {
t.Fatalf("%s unit not found in list", target1)
} else if unit.ActiveState != "active" {
t.Fatalf("%s unit should be active but it is %s", target1, unit.ActiveState)
}
unit = getUnitStatus(units, target2)
if unit == nil {
t.Fatalf("Unexisting test unit not found in list")
} else if unit.ActiveState != "inactive" {
t.Fatalf("Test unit should be inactive")
}
}
// Ensure that ListUnitsByPatterns works.
func TestListUnitsByPatterns(t *testing.T) {
target1 := "systemd-journald.service"
target2 := "unexisting.service"
conn := setupConn(t)
units, err := conn.ListUnitsByPatterns([]string{}, []string{"systemd-journald*", target2})
if err != nil {
t.Skip(err)
}
unit := getUnitStatus(units, target1)
if unit == nil {
t.Fatalf("%s unit not found in list", target1)
} else if unit.ActiveState != "active" {
t.Fatalf("Test unit should be active")
}
unit = getUnitStatus(units, target2)
if unit != nil {
t.Fatalf("Unexisting test unit found in list")
}
}
// Ensure that ListUnitsFiltered works.
func TestListUnitsFiltered(t *testing.T) {
target := "systemd-journald.service"
conn := setupConn(t)
units, err := conn.ListUnitsFiltered([]string{"active"})
if err != nil {
t.Fatal(err)
}
unit := getUnitStatus(units, target)
if unit == nil {
t.Fatalf("%s unit not found in list", target)
} else if unit.ActiveState != "active" {
t.Fatalf("Test unit should be active")
}
units, err = conn.ListUnitsFiltered([]string{"inactive"})
if err != nil {
t.Fatal(err)
}
unit = getUnitStatus(units, target)
if unit != nil {
t.Fatalf("Inactive unit should not be found in list")
}
}
// Ensure that ListUnitFilesByPatterns works.
func TestListUnitFilesByPatterns(t *testing.T) {
target1 := "systemd-journald.service"
target2 := "exit.target"
conn := setupConn(t)
units, err := conn.ListUnitFilesByPatterns([]string{"static"}, []string{"systemd-journald*", target2})
if err != nil {
t.Skip(err)
}
unit := getUnitFile(units, target1)
if unit == nil {
t.Fatalf("%s unit not found in list", target1)
} else if unit.Type != "static" {
t.Fatalf("Test unit file should be static")
}
units, err = conn.ListUnitFilesByPatterns([]string{"disabled"}, []string{"systemd-journald*", target2})
if err != nil {
t.Fatal(err)
}
unit = getUnitFile(units, target2)
if unit == nil {
t.Fatalf("%s unit not found in list", target2)
} else if unit.Type != "disabled" {
t.Fatalf("%s unit file should be disabled", target2)
}
}
func TestListUnitFiles(t *testing.T) {
target1 := "systemd-journald.service"
target2 := "exit.target"
conn := setupConn(t)
units, err := conn.ListUnitFiles()
if err != nil {
t.Fatal(err)
}
unit := getUnitFile(units, target1)
if unit == nil {
t.Fatalf("%s unit not found in list", target1)
} else if unit.Type != "static" {
t.Fatalf("Test unit file should be static")
}
unit = getUnitFile(units, target2)
if unit == nil {
t.Fatalf("%s unit not found in list", target2)
} else if unit.Type != "disabled" {
t.Fatalf("%s unit file should be disabled", target2)
}
}
// Enables a unit and then immediately tears it down
func TestEnableDisableUnit(t *testing.T) {
target := "enable-disable.service"
@ -393,28 +230,6 @@ func TestGetUnitPropertiesRejectsInvalidName(t *testing.T) {
}
}
// TestGetServiceProperty reads the `systemd-udevd.service` which should exist
// on all systemd systems and ensures that one of its property is valid.
func TestGetServiceProperty(t *testing.T) {
conn := setupConn(t)
service := "systemd-udevd.service"
prop, err := conn.GetServiceProperty(service, "Type")
if err != nil {
t.Fatal(err)
}
if prop.Name != "Type" {
t.Fatal("unexpected property name")
}
value := prop.Value.Value().(string)
if value != "notify" {
t.Fatal("unexpected property value")
}
}
// TestSetUnitProperties changes a cgroup setting on the `tmp.mount`
// which should exist on all systemd systems and ensures that the
// property was set.
@ -461,11 +276,18 @@ func TestStartStopTransientUnit(t *testing.T) {
units, err := conn.ListUnits()
unit := getUnitStatus(units, target)
var unit *UnitStatus
for _, u := range units {
if u.Name == target {
unit = &u
}
}
if unit == nil {
t.Fatalf("Test unit not found in list")
} else if unit.ActiveState != "active" {
}
if unit.ActiveState != "active" {
t.Fatalf("Test unit not active")
}
@ -480,7 +302,12 @@ func TestStartStopTransientUnit(t *testing.T) {
units, err = conn.ListUnits()
unit = getUnitStatus(units, target)
unit = nil
for _, u := range units {
if u.Name == target {
unit = &u
}
}
if unit != nil {
t.Fatalf("Test unit found in list, should be stopped")

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// +build ignore
// Activation example used by the activation unit tests.
package main

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// +build ignore
package main
import (

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// +build ignore
// Activation example used by the activation unit tests.
package main

View File

@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// +build ignore
// Activation example used by the activation unit tests.
package main

View File

@ -90,7 +90,6 @@ func Send(message string, priority Priority, vars map[string]string) error {
if err != nil {
return journalError(err.Error())
}
defer file.Close()
_, err = io.Copy(file, data)
if err != nil {
return journalError(err.Error())

View File

@ -24,236 +24,31 @@
// [1] http://www.freedesktop.org/software/systemd/man/sd-journal.html
package sdjournal
// #include <systemd/sd-journal.h>
// #include <stdlib.h>
// #include <syslog.h>
//
// int
// my_sd_journal_open(void *f, sd_journal **ret, int flags)
// {
// int (*sd_journal_open)(sd_journal **, int);
//
// sd_journal_open = f;
// return sd_journal_open(ret, flags);
// }
//
// int
// my_sd_journal_open_directory(void *f, sd_journal **ret, const char *path, int flags)
// {
// int (*sd_journal_open_directory)(sd_journal **, const char *, int);
//
// sd_journal_open_directory = f;
// return sd_journal_open_directory(ret, path, flags);
// }
//
// void
// my_sd_journal_close(void *f, sd_journal *j)
// {
// int (*sd_journal_close)(sd_journal *);
//
// sd_journal_close = f;
// sd_journal_close(j);
// }
//
// int
// my_sd_journal_get_usage(void *f, sd_journal *j, uint64_t *bytes)
// {
// int (*sd_journal_get_usage)(sd_journal *, uint64_t *);
//
// sd_journal_get_usage = f;
// return sd_journal_get_usage(j, bytes);
// }
//
// int
// my_sd_journal_add_match(void *f, sd_journal *j, const void *data, size_t size)
// {
// int (*sd_journal_add_match)(sd_journal *, const void *, size_t);
//
// sd_journal_add_match = f;
// return sd_journal_add_match(j, data, size);
// }
//
// int
// my_sd_journal_add_disjunction(void *f, sd_journal *j)
// {
// int (*sd_journal_add_disjunction)(sd_journal *);
//
// sd_journal_add_disjunction = f;
// return sd_journal_add_disjunction(j);
// }
//
// int
// my_sd_journal_add_conjunction(void *f, sd_journal *j)
// {
// int (*sd_journal_add_conjunction)(sd_journal *);
//
// sd_journal_add_conjunction = f;
// return sd_journal_add_conjunction(j);
// }
//
// void
// my_sd_journal_flush_matches(void *f, sd_journal *j)
// {
// int (*sd_journal_flush_matches)(sd_journal *);
//
// sd_journal_flush_matches = f;
// sd_journal_flush_matches(j);
// }
//
// int
// my_sd_journal_next(void *f, sd_journal *j)
// {
// int (*sd_journal_next)(sd_journal *);
//
// sd_journal_next = f;
// return sd_journal_next(j);
// }
//
// int
// my_sd_journal_next_skip(void *f, sd_journal *j, uint64_t skip)
// {
// int (*sd_journal_next_skip)(sd_journal *, uint64_t);
//
// sd_journal_next_skip = f;
// return sd_journal_next_skip(j, skip);
// }
//
// int
// my_sd_journal_previous(void *f, sd_journal *j)
// {
// int (*sd_journal_previous)(sd_journal *);
//
// sd_journal_previous = f;
// return sd_journal_previous(j);
// }
//
// int
// my_sd_journal_previous_skip(void *f, sd_journal *j, uint64_t skip)
// {
// int (*sd_journal_previous_skip)(sd_journal *, uint64_t);
//
// sd_journal_previous_skip = f;
// return sd_journal_previous_skip(j, skip);
// }
//
// int
// my_sd_journal_get_data(void *f, sd_journal *j, const char *field, const void **data, size_t *length)
// {
// int (*sd_journal_get_data)(sd_journal *, const char *, const void **, size_t *);
//
// sd_journal_get_data = f;
// return sd_journal_get_data(j, field, data, length);
// }
//
// int
// my_sd_journal_set_data_threshold(void *f, sd_journal *j, size_t sz)
// {
// int (*sd_journal_set_data_threshold)(sd_journal *, size_t);
//
// sd_journal_set_data_threshold = f;
// return sd_journal_set_data_threshold(j, sz);
// }
//
// int
// my_sd_journal_get_cursor(void *f, sd_journal *j, char **cursor)
// {
// int (*sd_journal_get_cursor)(sd_journal *, char **);
//
// sd_journal_get_cursor = f;
// return sd_journal_get_cursor(j, cursor);
// }
//
// int
// my_sd_journal_test_cursor(void *f, sd_journal *j, const char *cursor)
// {
// int (*sd_journal_test_cursor)(sd_journal *, const char *);
//
// sd_journal_test_cursor = f;
// return sd_journal_test_cursor(j, cursor);
// }
//
// int
// my_sd_journal_get_realtime_usec(void *f, sd_journal *j, uint64_t *usec)
// {
// int (*sd_journal_get_realtime_usec)(sd_journal *, uint64_t *);
//
// sd_journal_get_realtime_usec = f;
// return sd_journal_get_realtime_usec(j, usec);
// }
//
// int
// my_sd_journal_seek_head(void *f, sd_journal *j)
// {
// int (*sd_journal_seek_head)(sd_journal *);
//
// sd_journal_seek_head = f;
// return sd_journal_seek_head(j);
// }
//
// int
// my_sd_journal_seek_tail(void *f, sd_journal *j)
// {
// int (*sd_journal_seek_tail)(sd_journal *);
//
// sd_journal_seek_tail = f;
// return sd_journal_seek_tail(j);
// }
//
//
// int
// my_sd_journal_seek_cursor(void *f, sd_journal *j, const char *cursor)
// {
// int (*sd_journal_seek_cursor)(sd_journal *, const char *);
//
// sd_journal_seek_cursor = f;
// return sd_journal_seek_cursor(j, cursor);
// }
//
// int
// my_sd_journal_seek_realtime_usec(void *f, sd_journal *j, uint64_t usec)
// {
// int (*sd_journal_seek_realtime_usec)(sd_journal *, uint64_t);
//
// sd_journal_seek_realtime_usec = f;
// return sd_journal_seek_realtime_usec(j, usec);
// }
//
// int
// my_sd_journal_wait(void *f, sd_journal *j, uint64_t timeout_usec)
// {
// int (*sd_journal_wait)(sd_journal *, uint64_t);
//
// sd_journal_wait = f;
// return sd_journal_wait(j, timeout_usec);
// }
//
/*
#cgo pkg-config: libsystemd
#include <systemd/sd-journal.h>
#include <stdlib.h>
#include <syslog.h>
*/
import "C"
import (
"fmt"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"unsafe"
"github.com/coreos/pkg/dlopen"
)
var libsystemdFunctions = map[string]unsafe.Pointer{}
// Journal entry field strings which correspond to:
// http://www.freedesktop.org/software/systemd/man/systemd.journal-fields.html
const (
SD_JOURNAL_FIELD_SYSTEMD_UNIT = "_SYSTEMD_UNIT"
SD_JOURNAL_FIELD_SYSLOG_IDENTIFIER = "SYSLOG_IDENTIFIER"
SD_JOURNAL_FIELD_MESSAGE = "MESSAGE"
SD_JOURNAL_FIELD_PID = "_PID"
SD_JOURNAL_FIELD_UID = "_UID"
SD_JOURNAL_FIELD_GID = "_GID"
SD_JOURNAL_FIELD_HOSTNAME = "_HOSTNAME"
SD_JOURNAL_FIELD_MACHINE_ID = "_MACHINE_ID"
SD_JOURNAL_FIELD_TRANSPORT = "_TRANSPORT"
SD_JOURNAL_FIELD_SYSTEMD_UNIT = "_SYSTEMD_UNIT"
SD_JOURNAL_FIELD_MESSAGE = "MESSAGE"
SD_JOURNAL_FIELD_PID = "_PID"
SD_JOURNAL_FIELD_UID = "_UID"
SD_JOURNAL_FIELD_GID = "_GID"
SD_JOURNAL_FIELD_HOSTNAME = "_HOSTNAME"
SD_JOURNAL_FIELD_MACHINE_ID = "_MACHINE_ID"
)
// Journal event constants
@ -271,21 +66,10 @@ const (
IndefiniteWait time.Duration = 1<<63 - 1
)
var libsystemdNames = []string{
// systemd < 209
"libsystemd-journal.so.0",
"libsystemd-journal.so",
// systemd >= 209 merged libsystemd-journal into libsystemd proper
"libsystemd.so.0",
"libsystemd.so",
}
// Journal is a Go wrapper of an sd_journal structure.
type Journal struct {
cjournal *C.sd_journal
mu sync.Mutex
lib *dlopen.LibHandle
}
// Match is a convenience wrapper to describe filters supplied to AddMatch.
@ -299,50 +83,13 @@ func (m *Match) String() string {
return m.Field + "=" + m.Value
}
func (j *Journal) getFunction(name string) (unsafe.Pointer, error) {
j.mu.Lock()
defer j.mu.Unlock()
f, ok := libsystemdFunctions[name]
if !ok {
var err error
f, err = j.lib.GetSymbolPointer(name)
if err != nil {
return nil, err
}
libsystemdFunctions[name] = f
}
return f, nil
}
// NewJournal returns a new Journal instance pointing to the local journal
func NewJournal() (j *Journal, err error) {
h, err := dlopen.GetHandle(libsystemdNames)
if err != nil {
return nil, err
}
defer func() {
if err == nil {
return
}
err2 := h.Close()
if err2 != nil {
err = fmt.Errorf(`%q and "error closing handle: %v"`, err, err2)
}
}()
j = &Journal{lib: h}
sd_journal_open, err := j.getFunction("sd_journal_open")
if err != nil {
return nil, err
}
r := C.my_sd_journal_open(sd_journal_open, &j.cjournal, C.SD_JOURNAL_LOCAL_ONLY)
func NewJournal() (*Journal, error) {
j := &Journal{}
r := C.sd_journal_open(&j.cjournal, C.SD_JOURNAL_LOCAL_ONLY)
if r < 0 {
return nil, fmt.Errorf("failed to open journal: %d", syscall.Errno(-r))
return nil, fmt.Errorf("failed to open journal: %d", r)
}
return j, nil
@ -351,29 +98,8 @@ func NewJournal() (j *Journal, err error) {
// NewJournalFromDir returns a new Journal instance pointing to a journal residing
// in a given directory. The supplied path may be relative or absolute; if
// relative, it will be converted to an absolute path before being opened.
func NewJournalFromDir(path string) (j *Journal, err error) {
h, err := dlopen.GetHandle(libsystemdNames)
if err != nil {
return nil, err
}
defer func() {
if err == nil {
return
}
err2 := h.Close()
if err2 != nil {
err = fmt.Errorf(`%q and "error closing handle: %v"`, err, err2)
}
}()
path, err = filepath.Abs(path)
if err != nil {
return nil, err
}
j = &Journal{lib: h}
sd_journal_open_directory, err := j.getFunction("sd_journal_open_directory")
func NewJournalFromDir(path string) (*Journal, error) {
path, err := filepath.Abs(path)
if err != nil {
return nil, err
}
@ -381,9 +107,10 @@ func NewJournalFromDir(path string) (j *Journal, err error) {
p := C.CString(path)
defer C.free(unsafe.Pointer(p))
r := C.my_sd_journal_open_directory(sd_journal_open_directory, &j.cjournal, p, 0)
j := &Journal{}
r := C.sd_journal_open_directory(&j.cjournal, p, 0)
if r < 0 {
return nil, fmt.Errorf("failed to open journal in directory %q: %d", path, syscall.Errno(-r))
return nil, fmt.Errorf("failed to open journal in directory %q: %d", path, r)
}
return j, nil
@ -391,34 +118,24 @@ func NewJournalFromDir(path string) (j *Journal, err error) {
// Close closes a journal opened with NewJournal.
func (j *Journal) Close() error {
sd_journal_close, err := j.getFunction("sd_journal_close")
if err != nil {
return err
}
j.mu.Lock()
C.my_sd_journal_close(sd_journal_close, j.cjournal)
C.sd_journal_close(j.cjournal)
j.mu.Unlock()
return j.lib.Close()
return nil
}
// AddMatch adds a match by which to filter the entries of the journal.
func (j *Journal) AddMatch(match string) error {
sd_journal_add_match, err := j.getFunction("sd_journal_add_match")
if err != nil {
return err
}
m := C.CString(match)
defer C.free(unsafe.Pointer(m))
j.mu.Lock()
r := C.my_sd_journal_add_match(sd_journal_add_match, j.cjournal, unsafe.Pointer(m), C.size_t(len(match)))
r := C.sd_journal_add_match(j.cjournal, unsafe.Pointer(m), C.size_t(len(match)))
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to add match: %d", syscall.Errno(-r))
return fmt.Errorf("failed to add match: %d", r)
}
return nil
@ -426,17 +143,12 @@ func (j *Journal) AddMatch(match string) error {
// AddDisjunction inserts a logical OR in the match list.
func (j *Journal) AddDisjunction() error {
sd_journal_add_disjunction, err := j.getFunction("sd_journal_add_disjunction")
if err != nil {
return err
}
j.mu.Lock()
r := C.my_sd_journal_add_disjunction(sd_journal_add_disjunction, j.cjournal)
r := C.sd_journal_add_disjunction(j.cjournal)
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to add a disjunction in the match list: %d", syscall.Errno(-r))
return fmt.Errorf("failed to add a disjunction in the match list: %d", r)
}
return nil
@ -444,17 +156,12 @@ func (j *Journal) AddDisjunction() error {
// AddConjunction inserts a logical AND in the match list.
func (j *Journal) AddConjunction() error {
sd_journal_add_conjunction, err := j.getFunction("sd_journal_add_conjunction")
if err != nil {
return err
}
j.mu.Lock()
r := C.my_sd_journal_add_conjunction(sd_journal_add_conjunction, j.cjournal)
r := C.sd_journal_add_conjunction(j.cjournal)
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to add a conjunction in the match list: %d", syscall.Errno(-r))
return fmt.Errorf("failed to add a conjunction in the match list: %d", r)
}
return nil
@ -462,29 +169,19 @@ func (j *Journal) AddConjunction() error {
// FlushMatches flushes all matches, disjunctions and conjunctions.
func (j *Journal) FlushMatches() {
sd_journal_flush_matches, err := j.getFunction("sd_journal_flush_matches")
if err != nil {
return
}
j.mu.Lock()
C.my_sd_journal_flush_matches(sd_journal_flush_matches, j.cjournal)
C.sd_journal_flush_matches(j.cjournal)
j.mu.Unlock()
}
// Next advances the read pointer into the journal by one entry.
func (j *Journal) Next() (int, error) {
sd_journal_next, err := j.getFunction("sd_journal_next")
if err != nil {
return -1, err
}
j.mu.Lock()
r := C.my_sd_journal_next(sd_journal_next, j.cjournal)
r := C.sd_journal_next(j.cjournal)
j.mu.Unlock()
if r < 0 {
return int(r), fmt.Errorf("failed to iterate journal: %d", syscall.Errno(-r))
return int(r), fmt.Errorf("failed to iterate journal: %d", r)
}
return int(r), nil
@ -493,17 +190,12 @@ func (j *Journal) Next() (int, error) {
// NextSkip advances the read pointer by multiple entries at once,
// as specified by the skip parameter.
func (j *Journal) NextSkip(skip uint64) (uint64, error) {
sd_journal_next_skip, err := j.getFunction("sd_journal_next_skip")
if err != nil {
return 0, err
}
j.mu.Lock()
r := C.my_sd_journal_next_skip(sd_journal_next_skip, j.cjournal, C.uint64_t(skip))
r := C.sd_journal_next_skip(j.cjournal, C.uint64_t(skip))
j.mu.Unlock()
if r < 0 {
return uint64(r), fmt.Errorf("failed to iterate journal: %d", syscall.Errno(-r))
return uint64(r), fmt.Errorf("failed to iterate journal: %d", r)
}
return uint64(r), nil
@ -511,17 +203,12 @@ func (j *Journal) NextSkip(skip uint64) (uint64, error) {
// Previous sets the read pointer into the journal back by one entry.
func (j *Journal) Previous() (uint64, error) {
sd_journal_previous, err := j.getFunction("sd_journal_previous")
if err != nil {
return 0, err
}
j.mu.Lock()
r := C.my_sd_journal_previous(sd_journal_previous, j.cjournal)
r := C.sd_journal_previous(j.cjournal)
j.mu.Unlock()
if r < 0 {
return uint64(r), fmt.Errorf("failed to iterate journal: %d", syscall.Errno(-r))
return uint64(r), fmt.Errorf("failed to iterate journal: %d", r)
}
return uint64(r), nil
@ -530,17 +217,12 @@ func (j *Journal) Previous() (uint64, error) {
// PreviousSkip sets back the read pointer by multiple entries at once,
// as specified by the skip parameter.
func (j *Journal) PreviousSkip(skip uint64) (uint64, error) {
sd_journal_previous_skip, err := j.getFunction("sd_journal_previous_skip")
if err != nil {
return 0, err
}
j.mu.Lock()
r := C.my_sd_journal_previous_skip(sd_journal_previous_skip, j.cjournal, C.uint64_t(skip))
r := C.sd_journal_previous_skip(j.cjournal, C.uint64_t(skip))
j.mu.Unlock()
if r < 0 {
return uint64(r), fmt.Errorf("failed to iterate journal: %d", syscall.Errno(-r))
return uint64(r), fmt.Errorf("failed to iterate journal: %d", r)
}
return uint64(r), nil
@ -549,11 +231,6 @@ func (j *Journal) PreviousSkip(skip uint64) (uint64, error) {
// GetData gets the data object associated with a specific field from the
// current journal entry.
func (j *Journal) GetData(field string) (string, error) {
sd_journal_get_data, err := j.getFunction("sd_journal_get_data")
if err != nil {
return "", err
}
f := C.CString(field)
defer C.free(unsafe.Pointer(f))
@ -561,11 +238,11 @@ func (j *Journal) GetData(field string) (string, error) {
var l C.size_t
j.mu.Lock()
r := C.my_sd_journal_get_data(sd_journal_get_data, j.cjournal, f, &d, &l)
r := C.sd_journal_get_data(j.cjournal, f, &d, &l)
j.mu.Unlock()
if r < 0 {
return "", fmt.Errorf("failed to read message: %d", syscall.Errno(-r))
return "", fmt.Errorf("failed to read message: %d", r)
}
msg := C.GoStringN((*C.char)(d), C.int(l))
@ -588,17 +265,12 @@ func (j *Journal) GetDataValue(field string) (string, error) {
// turned off by setting it to 0, so that the library always returns the
// complete data objects.
func (j *Journal) SetDataThreshold(threshold uint64) error {
sd_journal_set_data_threshold, err := j.getFunction("sd_journal_set_data_threshold")
if err != nil {
return err
}
j.mu.Lock()
r := C.my_sd_journal_set_data_threshold(sd_journal_set_data_threshold, j.cjournal, C.size_t(threshold))
r := C.sd_journal_set_data_threshold(j.cjournal, C.size_t(threshold))
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to set data threshold: %d", syscall.Errno(-r))
return fmt.Errorf("failed to set data threshold: %d", r)
}
return nil
@ -609,99 +281,26 @@ func (j *Journal) SetDataThreshold(threshold uint64) error {
func (j *Journal) GetRealtimeUsec() (uint64, error) {
var usec C.uint64_t
sd_journal_get_realtime_usec, err := j.getFunction("sd_journal_get_realtime_usec")
if err != nil {
return 0, err
}
j.mu.Lock()
r := C.my_sd_journal_get_realtime_usec(sd_journal_get_realtime_usec, j.cjournal, &usec)
r := C.sd_journal_get_realtime_usec(j.cjournal, &usec)
j.mu.Unlock()
if r < 0 {
return 0, fmt.Errorf("error getting timestamp for entry: %d", syscall.Errno(-r))
return 0, fmt.Errorf("error getting timestamp for entry: %d", r)
}
return uint64(usec), nil
}
// GetCursor gets the cursor of the current journal entry.
func (j *Journal) GetCursor() (string, error) {
var d *C.char
sd_journal_get_cursor, err := j.getFunction("sd_journal_get_cursor")
if err != nil {
return "", err
}
j.mu.Lock()
r := C.my_sd_journal_get_cursor(sd_journal_get_cursor, j.cjournal, &d)
j.mu.Unlock()
if r < 0 {
return "", fmt.Errorf("failed to get cursor: %d", syscall.Errno(-r))
}
cursor := C.GoString(d)
return cursor, nil
}
// TestCursor checks whether the current position in the journal matches the
// specified cursor
func (j *Journal) TestCursor(cursor string) error {
sd_journal_test_cursor, err := j.getFunction("sd_journal_test_cursor")
if err != nil {
return err
}
c := C.CString(cursor)
defer C.free(unsafe.Pointer(c))
j.mu.Lock()
r := C.my_sd_journal_test_cursor(sd_journal_test_cursor, j.cjournal, c)
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to test to cursor %q: %d", cursor, syscall.Errno(-r))
}
return nil
}
// SeekHead seeks to the beginning of the journal, i.e. the oldest available
// entry.
func (j *Journal) SeekHead() error {
sd_journal_seek_head, err := j.getFunction("sd_journal_seek_head")
if err != nil {
return err
}
j.mu.Lock()
r := C.my_sd_journal_seek_head(sd_journal_seek_head, j.cjournal)
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to seek to head of journal: %d", syscall.Errno(-r))
}
return nil
}
// SeekTail may be used to seek to the end of the journal, i.e. the most recent
// available entry.
func (j *Journal) SeekTail() error {
sd_journal_seek_tail, err := j.getFunction("sd_journal_seek_tail")
if err != nil {
return err
}
j.mu.Lock()
r := C.my_sd_journal_seek_tail(sd_journal_seek_tail, j.cjournal)
r := C.sd_journal_seek_tail(j.cjournal)
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to seek to tail of journal: %d", syscall.Errno(-r))
return fmt.Errorf("failed to seek to tail of journal: %d", r)
}
return nil
@ -710,38 +309,12 @@ func (j *Journal) SeekTail() error {
// SeekRealtimeUsec seeks to the entry with the specified realtime (wallclock)
// timestamp, i.e. CLOCK_REALTIME.
func (j *Journal) SeekRealtimeUsec(usec uint64) error {
sd_journal_seek_realtime_usec, err := j.getFunction("sd_journal_seek_realtime_usec")
if err != nil {
return err
}
j.mu.Lock()
r := C.my_sd_journal_seek_realtime_usec(sd_journal_seek_realtime_usec, j.cjournal, C.uint64_t(usec))
r := C.sd_journal_seek_realtime_usec(j.cjournal, C.uint64_t(usec))
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to seek to %d: %d", usec, syscall.Errno(-r))
}
return nil
}
// SeekCursor seeks to a concrete journal cursor.
func (j *Journal) SeekCursor(cursor string) error {
sd_journal_seek_cursor, err := j.getFunction("sd_journal_seek_cursor")
if err != nil {
return err
}
c := C.CString(cursor)
defer C.free(unsafe.Pointer(c))
j.mu.Lock()
r := C.my_sd_journal_seek_cursor(sd_journal_seek_cursor, j.cjournal, c)
j.mu.Unlock()
if r < 0 {
return fmt.Errorf("failed to seek to cursor %q: %d", cursor, syscall.Errno(-r))
return fmt.Errorf("failed to seek to %d: %d", usec, r)
}
return nil
@ -753,12 +326,6 @@ func (j *Journal) SeekCursor(cursor string) error {
// wait indefinitely for a journal change.
func (j *Journal) Wait(timeout time.Duration) int {
var to uint64
sd_journal_wait, err := j.getFunction("sd_journal_wait")
if err != nil {
return -1
}
if timeout == IndefiniteWait {
// sd_journal_wait(3) calls for a (uint64_t) -1 to be passed to signify
// indefinite wait, but using a -1 overflows our C.uint64_t, so we use an
@ -768,7 +335,7 @@ func (j *Journal) Wait(timeout time.Duration) int {
to = uint64(time.Now().Add(timeout).Unix() / 1000)
}
j.mu.Lock()
r := C.my_sd_journal_wait(sd_journal_wait, j.cjournal, C.uint64_t(to))
r := C.sd_journal_wait(j.cjournal, C.uint64_t(to))
j.mu.Unlock()
return int(r)
@ -777,18 +344,12 @@ func (j *Journal) Wait(timeout time.Duration) int {
// GetUsage returns the journal disk space usage, in bytes.
func (j *Journal) GetUsage() (uint64, error) {
var out C.uint64_t
sd_journal_get_usage, err := j.getFunction("sd_journal_get_usage")
if err != nil {
return 0, err
}
j.mu.Lock()
r := C.my_sd_journal_get_usage(sd_journal_get_usage, j.cjournal, &out)
r := C.sd_journal_get_usage(j.cjournal, &out)
j.mu.Unlock()
if r < 0 {
return 0, fmt.Errorf("failed to get journal disk space usage: %d", syscall.Errno(-r))
return 0, fmt.Errorf("failed to get journal disk space usage: %d", r)
}
return uint64(out), nil

87
vendor/github.com/coreos/go-systemd/sdjournal/journal_test.go generated vendored Executable file → Normal file
View File

@ -16,10 +16,6 @@
package sdjournal
import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"testing"
"time"
@ -92,86 +88,3 @@ func TestJournalGetUsage(t *testing.T) {
t.Fatalf("Error getting journal size: %s", err)
}
}
func TestJournalCursorGetSeekAndTest(t *testing.T) {
j, err := NewJournal()
if err != nil {
t.Fatalf("Error opening journal: %s", err)
}
if j == nil {
t.Fatal("Got a nil journal")
}
defer j.Close()
waitAndNext := func(j *Journal) error {
r := j.Wait(time.Duration(1) * time.Second)
if r < 0 {
return errors.New("Error waiting to journal")
}
n, err := j.Next()
if err != nil {
return fmt.Errorf("Error reading to journal: %s", err)
}
if n == 0 {
return fmt.Errorf("Error reading to journal: %s", io.EOF)
}
return nil
}
err = journal.Print(journal.PriInfo, "test message for cursor %s", time.Now())
if err != nil {
t.Fatalf("Error writing to journal: %s", err)
}
if err = waitAndNext(j); err != nil {
t.Fatalf(err.Error())
}
c, err := j.GetCursor()
if err != nil {
t.Fatalf("Error getting cursor from journal: %s", err)
}
err = j.SeekCursor(c)
if err != nil {
t.Fatalf("Error seeking cursor to journal: %s", err)
}
if err = waitAndNext(j); err != nil {
t.Fatalf(err.Error())
}
err = j.TestCursor(c)
if err != nil {
t.Fatalf("Error testing cursor to journal: %s", err)
}
}
func TestNewJournalFromDir(t *testing.T) {
// test for error handling
dir := "/ClearlyNonExistingPath/"
j, err := NewJournalFromDir(dir)
if err == nil {
defer j.Close()
t.Fatalf("Error expected when opening dummy path (%s)", dir)
}
// test for main code path
dir, err = ioutil.TempDir("", "go-systemd-test")
if err != nil {
t.Fatalf("Error creating tempdir: %s", err)
}
defer os.RemoveAll(dir)
j, err = NewJournalFromDir(dir)
if err != nil {
t.Fatalf("Error opening journal: %s", err)
}
if j == nil {
t.Fatal("Got a nil journal")
}
j.Close()
}

View File

@ -29,20 +29,14 @@ var (
// JournalReaderConfig represents options to drive the behavior of a JournalReader.
type JournalReaderConfig struct {
// The Since, NumFromTail and Cursor options are mutually exclusive and
// determine where the reading begins within the journal. The order in which
// options are written is exactly the order of precedence.
// The Since and NumFromTail options are mutually exclusive and determine
// where the reading begins within the journal.
Since time.Duration // start relative to a Duration from now
NumFromTail uint64 // start relative to the tail
Cursor string // start relative to the cursor
// Show only journal entries whose fields match the supplied values. If
// the array is empty, entries will not be filtered.
Matches []Match
// If not empty, the journal instance will point to a journal residing
// in this directory. The supplied path may be relative or absolute.
Path string
}
// JournalReader is an io.ReadCloser which provides a simple interface for iterating through the
@ -56,14 +50,9 @@ type JournalReader struct {
func NewJournalReader(config JournalReaderConfig) (*JournalReader, error) {
r := &JournalReader{}
// Open the journal
var err error
if config.Path != "" {
r.journal, err = NewJournalFromDir(config.Path)
} else {
r.journal, err = NewJournal()
}
if err != nil {
// Open the journal
if r.journal, err = NewJournal(); err != nil {
return nil, err
}
@ -91,11 +80,6 @@ func NewJournalReader(config JournalReaderConfig) (*JournalReader, error) {
if _, err := r.journal.PreviousSkip(config.NumFromTail + 1); err != nil {
return nil, err
}
} else if config.Cursor != "" {
// Start based on a custom cursor
if err := r.journal.SeekCursor(config.Cursor); err != nil {
return nil, err
}
}
return r, nil
@ -132,25 +116,20 @@ func (r *JournalReader) Read(b []byte) (int, error) {
return len(msg), nil
}
// Close closes the JournalReader's handle to the journal.
func (r *JournalReader) Close() error {
return r.journal.Close()
}
// Rewind attempts to rewind the JournalReader to the first entry.
func (r *JournalReader) Rewind() error {
return r.journal.SeekHead()
}
// Follow synchronously follows the JournalReader, writing each new journal entry to writer. The
// follow will continue until a single time.Time is received on the until channel.
func (r *JournalReader) Follow(until <-chan time.Time, writer io.Writer) (err error) {
// Process journal entries and events. Entries are flushed until the tail or
// timeout is reached, and then we wait for new events or the timeout.
var msg = make([]byte, 64*1<<(10))
process:
for {
var msg = make([]byte, 64*1<<(10))
c, err := r.Read(msg)
if err != nil && err != io.EOF {
break process
@ -161,7 +140,7 @@ process:
return ErrExpired
default:
if c > 0 {
writer.Write(msg[:c])
writer.Write(msg)
continue process
}
}

View File

@ -57,7 +57,7 @@ split=(${TEST// / })
TEST=${split[@]/#/${REPO_PATH}/}
echo "Running tests..."
go test -v ${COVER} $@ ${TEST}
go test ${COVER} $@ ${TEST}
echo "Checking gofmt..."
fmtRes=$(gofmt -l $FMT)

View File

@ -18,7 +18,9 @@
// than linking against them.
package util
// #cgo LDFLAGS: -ldl
// #include <stdlib.h>
// #include <dlfcn.h>
// #include <sys/types.h>
// #include <unistd.h>
//
@ -56,24 +58,56 @@ package util
// }
import "C"
import (
"errors"
"fmt"
"io/ioutil"
"os"
"strings"
"syscall"
"unsafe"
"github.com/coreos/pkg/dlopen"
)
var libsystemdNames = []string{
// systemd < 209
"libsystemd-login.so.0",
"libsystemd-login.so",
var ErrSoNotFound = errors.New("unable to open a handle to libsystemd")
// systemd >= 209 merged libsystemd-login into libsystemd proper
"libsystemd.so.0",
"libsystemd.so",
// libHandle represents an open handle to the systemd C library
type libHandle struct {
handle unsafe.Pointer
libname string
}
func (h *libHandle) Close() error {
if r := C.dlclose(h.handle); r != 0 {
return fmt.Errorf("error closing %v: %d", h.libname, r)
}
return nil
}
// getHandle tries to get a handle to a systemd library (.so), attempting to
// access it by several different names and returning the first that is
// successfully opened. Callers are responsible for closing the handler.
// If no library can be successfully opened, an error is returned.
func getHandle() (*libHandle, error) {
for _, name := range []string{
// systemd < 209
"libsystemd-login.so",
"libsystemd-login.so.0",
// systemd >= 209 merged libsystemd-login into libsystemd proper
"libsystemd.so",
"libsystemd.so.0",
} {
libname := C.CString(name)
defer C.free(unsafe.Pointer(libname))
handle := C.dlopen(libname, C.RTLD_LAZY)
if handle != nil {
h := &libHandle{
handle: handle,
libname: name,
}
return h, nil
}
}
return nil, ErrSoNotFound
}
// GetRunningSlice attempts to retrieve the name of the systemd slice in which
@ -81,8 +115,8 @@ var libsystemdNames = []string{
// This function is a wrapper around the libsystemd C library; if it cannot be
// opened, an error is returned.
func GetRunningSlice() (slice string, err error) {
var h *dlopen.LibHandle
h, err = dlopen.GetHandle(libsystemdNames)
var h *libHandle
h, err = getHandle()
if err != nil {
return
}
@ -92,8 +126,11 @@ func GetRunningSlice() (slice string, err error) {
}
}()
sd_pid_get_slice, err := h.GetSymbolPointer("sd_pid_get_slice")
if err != nil {
sym := C.CString("sd_pid_get_slice")
defer C.free(unsafe.Pointer(sym))
sd_pid_get_slice := C.dlsym(h.handle, sym)
if sd_pid_get_slice == nil {
err = fmt.Errorf("error resolving sd_pid_get_slice function")
return
}
@ -127,8 +164,8 @@ func GetRunningSlice() (slice string, err error) {
// unable to successfully open a handle to the library for any reason (e.g. it
// cannot be found), an errr will be returned
func RunningFromSystemService() (ret bool, err error) {
var h *dlopen.LibHandle
h, err = dlopen.GetHandle(libsystemdNames)
var h *libHandle
h, err = getHandle()
if err != nil {
return
}
@ -138,8 +175,11 @@ func RunningFromSystemService() (ret bool, err error) {
}
}()
sd_pid_get_owner_uid, err := h.GetSymbolPointer("sd_pid_get_owner_uid")
if err != nil {
sym := C.CString("sd_pid_get_owner_uid")
defer C.free(unsafe.Pointer(sym))
sd_pid_get_owner_uid := C.dlsym(h.handle, sym)
if sd_pid_get_owner_uid == nil {
err = fmt.Errorf("error resolving sd_pid_get_owner_uid function")
return
}
@ -172,8 +212,8 @@ func RunningFromSystemService() (ret bool, err error) {
// `sd_pid_get_unit` call, with the same caveat: for processes not part of a
// systemd system unit, this function will return an error.
func CurrentUnitName() (unit string, err error) {
var h *dlopen.LibHandle
h, err = dlopen.GetHandle(libsystemdNames)
var h *libHandle
h, err = getHandle()
if err != nil {
return
}
@ -183,8 +223,11 @@ func CurrentUnitName() (unit string, err error) {
}
}()
sd_pid_get_unit, err := h.GetSymbolPointer("sd_pid_get_unit")
if err != nil {
sym := C.CString("sd_pid_get_unit")
defer C.free(unsafe.Pointer(sym))
sd_pid_get_unit := C.dlsym(h.handle, sym)
if sd_pid_get_unit == nil {
err = fmt.Errorf("error resolving sd_pid_get_unit function")
return
}

View File

@ -1,8 +0,0 @@
language: go
go:
- 1.5.4
- 1.6.2
script:
- ./test

View File

@ -1,4 +1,3 @@
a collection of go utility packages
[![Build Status](https://travis-ci.org/coreos/pkg.png?branch=master)](https://travis-ci.org/coreos/pkg)
[![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/coreos/pkg)
[![Build Status](https://semaphoreci.com/api/v1/projects/14b3f261-22c2-4f56-b1ff-f23f4aa03f5c/411991/badge.svg)](https://semaphoreci.com/coreos/pkg) [![Godoc](http://img.shields.io/badge/godoc-reference-blue.svg?style=flat)](https://godoc.org/github.com/coreos/pkg)

View File

@ -18,7 +18,6 @@ import (
"bufio"
"fmt"
"io"
"log"
"runtime"
"strings"
"time"
@ -29,7 +28,7 @@ type Formatter interface {
Flush()
}
func NewStringFormatter(w io.Writer) Formatter {
func NewStringFormatter(w io.Writer) *StringFormatter {
return &StringFormatter{
w: bufio.NewWriter(w),
}
@ -105,53 +104,3 @@ func (c *PrettyFormatter) Format(pkg string, l LogLevel, depth int, entries ...i
func (c *PrettyFormatter) Flush() {
c.w.Flush()
}
// LogFormatter emulates the form of the traditional built-in logger.
type LogFormatter struct {
logger *log.Logger
prefix string
}
// NewLogFormatter is a helper to produce a new LogFormatter struct. It uses the
// golang log package to actually do the logging work so that logs look similar.
func NewLogFormatter(w io.Writer, prefix string, flag int) Formatter {
return &LogFormatter{
logger: log.New(w, "", flag), // don't use prefix here
prefix: prefix, // save it instead
}
}
// Format builds a log message for the LogFormatter. The LogLevel is ignored.
func (lf *LogFormatter) Format(pkg string, _ LogLevel, _ int, entries ...interface{}) {
str := fmt.Sprint(entries...)
prefix := lf.prefix
if pkg != "" {
prefix = fmt.Sprintf("%s%s: ", prefix, pkg)
}
lf.logger.Output(5, fmt.Sprintf("%s%v", prefix, str)) // call depth is 5
}
// Flush is included so that the interface is complete, but is a no-op.
func (lf *LogFormatter) Flush() {
// noop
}
// NilFormatter is a no-op log formatter that does nothing.
type NilFormatter struct {
}
// NewNilFormatter is a helper to produce a new LogFormatter struct. It logs no
// messages so that you can cause part of your logging to be silent.
func NewNilFormatter() Formatter {
return &NilFormatter{}
}
// Format does nothing.
func (_ *NilFormatter) Format(_ string, _ LogLevel, _ int, _ ...interface{}) {
// noop
}
// Flush is included so that the interface is complete, but is a no-op.
func (_ *NilFormatter) Flush() {
// noop
}

View File

@ -27,19 +27,17 @@ type PackageLogger struct {
const calldepth = 2
func (p *PackageLogger) internalLog(depth int, inLevel LogLevel, entries ...interface{}) {
logger.Lock()
defer logger.Unlock()
if inLevel != CRITICAL && p.level < inLevel {
return
}
logger.Lock()
defer logger.Unlock()
if logger.formatter != nil {
logger.formatter.Format(p.pkg, inLevel, depth+1, entries...)
}
}
func (p *PackageLogger) LevelAt(l LogLevel) bool {
logger.Lock()
defer logger.Unlock()
return p.level >= l
}
@ -60,7 +58,7 @@ func (p *PackageLogger) Println(args ...interface{}) {
}
func (p *PackageLogger) Printf(format string, args ...interface{}) {
p.Logf(INFO, format, args...)
p.internalLog(calldepth, INFO, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Print(args ...interface{}) {
@ -82,7 +80,8 @@ func (p *PackageLogger) Panic(args ...interface{}) {
}
func (p *PackageLogger) Fatalf(format string, args ...interface{}) {
p.Logf(CRITICAL, format, args...)
s := fmt.Sprintf(format, args...)
p.internalLog(calldepth, CRITICAL, s)
os.Exit(1)
}
@ -95,7 +94,7 @@ func (p *PackageLogger) Fatal(args ...interface{}) {
// Error Functions
func (p *PackageLogger) Errorf(format string, args ...interface{}) {
p.Logf(ERROR, format, args...)
p.internalLog(calldepth, ERROR, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Error(entries ...interface{}) {
@ -105,7 +104,7 @@ func (p *PackageLogger) Error(entries ...interface{}) {
// Warning Functions
func (p *PackageLogger) Warningf(format string, args ...interface{}) {
p.Logf(WARNING, format, args...)
p.internalLog(calldepth, WARNING, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Warning(entries ...interface{}) {
@ -115,7 +114,7 @@ func (p *PackageLogger) Warning(entries ...interface{}) {
// Notice Functions
func (p *PackageLogger) Noticef(format string, args ...interface{}) {
p.Logf(NOTICE, format, args...)
p.internalLog(calldepth, NOTICE, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Notice(entries ...interface{}) {
@ -125,7 +124,7 @@ func (p *PackageLogger) Notice(entries ...interface{}) {
// Info Functions
func (p *PackageLogger) Infof(format string, args ...interface{}) {
p.Logf(INFO, format, args...)
p.internalLog(calldepth, INFO, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Info(entries ...interface{}) {
@ -135,32 +134,20 @@ func (p *PackageLogger) Info(entries ...interface{}) {
// Debug Functions
func (p *PackageLogger) Debugf(format string, args ...interface{}) {
if p.level < DEBUG {
return
}
p.Logf(DEBUG, format, args...)
p.internalLog(calldepth, DEBUG, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Debug(entries ...interface{}) {
if p.level < DEBUG {
return
}
p.internalLog(calldepth, DEBUG, entries...)
}
// Trace Functions
func (p *PackageLogger) Tracef(format string, args ...interface{}) {
if p.level < TRACE {
return
}
p.Logf(TRACE, format, args...)
p.internalLog(calldepth, TRACE, fmt.Sprintf(format, args...))
}
func (p *PackageLogger) Trace(entries ...interface{}) {
if p.level < TRACE {
return
}
p.internalLog(calldepth, TRACE, entries...)
}

View File

@ -1,82 +0,0 @@
// Copyright 2016 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package dlopen provides some convenience functions to dlopen a library and
// get its symbols.
package dlopen
// #cgo LDFLAGS: -ldl
// #include <stdlib.h>
// #include <dlfcn.h>
import "C"
import (
"errors"
"fmt"
"unsafe"
)
var ErrSoNotFound = errors.New("unable to open a handle to the library")
// LibHandle represents an open handle to a library (.so)
type LibHandle struct {
Handle unsafe.Pointer
Libname string
}
// GetHandle tries to get a handle to a library (.so), attempting to access it
// by the names specified in libs and returning the first that is successfully
// opened. Callers are responsible for closing the handler. If no library can
// be successfully opened, an error is returned.
func GetHandle(libs []string) (*LibHandle, error) {
for _, name := range libs {
libname := C.CString(name)
defer C.free(unsafe.Pointer(libname))
handle := C.dlopen(libname, C.RTLD_LAZY)
if handle != nil {
h := &LibHandle{
Handle: handle,
Libname: name,
}
return h, nil
}
}
return nil, ErrSoNotFound
}
// GetSymbolPointer takes a symbol name and returns a pointer to the symbol.
func (l *LibHandle) GetSymbolPointer(symbol string) (unsafe.Pointer, error) {
sym := C.CString(symbol)
defer C.free(unsafe.Pointer(sym))
C.dlerror()
p := C.dlsym(l.Handle, sym)
e := C.dlerror()
if e != nil {
return nil, fmt.Errorf("error resolving symbol %q: %v", symbol, errors.New(C.GoString(e)))
}
return p, nil
}
// Close closes a LibHandle.
func (l *LibHandle) Close() error {
C.dlerror()
C.dlclose(l.Handle)
e := C.dlerror()
if e != nil {
return fmt.Errorf("error closing %v: %v", l.Libname, errors.New(C.GoString(e)))
}
return nil
}

View File

@ -1,56 +0,0 @@
// Copyright 2015 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// +build linux
package dlopen
// #include <string.h>
// #include <stdlib.h>
//
// int
// my_strlen(void *f, const char *s)
// {
// size_t (*strlen)(const char *);
//
// strlen = (size_t (*)(const char *))f;
// return strlen(s);
// }
import "C"
import (
"fmt"
"unsafe"
)
func strlen(libs []string, s string) (int, error) {
h, err := GetHandle(libs)
if err != nil {
return -1, fmt.Errorf(`couldn't get a handle to the library: %v`, err)
}
defer h.Close()
f := "strlen"
cs := C.CString(s)
defer C.free(unsafe.Pointer(cs))
strlen, err := h.GetSymbolPointer(f)
if err != nil {
return -1, fmt.Errorf(`couldn't get symbol %q: %v`, f, err)
}
len := C.my_strlen(strlen, cs)
return int(len), nil
}

View File

@ -1,63 +0,0 @@
// Copyright 2015 CoreOS, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package dlopen
import (
"fmt"
"testing"
)
func checkFailure(shouldSucceed bool, err error) (rErr error) {
switch {
case err != nil && shouldSucceed:
rErr = fmt.Errorf("expected test to succeed, failed unexpectedly: %v", err)
case err == nil && !shouldSucceed:
rErr = fmt.Errorf("expected test to fail, succeeded unexpectedly")
}
return
}
func TestDlopen(t *testing.T) {
tests := []struct {
libs []string
shouldSucceed bool
}{
{
libs: []string{
"libc.so.6",
"libc.so",
},
shouldSucceed: true,
},
{
libs: []string{
"libstrange.so",
},
shouldSucceed: false,
},
}
for i, tt := range tests {
expLen := 4
len, err := strlen(tt.libs, "test")
if checkFailure(tt.shouldSucceed, err) != nil {
t.Errorf("case %d: %v", i, err)
}
if tt.shouldSucceed && len != expLen {
t.Errorf("case %d: expected length %d, got %d", i, expLen, len)
}
}
}

View File

@ -1,77 +0,0 @@
package flagutil
import (
"bufio"
"flag"
"fmt"
"os"
"strings"
)
// SetFlagsFromEnvFile iterates the given flagset and if any flags are not
// already set it attempts to set their values from the given env file. Env
// files may have KEY=VALUE lines where the environment variable names are
// in UPPERCASE, prefixed by the given PREFIX, and dashes are replaced by
// underscores. For example, if prefix=PREFIX, some-flag is named
// PREFIX_SOME_FLAG.
// Comment lines are skipped, but more complex env file parsing is not
// performed.
func SetFlagsFromEnvFile(fs *flag.FlagSet, prefix string, path string) (err error) {
alreadySet := make(map[string]bool)
fs.Visit(func(f *flag.Flag) {
alreadySet[f.Name] = true
})
envs, err := parseEnvFile(path)
if err != nil {
return err
}
fs.VisitAll(func(f *flag.Flag) {
if !alreadySet[f.Name] {
key := prefix + "_" + strings.ToUpper(strings.Replace(f.Name, "-", "_", -1))
val := envs[key]
if val != "" {
if serr := fs.Set(f.Name, val); serr != nil {
err = fmt.Errorf("invalid value %q for %s: %v", val, key, serr)
}
}
}
})
return err
}
func parseEnvFile(path string) (map[string]string, error) {
file, err := os.Open(path)
if err != nil {
return nil, err
}
defer file.Close()
envs := make(map[string]string)
scanner := bufio.NewScanner(file)
for scanner.Scan() {
token := scanner.Text()
if !skipLine(token) {
key, val, err := parseLine(token)
if err == nil {
envs[key] = val
}
}
}
return envs, nil
}
func skipLine(line string) bool {
return len(line) == 0 || strings.HasPrefix(line, "#")
}
func parseLine(line string) (key string, val string, err error) {
trimmed := strings.TrimSpace(line)
pair := strings.SplitN(trimmed, "=", 2)
if len(pair) != 2 {
err = fmt.Errorf("invalid KEY=value line: %q", line)
return
}
key = strings.TrimSpace(pair[0])
val = strings.TrimSpace(pair[1])
return
}

View File

@ -1,107 +0,0 @@
package flagutil
import (
"flag"
"io/ioutil"
"os"
"testing"
)
var envFile = `
# some secret env vars
MYPROJ_A=foo
MYPROJ_C=woof
`
func TestSetFlagsFromEnvFile(t *testing.T) {
fs := flag.NewFlagSet("testing", flag.ExitOnError)
fs.String("a", "", "")
fs.String("b", "", "")
fs.String("c", "", "")
fs.Parse([]string{})
// add command-line flags
if err := fs.Set("b", "bar"); err != nil {
t.Fatal(err)
}
if err := fs.Set("c", "quack"); err != nil {
t.Fatal(err)
}
// first verify that flags are as expected before reading the env
for f, want := range map[string]string{
"a": "",
"b": "bar",
"c": "quack",
} {
if got := fs.Lookup(f).Value.String(); got != want {
t.Fatalf("flag %q=%q, want %q", f, got, want)
}
}
file, err := ioutil.TempFile("", "env-file")
if err != nil {
t.Fatal(err)
}
defer os.Remove(file.Name())
file.Write([]byte(envFile))
// read env file and verify flags were updated as expected
err = SetFlagsFromEnvFile(fs, "MYPROJ", file.Name())
if err != nil {
t.Errorf("err=%v, want nil", err)
}
for f, want := range map[string]string{
"a": "foo",
"b": "bar",
"c": "quack",
} {
if got := fs.Lookup(f).Value.String(); got != want {
t.Errorf("flag %q=%q, want %q", f, got, want)
}
}
}
func TestSetFlagsFromEnvFile_FlagSetError(t *testing.T) {
// now verify that an error is propagated
fs := flag.NewFlagSet("testing", flag.ExitOnError)
fs.Int("x", 0, "")
file, err := ioutil.TempFile("", "env-file")
if err != nil {
t.Fatal(err)
}
defer os.Remove(file.Name())
file.Write([]byte("MYPROJ_X=not_a_number"))
if err := SetFlagsFromEnvFile(fs, "MYPROJ", file.Name()); err == nil {
t.Errorf("err=nil, want != nil")
}
}
func TestParseLine(t *testing.T) {
cases := []struct {
line string
expectedKey string
expectedVal string
nilErr bool
}{
{"key=value", "key", "value", true},
{" key = value ", "key", "value", true},
{"key='#gopher' #blah", "key", "'#gopher' #blah", true},
// invalid
{"key:value", "", "", false},
{"keyvalue", "", "", false},
}
for _, c := range cases {
key, val, err := parseLine(c.line)
if (err == nil) != c.nilErr {
if c.nilErr {
t.Errorf("got %s, want err=nil", err)
} else {
t.Errorf("got err=nil, want err!=nil")
}
}
if c.expectedKey != key || c.expectedVal != val {
t.Errorf("got %q=%q, want %q=%q", key, val, c.expectedKey, c.expectedVal)
}
}
}

View File

@ -1,189 +0,0 @@
// Copyright 2016 CoreOS Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package progressutil
import (
"errors"
"fmt"
"io"
"sync"
"time"
)
var (
ErrAlreadyStarted = errors.New("cannot add copies after PrintAndWait has been called")
)
type copyReader struct {
reader io.Reader
current int64
total int64
pb *ProgressBar
}
func (cr *copyReader) Read(p []byte) (int, error) {
n, err := cr.reader.Read(p)
cr.current += int64(n)
err1 := cr.updateProgressBar()
if err == nil {
err = err1
}
return n, err
}
func (cr *copyReader) updateProgressBar() error {
cr.pb.SetPrintAfter(cr.formattedProgress())
progress := float64(cr.current) / float64(cr.total)
if progress > 1 {
progress = 1
}
return cr.pb.SetCurrentProgress(progress)
}
// NewCopyProgressPrinter returns a new CopyProgressPrinter
func NewCopyProgressPrinter() *CopyProgressPrinter {
return &CopyProgressPrinter{results: make(chan error), cancel: make(chan struct{})}
}
// CopyProgressPrinter will perform an arbitrary number of io.Copy calls, while
// continually printing the progress of each copy.
type CopyProgressPrinter struct {
results chan error
cancel chan struct{}
// `lock` mutex protects all fields below it in CopyProgressPrinter struct
lock sync.Mutex
readers []*copyReader
started bool
pbp *ProgressBarPrinter
}
// AddCopy adds a copy for this CopyProgressPrinter to perform. An io.Copy call
// will be made to copy bytes from reader to dest, and name and size will be
// used to label the progress bar and display how much progress has been made.
// If size is 0, the total size of the reader is assumed to be unknown.
// AddCopy can only be called before PrintAndWait; otherwise, ErrAlreadyStarted
// will be returned.
func (cpp *CopyProgressPrinter) AddCopy(reader io.Reader, name string, size int64, dest io.Writer) error {
cpp.lock.Lock()
defer cpp.lock.Unlock()
if cpp.started {
return ErrAlreadyStarted
}
if cpp.pbp == nil {
cpp.pbp = &ProgressBarPrinter{}
cpp.pbp.PadToBeEven = true
}
cr := &copyReader{
reader: reader,
current: 0,
total: size,
pb: cpp.pbp.AddProgressBar(),
}
cr.pb.SetPrintBefore(name)
cr.pb.SetPrintAfter(cr.formattedProgress())
cpp.readers = append(cpp.readers, cr)
go func() {
_, err := io.Copy(dest, cr)
select {
case <-cpp.cancel:
return
case cpp.results <- err:
return
}
}()
return nil
}
// PrintAndWait will print the progress for each copy operation added with
// AddCopy to printTo every printInterval. This will continue until every added
// copy is finished, or until cancel is written to.
// PrintAndWait may only be called once; any subsequent calls will immediately
// return ErrAlreadyStarted. After PrintAndWait has been called, no more
// copies may be added to the CopyProgressPrinter.
func (cpp *CopyProgressPrinter) PrintAndWait(printTo io.Writer, printInterval time.Duration, cancel chan struct{}) error {
cpp.lock.Lock()
if cpp.started {
cpp.lock.Unlock()
return ErrAlreadyStarted
}
cpp.started = true
cpp.lock.Unlock()
n := len(cpp.readers)
if n == 0 {
// Nothing to do.
return nil
}
defer close(cpp.cancel)
t := time.NewTicker(printInterval)
allDone := false
for i := 0; i < n; {
select {
case <-cancel:
return nil
case <-t.C:
_, err := cpp.pbp.Print(printTo)
if err != nil {
return err
}
case err := <-cpp.results:
i++
// Once completion is signaled, further on this just drains
// (unlikely) errors from the channel.
if err == nil && !allDone {
allDone, err = cpp.pbp.Print(printTo)
}
if err != nil {
return err
}
}
}
return nil
}
func (cr *copyReader) formattedProgress() string {
var totalStr string
if cr.total == 0 {
totalStr = "?"
} else {
totalStr = ByteUnitStr(cr.total)
}
return fmt.Sprintf("%s / %s", ByteUnitStr(cr.current), totalStr)
}
var byteUnits = []string{"B", "KB", "MB", "GB", "TB", "PB"}
// ByteUnitStr pretty prints a number of bytes.
func ByteUnitStr(n int64) string {
var unit string
size := float64(n)
for i := 1; i < len(byteUnits); i++ {
if size < 1000 {
unit = byteUnits[i-1]
break
}
size = size / 1000
}
return fmt.Sprintf("%.3g %s", size, unit)
}

View File

@ -1,256 +0,0 @@
// Copyright 2016 CoreOS Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package progressutil
import (
"fmt"
"io"
"os"
"strings"
"sync"
"golang.org/x/crypto/ssh/terminal"
)
var (
// ErrorProgressOutOfBounds is returned if the progress is set to a value
// not between 0 and 1.
ErrorProgressOutOfBounds = fmt.Errorf("progress is out of bounds (0 to 1)")
// ErrorNoBarsAdded is returned when no progress bars have been added to a
// ProgressBarPrinter before PrintAndWait is called.
ErrorNoBarsAdded = fmt.Errorf("AddProgressBar hasn't been called yet")
)
// ProgressBar represents one progress bar in a ProgressBarPrinter. Should not
// be created directly, use the AddProgressBar on a ProgressBarPrinter to
// create these.
type ProgressBar struct {
lock sync.Mutex
currentProgress float64
printBefore string
printAfter string
done bool
}
func (pb *ProgressBar) clone() *ProgressBar {
pb.lock.Lock()
pbClone := &ProgressBar{
currentProgress: pb.currentProgress,
printBefore: pb.printBefore,
printAfter: pb.printAfter,
done: pb.done,
}
pb.lock.Unlock()
return pbClone
}
func (pb *ProgressBar) GetCurrentProgress() float64 {
pb.lock.Lock()
val := pb.currentProgress
pb.lock.Unlock()
return val
}
// SetCurrentProgress sets the progress of this ProgressBar. The progress must
// be between 0 and 1 inclusive.
func (pb *ProgressBar) SetCurrentProgress(progress float64) error {
if progress < 0 || progress > 1 {
return ErrorProgressOutOfBounds
}
pb.lock.Lock()
pb.currentProgress = progress
pb.lock.Unlock()
return nil
}
// GetDone returns whether or not this progress bar is done
func (pb *ProgressBar) GetDone() bool {
pb.lock.Lock()
val := pb.done
pb.lock.Unlock()
return val
}
// SetDone sets whether or not this progress bar is done
func (pb *ProgressBar) SetDone(val bool) {
pb.lock.Lock()
pb.done = val
pb.lock.Unlock()
}
// GetPrintBefore gets the text printed on the line before the progress bar.
func (pb *ProgressBar) GetPrintBefore() string {
pb.lock.Lock()
val := pb.printBefore
pb.lock.Unlock()
return val
}
// SetPrintBefore sets the text printed on the line before the progress bar.
func (pb *ProgressBar) SetPrintBefore(before string) {
pb.lock.Lock()
pb.printBefore = before
pb.lock.Unlock()
}
// GetPrintAfter gets the text printed on the line after the progress bar.
func (pb *ProgressBar) GetPrintAfter() string {
pb.lock.Lock()
val := pb.printAfter
pb.lock.Unlock()
return val
}
// SetPrintAfter sets the text printed on the line after the progress bar.
func (pb *ProgressBar) SetPrintAfter(after string) {
pb.lock.Lock()
pb.printAfter = after
pb.lock.Unlock()
}
// ProgressBarPrinter will print out the progress of some number of
// ProgressBars.
type ProgressBarPrinter struct {
lock sync.Mutex
// DisplayWidth can be set to influence how large the progress bars are.
// The bars will be scaled to attempt to produce lines of this number of
// characters, but lines of different lengths may still be printed. When
// this value is 0 (aka unset), 80 character columns are assumed.
DisplayWidth int
// PadToBeEven, when set to true, will make Print pad the printBefore text
// with trailing spaces and the printAfter text with leading spaces to make
// the progress bars the same length.
PadToBeEven bool
numLinesInLastPrint int
progressBars []*ProgressBar
maxBefore int
maxAfter int
}
// AddProgressBar will create a new ProgressBar, register it with this
// ProgressBarPrinter, and return it. This must be called at least once before
// PrintAndWait is called.
func (pbp *ProgressBarPrinter) AddProgressBar() *ProgressBar {
pb := &ProgressBar{}
pbp.lock.Lock()
pbp.progressBars = append(pbp.progressBars, pb)
pbp.lock.Unlock()
return pb
}
// Print will print out progress information for each ProgressBar that has been
// added to this ProgressBarPrinter. The progress will be written to printTo,
// and if printTo is a terminal it will draw progress bars. AddProgressBar
// must be called at least once before Print is called. If printing to a
// terminal, all draws after the first one will move the cursor up to draw over
// the previously printed bars.
func (pbp *ProgressBarPrinter) Print(printTo io.Writer) (bool, error) {
pbp.lock.Lock()
var bars []*ProgressBar
for _, bar := range pbp.progressBars {
bars = append(bars, bar.clone())
}
numColumns := pbp.DisplayWidth
pbp.lock.Unlock()
if len(bars) == 0 {
return false, ErrorNoBarsAdded
}
if numColumns == 0 {
numColumns = 80
}
if isTerminal(printTo) {
moveCursorUp(printTo, pbp.numLinesInLastPrint)
}
for _, bar := range bars {
beforeSize := len(bar.GetPrintBefore())
afterSize := len(bar.GetPrintAfter())
if beforeSize > pbp.maxBefore {
pbp.maxBefore = beforeSize
}
if afterSize > pbp.maxAfter {
pbp.maxAfter = afterSize
}
}
allDone := true
for _, bar := range bars {
if isTerminal(printTo) {
bar.printToTerminal(printTo, numColumns, pbp.PadToBeEven, pbp.maxBefore, pbp.maxAfter)
} else {
bar.printToNonTerminal(printTo)
}
allDone = allDone && bar.GetCurrentProgress() == 1
}
pbp.numLinesInLastPrint = len(bars)
return allDone, nil
}
// moveCursorUp moves the cursor up numLines in the terminal
func moveCursorUp(printTo io.Writer, numLines int) {
if numLines > 0 {
fmt.Fprintf(printTo, "\033[%dA", numLines)
}
}
func (pb *ProgressBar) printToTerminal(printTo io.Writer, numColumns int, padding bool, maxBefore, maxAfter int) {
before := pb.GetPrintBefore()
after := pb.GetPrintAfter()
if padding {
before = before + strings.Repeat(" ", maxBefore-len(before))
after = strings.Repeat(" ", maxAfter-len(after)) + after
}
progressBarSize := numColumns - (len(fmt.Sprintf("%s [] %s", before, after)))
progressBar := ""
if progressBarSize > 0 {
currentProgress := int(pb.GetCurrentProgress() * float64(progressBarSize))
progressBar = fmt.Sprintf("[%s%s] ",
strings.Repeat("=", currentProgress),
strings.Repeat(" ", progressBarSize-currentProgress))
} else {
// If we can't fit the progress bar, better to not pad the before/after.
before = pb.GetPrintBefore()
after = pb.GetPrintAfter()
}
fmt.Fprintf(printTo, "%s %s%s\n", before, progressBar, after)
}
func (pb *ProgressBar) printToNonTerminal(printTo io.Writer) {
if !pb.GetDone() {
fmt.Fprintf(printTo, "%s %s\n", pb.printBefore, pb.printAfter)
if pb.GetCurrentProgress() == 1 {
pb.SetDone(true)
}
}
}
// isTerminal returns True when w is going to a tty, and false otherwise.
func isTerminal(w io.Writer) bool {
if f, ok := w.(*os.File); ok {
return terminal.IsTerminal(int(f.Fd()))
}
return false
}

2
vendor/github.com/coreos/pkg/test generated vendored
View File

@ -14,7 +14,7 @@ COVER=${COVER:-"-cover"}
source ./build
TESTABLE="cryptoutil flagutil timeutil netutil yamlutil httputil health multierror dlopen"
TESTABLE="cryptoutil flagutil timeutil netutil yamlutil httputil health multierror"
FORMATTABLE="$TESTABLE capnslog"
# user has not provided PKG override

1
vendor/github.com/docker/distribution generated vendored Submodule

@ -0,0 +1 @@
Subproject commit c8dff1bb5764aa6ea48d261b813ef3b0d2e25989

1
vendor/github.com/docker/docker generated vendored Submodule

@ -0,0 +1 @@
Subproject commit 508a17baba3c39496008fc5b5e3fe890b8a1b31b

1
vendor/github.com/docker/engine-api generated vendored Submodule

@ -0,0 +1 @@
Subproject commit 739ad1671be872943fadf936120ccfe9ccdde23a

1
vendor/github.com/docker/go-connections generated vendored Submodule

@ -0,0 +1 @@
Subproject commit c7838b258fbfa3fe88eecfb2a0e08ea0dbd6a646

1
vendor/github.com/docker/go-units generated vendored Submodule

@ -0,0 +1 @@
Subproject commit 09dda9d4b0d748c57c14048906d3d094a58ec0c9

1
vendor/github.com/docker/libtrust generated vendored Submodule

@ -0,0 +1 @@
Subproject commit 9cbd2a1374f46905c68a4eb3694a130610adc62a

View File

@ -4,8 +4,6 @@ go:
- 1.2
- 1.3
- 1.4
- 1.5
- 1.6
- tip
before_script:

View File

@ -15,8 +15,6 @@ Aaron Hopkins <go-sql-driver at die.net>
Arne Hormann <arnehormann at gmail.com>
Carlos Nieto <jose.carlos at menteslibres.net>
Chris Moos <chris at tech9computers.com>
Daniel Nichter <nil at codenode.com>
Daniël van Eeden <git at myname.nl>
DisposaBoy <disposaboy at dby.me>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
@ -27,23 +25,19 @@ INADA Naoki <songofacandy at gmail.com>
James Harr <james.harr at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com>
Julien Schmidt <go-sql-driver at julienschmidt.com>
Kamil Dziedzic <kamil at klecza.pl>
Kevin Malachowski <kevin at chowski.com>
Leonardo YongUk Kim <dalinaum at gmail.com>
Luca Looz <luca.looz92 at gmail.com>
Lucas Liu <extrafliu at gmail.com>
Luke Scott <luke at webconnex.com>
Michael Woolnough <michael.woolnough at gmail.com>
Nicola Peduzzi <thenikso at gmail.com>
Paul Bonser <misterpib at gmail.com>
Runrioter Wung <runrioter at gmail.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
Julien Lefevre <julien.lefevr at gmail.com>
# Organizations

View File

@ -12,21 +12,10 @@ Bugfixes:
- Enable microsecond resolution on TIME, DATETIME and TIMESTAMP (#249)
- Fixed handling of queries without columns and rows (#255)
- Fixed a panic when SetKeepAlive() failed (#298)
- Support receiving ERR packet while reading rows (#321)
- Fixed reading NULL length-encoded integers in MySQL 5.6+ (#349)
- Fixed absolute paths support in LOAD LOCAL DATA INFILE (#356)
- Actually zero out bytes in handshake response (#378)
- Fixed race condition in registering LOAD DATA INFILE handler (#383)
- Fixed tests with MySQL 5.7.9+ (#380)
- QueryUnescape TLS config names (#397)
- Fixed "broken pipe" error by writing to closed socket (#390)
New Features:
- Support for returning table alias on Columns() (#289, #359, #382)
- Support for returning table alias on Columns() (#289)
- Placeholder interpolation, can be actived with the DSN parameter `interpolateParams=true` (#309, #318)
- Support for uint64 parameters with high bit set (#332, #345)
- Cleartext authentication plugin support (#327)
## Version 1.2 (2014-06-03)

View File

@ -4,11 +4,28 @@
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/go-sql-driver/mysql/issues?state=open) or was [recently closed](https://github.com/go-sql-driver/mysql/issues?direction=desc&page=1&sort=updated&state=closed).
Please provide the following minimum information:
* Your Go-MySQL-Driver version (or git SHA)
* Your Go version (run `go version` in your console)
* A detailed issue description
* Error Log if present
* If possible, a short example
## Contributing Code
By contributing to this project, you share your code under the Mozilla Public License 2, as specified in the LICENSE file.
Don't forget to add yourself to the AUTHORS file.
### Pull Requests Checklist
Please check the following points before submitting your pull request:
- [x] Code compiles correctly
- [x] Created tests, if possible
- [x] All tests pass
- [x] Extended the README / documentation, if necessary
- [x] Added yourself to the AUTHORS file
### Code Review
Everyone is invited to review and comment on pull requests.

View File

@ -1,21 +0,0 @@
### Issue description
Tell us what should happen and what happens instead
### Example code
```go
If possible, please enter some example code here to reproduce the issue.
```
### Error log
```
If you have an error log, please paste it here.
```
### Configuration
*Driver version (or git SHA):*
*Go version:* run `go version` in your console
*Server version:* E.g. MySQL 5.6, MariaDB 10.0.20
*Server OS:* E.g. Debian 8.1 (Jessie), Windows 10

View File

@ -1,9 +0,0 @@
### Description
Please explain the changes you made here.
### Checklist
- [ ] Code compiles correctly
- [ ] Created tests which fail without the change (if possible)
- [ ] All tests passing
- [ ] Extended the README / documentation, if necessary
- [ ] Added myself / the copyright holder to the AUTHORS file

View File

@ -93,8 +93,6 @@ This has the same effect as an empty DSN string:
```
Alternatively, [Config.FormatDSN](https://godoc.org/github.com/go-sql-driver/mysql#Config.FormatDSN) can be used to create a DSN string by filling a struct.
#### Password
Passwords can consist of any character. Escaping is **not** necessary.
@ -221,18 +219,6 @@ Note that this sets the location for time.Time values but does not change MySQL'
Please keep in mind, that param values must be [url.QueryEscape](http://golang.org/pkg/net/url/#QueryEscape)'ed. Alternatively you can manually replace the `/` with `%2F`. For example `US/Pacific` would be `loc=US%2FPacific`.
##### `multiStatements`
```
Type: bool
Valid Values: true, false
Default: false
```
Allow multiple statements in one query. While this allows batch queries, it also greatly increases the risk of SQL injections. Only the result of the first query is returned, all other results are silently discarded.
When `multiStatements` is used, `?` parameters must only be used in the first statement.
##### `parseTime`
@ -245,16 +231,6 @@ Default: false
`parseTime=true` changes the output type of `DATE` and `DATETIME` values to `time.Time` instead of `[]byte` / `string`
##### `readTimeout`
```
Type: decimal number
Default: 0
```
I/O read timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### `strict`
```
@ -275,7 +251,7 @@ Type: decimal number
Default: OS default
```
*Driver* side connection timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
*Driver* side connection timeout. The value must be a string of decimal numbers, each with optional fraction and a unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*. To set a server side timeout, use the parameter [`wait_timeout`](http://dev.mysql.com/doc/refman/5.6/en/server-system-variables.html#sysvar_wait_timeout).
##### `tls`
@ -289,16 +265,6 @@ Default: false
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](http://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
##### `writeTimeout`
```
Type: decimal number
Default: 0
```
I/O write timeout. The value must be a decimal number with an unit suffix ( *"ms"*, *"s"*, *"m"*, *"h"* ), such as *"30s"*, *"0.5m"* or *"1m30s"*.
##### System Variables
All other parameters are interpreted as system variables:

View File

@ -49,9 +49,9 @@ func initDB(b *testing.B, queries ...string) *sql.DB {
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
if w, ok := err.(MySQLWarnings); ok {
b.Logf("warning on %q: %v", query, w)
b.Logf("Warning on %q: %v", query, w)
} else {
b.Fatalf("error on %q: %v", query, err)
b.Fatalf("Error on %q: %v", query, err)
}
}
}
@ -216,9 +216,9 @@ func BenchmarkRoundtripBin(b *testing.B) {
func BenchmarkInterpolation(b *testing.B) {
mc := &mysqlConn{
cfg: &Config{
InterpolateParams: true,
Loc: time.UTC,
cfg: &config{
interpolateParams: true,
loc: time.UTC,
},
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,

View File

@ -8,11 +8,7 @@
package mysql
import (
"io"
"net"
"time"
)
import "io"
const defaultBufSize = 4096
@ -22,18 +18,17 @@ const defaultBufSize = 4096
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
type buffer struct {
buf []byte
nc net.Conn
idx int
length int
timeout time.Duration
buf []byte
rd io.Reader
idx int
length int
}
func newBuffer(nc net.Conn) buffer {
func newBuffer(rd io.Reader) buffer {
var b [defaultBufSize]byte
return buffer{
buf: b[:],
nc: nc,
rd: rd,
}
}
@ -59,13 +54,7 @@ func (b *buffer) fill(need int) error {
b.idx = 0
for {
if b.timeout > 0 {
if err := b.nc.SetReadDeadline(time.Now().Add(b.timeout)); err != nil {
return err
}
}
nn, err := b.nc.Read(b.buf[n:])
nn, err := b.rd.Read(b.buf[n:])
n += nn
switch err {

View File

@ -8,7 +8,7 @@
package mysql
const defaultCollation = "utf8_general_ci"
const defaultCollation byte = 33 // utf8_general_ci
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:
@ -237,14 +237,14 @@ var collations = map[string]byte{
// A blacklist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
var unsafeCollations = map[byte]bool{
1: true, // big5_chinese_ci
13: true, // sjis_japanese_ci
28: true, // gbk_chinese_ci
84: true, // big5_bin
86: true, // gb2312_bin
87: true, // gbk_bin
88: true, // sjis_bin
95: true, // cp932_japanese_ci
96: true, // cp932_bin
}

View File

@ -9,7 +9,9 @@
package mysql
import (
"crypto/tls"
"database/sql/driver"
"errors"
"net"
"strconv"
"strings"
@ -21,10 +23,9 @@ type mysqlConn struct {
netConn net.Conn
affectedRows uint64
insertId uint64
cfg *Config
cfg *config
maxPacketAllowed int
maxWriteSize int
writeTimeout time.Duration
flags clientFlag
status statusFlag
sequence uint8
@ -32,9 +33,28 @@ type mysqlConn struct {
strict bool
}
type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
loc *time.Location
tls *tls.Config
timeout time.Duration
collation uint8
allowAllFiles bool
allowOldPasswords bool
allowCleartextPasswords bool
clientFoundRows bool
columnsWithAlias bool
interpolateParams bool
}
// Handles parameters set in DSN after the connection is established
func (mc *mysqlConn) handleParams() (err error) {
for param, val := range mc.cfg.Params {
for param, val := range mc.cfg.params {
switch param {
// Charset
case "charset":
@ -50,6 +70,27 @@ func (mc *mysqlConn) handleParams() (err error) {
return
}
// time.Time parsing
case "parseTime":
var isBool bool
mc.parseTime, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Strict mode
case "strict":
var isBool bool
mc.strict, isBool = readBool(val)
if !isBool {
return errors.New("Invalid Bool value: " + val)
}
// Compression
case "compress":
err = errors.New("Compression not implemented yet")
return
// System Vars
default:
err = mc.exec("SET " + param + "=" + val + "")
@ -79,27 +120,18 @@ func (mc *mysqlConn) Close() (err error) {
// Makes Close idempotent
if mc.netConn != nil {
err = mc.writeCommandPacket(comQuit)
}
mc.cleanup()
return
}
// Closes the network connection and unsets internal variables. Do not call this
// function after successfully authentication, call Close instead. This function
// is called before auth or on auth failure because MySQL will have already
// closed the network connection.
func (mc *mysqlConn) cleanup() {
// Makes cleanup idempotent
if mc.netConn != nil {
if err := mc.netConn.Close(); err != nil {
errLog.Print(err)
if err == nil {
err = mc.netConn.Close()
} else {
mc.netConn.Close()
}
mc.netConn = nil
}
mc.cfg = nil
mc.buf.nc = nil
mc.buf.rd = nil
return
}
func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
@ -176,7 +208,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
if v.IsZero() {
buf = append(buf, "'0000-00-00'"...)
} else {
v := v.In(mc.cfg.Loc)
v := v.In(mc.cfg.loc)
v = v.Add(time.Nanosecond * 500) // To round under microsecond
year := v.Year()
year100 := year / 100
@ -257,7 +289,7 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
if !mc.cfg.interpolateParams {
return nil, driver.ErrSkip
}
// try to interpolate the parameters to save extra roundtrips for preparing and closing a statement
@ -308,7 +340,7 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
return nil, driver.ErrBadConn
}
if len(args) != 0 {
if !mc.cfg.InterpolateParams {
if !mc.cfg.interpolateParams {
return nil, driver.ErrSkip
}
// try client-side prepare to reduce roundtrip
@ -354,7 +386,6 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
if err == nil {
rows := new(textRows)
rows.mc = mc
rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}}
if resLen > 0 {
// Columns

View File

@ -107,8 +107,7 @@ const (
fieldTypeBit
)
const (
fieldTypeJSON byte = iota + 0xf5
fieldTypeNewDecimal
fieldTypeNewDecimal byte = iota + 0xf6
fieldTypeEnum
fieldTypeSet
fieldTypeTinyBLOB

View File

@ -4,7 +4,7 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// Package mysql provides a MySQL driver for Go's database/sql package
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// The driver should be used via the database/sql package:
//
@ -22,7 +22,7 @@ import (
"net"
)
// MySQLDriver is exported to make the driver directly accessible.
// This struct is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
@ -53,19 +53,17 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
maxPacketAllowed: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
}
mc.cfg, err = ParseDSN(dsn)
mc.cfg, err = parseDSN(dsn)
if err != nil {
return nil, err
}
mc.parseTime = mc.cfg.ParseTime
mc.strict = mc.cfg.Strict
// Connect to Server
if dial, ok := dials[mc.cfg.Net]; ok {
mc.netConn, err = dial(mc.cfg.Addr)
if dial, ok := dials[mc.cfg.net]; ok {
mc.netConn, err = dial(mc.cfg.addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
nd := net.Dialer{Timeout: mc.cfg.timeout}
mc.netConn, err = nd.Dial(mc.cfg.net, mc.cfg.addr)
}
if err != nil {
return nil, err
@ -83,30 +81,46 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
cipher, err := mc.readInitPacket()
if err != nil {
mc.cleanup()
mc.Close()
return nil, err
}
// Send Client Authentication Packet
if err = mc.writeAuthPacket(cipher); err != nil {
mc.cleanup()
mc.Close()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = handleAuthResult(mc, cipher); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
// Read Result Packet
err = mc.readResultOK()
if err != nil {
// Retry with old authentication method, if allowed
if mc.cfg != nil && mc.cfg.allowOldPasswords && err == ErrOldPassword {
if err = mc.writeOldAuthPacket(cipher); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else if mc.cfg != nil && mc.cfg.allowCleartextPasswords && err == ErrCleartextPassword {
if err = mc.writeClearAuthPacket(); err != nil {
mc.Close()
return nil, err
}
if err = mc.readResultOK(); err != nil {
mc.Close()
return nil, err
}
} else {
mc.Close()
return nil, err
}
}
// Get max allowed packet size
@ -130,38 +144,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
return mc, nil
}
func handleAuthResult(mc *mysqlConn, cipher []byte) error {
// Read Result Packet
err := mc.readResultOK()
if err == nil {
return nil // auth successful
}
if mc.cfg == nil {
return err // auth failed and retry not possible
}
// Retry auth if configured to do so.
if mc.cfg.AllowOldPasswords && err == ErrOldPassword {
// Retry with old authentication method. Note: there are edge cases
// where this should work but doesn't; this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
if err = mc.writeOldAuthPacket(cipher); err != nil {
return err
}
err = mc.readResultOK()
} else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword {
// Retry with clear text password for
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
if err = mc.writeClearAuthPacket(); err != nil {
return err
}
err = mc.readResultOK()
}
return err
}
func init() {
sql.Register("mysql", &MySQLDriver{})
}

View File

@ -9,14 +9,12 @@
package mysql
import (
"bytes"
"crypto/tls"
"database/sql"
"database/sql/driver"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/url"
"os"
@ -76,36 +74,14 @@ type DBTest struct {
db *sql.DB
}
func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}
dsn += "&multiStatements=true"
var db *sql.DB
if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation {
db, err = sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
}
dbt := &DBTest{t, db}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
}
}
func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
t.Skipf("MySQL-Server not running on %s", netAddr)
}
db, err := sql.Open("mysql", dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
t.Fatalf("Error connecting: %s", err.Error())
}
defer db.Close()
@ -113,27 +89,16 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
dsn2 := dsn + "&interpolateParams=true"
var db2 *sql.DB
if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
if _, err := parseDSN(dsn2); err != errInvalidDSNUnsafeCollation {
db2, err = sql.Open("mysql", dsn2)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
t.Fatalf("Error connecting: %s", err.Error())
}
defer db2.Close()
}
dsn3 := dsn + "&multiStatements=true"
var db3 *sql.DB
if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation {
db3, err = sql.Open("mysql", dsn3)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db3.Close()
}
dbt := &DBTest{t, db}
dbt2 := &DBTest{t, db2}
dbt3 := &DBTest{t, db3}
for _, test := range tests {
test(dbt)
dbt.db.Exec("DROP TABLE IF EXISTS test")
@ -141,10 +106,6 @@ func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) {
test(dbt2)
dbt2.db.Exec("DROP TABLE IF EXISTS test")
}
if db3 != nil {
test(dbt3)
dbt3.db.Exec("DROP TABLE IF EXISTS test")
}
}
}
@ -152,13 +113,13 @@ func (dbt *DBTest) fail(method, query string, err error) {
if len(query) > 300 {
query = "[query too large to print]"
}
dbt.Fatalf("error on %s %s: %s", method, query, err.Error())
dbt.Fatalf("Error on %s %s: %s", method, query, err.Error())
}
func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) {
res, err := dbt.db.Exec(query, args...)
if err != nil {
dbt.fail("exec", query, err)
dbt.fail("Exec", query, err)
}
return res
}
@ -166,7 +127,7 @@ func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result)
func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) {
rows, err := dbt.db.Query(query, args...)
if err != nil {
dbt.fail("query", query, err)
dbt.fail("Query", query, err)
}
return rows
}
@ -177,7 +138,7 @@ func TestEmptyQuery(t *testing.T) {
rows := dbt.mustQuery("--")
// will hang before #255
if rows.Next() {
dbt.Errorf("next on rows must be false")
dbt.Errorf("Next on rows must be false")
}
})
}
@ -201,7 +162,7 @@ func TestCRUD(t *testing.T) {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("expected 1 affected row, got %d", count)
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
id, err := res.LastInsertId()
@ -209,7 +170,7 @@ func TestCRUD(t *testing.T) {
dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error())
}
if id != 0 {
dbt.Fatalf("expected InsertId 0, got %d", id)
dbt.Fatalf("Expected InsertID 0, got %d", id)
}
// Read
@ -234,7 +195,7 @@ func TestCRUD(t *testing.T) {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("expected 1 affected row, got %d", count)
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
// Check Update
@ -259,7 +220,7 @@ func TestCRUD(t *testing.T) {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("expected 1 affected row, got %d", count)
dbt.Fatalf("Expected 1 affected row, got %d", count)
}
// Check for unexpected rows
@ -269,55 +230,11 @@ func TestCRUD(t *testing.T) {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 0 {
dbt.Fatalf("expected 0 affected row, got %d", count)
dbt.Fatalf("Expected 0 affected row, got %d", count)
}
})
}
func TestMultiQuery(t *testing.T) {
runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) {
// Create Table
dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ")
// Create Data
res := dbt.mustExec("INSERT INTO test VALUES (1, 1)")
count, err := res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("expected 1 affected row, got %d", count)
}
// Update
res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;")
count, err = res.RowsAffected()
if err != nil {
dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error())
}
if count != 1 {
dbt.Fatalf("expected 1 affected row, got %d", count)
}
// Read
var out int
rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;")
if rows.Next() {
rows.Scan(&out)
if 5 != out {
dbt.Errorf("5 != %d", out)
}
if rows.Next() {
dbt.Error("unexpected data")
}
} else {
dbt.Error("no data")
}
})
}
func TestInt(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"}
@ -365,7 +282,7 @@ func TestInt(t *testing.T) {
})
}
func TestFloat32(t *testing.T) {
func TestFloat(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [2]string{"FLOAT", "DOUBLE"}
in := float32(42.23)
@ -388,52 +305,6 @@ func TestFloat32(t *testing.T) {
})
}
func TestFloat64(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [2]string{"FLOAT", "DOUBLE"}
var expected float64 = 42.23
var out float64
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (42.23)")
rows = dbt.mustQuery("SELECT value FROM test")
if rows.Next() {
rows.Scan(&out)
if expected != out {
dbt.Errorf("%s: %g != %g", v, expected, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
})
}
func TestFloat64Placeholder(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [2]string{"FLOAT", "DOUBLE"}
var expected float64 = 42.23
var out float64
var rows *sql.Rows
for _, v := range types {
dbt.mustExec("CREATE TABLE test (id int, value " + v + ")")
dbt.mustExec("INSERT INTO test VALUES (1, 42.23)")
rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1)
if rows.Next() {
rows.Scan(&out)
if expected != out {
dbt.Errorf("%s: %g != %g", v, expected, out)
}
} else {
dbt.Errorf("%s: no data", v)
}
dbt.mustExec("DROP TABLE IF EXISTS test")
}
})
}
func TestString(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"}
@ -780,14 +651,14 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if nb.Valid {
dbt.Error("valid NullBool which should be invalid")
dbt.Error("Valid NullBool which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&nb); err != nil {
dbt.Fatal(err)
}
if !nb.Valid {
dbt.Error("invalid NullBool which should be valid")
dbt.Error("Invalid NullBool which should be valid")
} else if nb.Bool != true {
dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool)
}
@ -799,16 +670,16 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if nf.Valid {
dbt.Error("valid NullFloat64 which should be invalid")
dbt.Error("Valid NullFloat64 which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&nf); err != nil {
dbt.Fatal(err)
}
if !nf.Valid {
dbt.Error("invalid NullFloat64 which should be valid")
dbt.Error("Invalid NullFloat64 which should be valid")
} else if nf.Float64 != float64(1) {
dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
dbt.Errorf("Unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64)
}
// NullInt64
@ -818,16 +689,16 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if ni.Valid {
dbt.Error("valid NullInt64 which should be invalid")
dbt.Error("Valid NullInt64 which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&ni); err != nil {
dbt.Fatal(err)
}
if !ni.Valid {
dbt.Error("invalid NullInt64 which should be valid")
dbt.Error("Invalid NullInt64 which should be valid")
} else if ni.Int64 != int64(1) {
dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64)
dbt.Errorf("Unexpected NullInt64 value: %d (should be 1)", ni.Int64)
}
// NullString
@ -837,16 +708,16 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if ns.Valid {
dbt.Error("valid NullString which should be invalid")
dbt.Error("Valid NullString which should be invalid")
}
// Valid
if err = nonNullStmt.QueryRow().Scan(&ns); err != nil {
dbt.Fatal(err)
}
if !ns.Valid {
dbt.Error("invalid NullString which should be valid")
dbt.Error("Invalid NullString which should be valid")
} else if ns.String != `1` {
dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)")
dbt.Error("Unexpected NullString value:" + ns.String + " (should be `1`)")
}
// nil-bytes
@ -856,14 +727,14 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if b != nil {
dbt.Error("non-nil []byte wich should be nil")
dbt.Error("Non-nil []byte wich should be nil")
}
// Read non-nil
if err = nonNullStmt.QueryRow().Scan(&b); err != nil {
dbt.Fatal(err)
}
if b == nil {
dbt.Error("nil []byte wich should be non-nil")
dbt.Error("Nil []byte wich should be non-nil")
}
// Insert nil
b = nil
@ -872,7 +743,7 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if !success {
dbt.Error("inserting []byte(nil) as NULL failed")
dbt.Error("Inserting []byte(nil) as NULL failed")
}
// Check input==output with input==nil
b = nil
@ -880,7 +751,7 @@ func TestNULL(t *testing.T) {
dbt.Fatal(err)
}
if b != nil {
dbt.Error("non-nil echo from nil input")
dbt.Error("Non-nil echo from nil input")
}
// Check input==output with input!=nil
b = []byte("")
@ -947,7 +818,7 @@ func TestUint64(t *testing.T) {
sb != shigh,
sc != stop,
sd != sall:
dbt.Fatal("unexpected result value")
dbt.Fatal("Unexpected result value")
}
})
}
@ -1051,7 +922,7 @@ func TestLoadData(t *testing.T) {
}
if i != 4 {
dbt.Fatalf("rows count mismatch. Got %d, want 4", i)
dbt.Fatalf("Rows count mismatch. Got %d, want 4", i)
}
}
file, err := ioutil.TempFile("", "gotest")
@ -1072,8 +943,8 @@ func TestLoadData(t *testing.T) {
// negative test
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test")
if err == nil {
dbt.Fatal("load non-existent file didn't fail")
} else if err.Error() != "local file 'doesnotexist' is not registered" {
dbt.Fatal("Load non-existent file didn't fail")
} else if err.Error() != "Local File 'doesnotexist' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files" {
dbt.Fatal(err.Error())
}
@ -1093,7 +964,7 @@ func TestLoadData(t *testing.T) {
// negative test
_, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test")
if err == nil {
dbt.Fatal("load non-existent Reader didn't fail")
dbt.Fatal("Load non-existent Reader didn't fail")
} else if err.Error() != "Reader 'doesnotexist' is not registered" {
dbt.Fatal(err.Error())
}
@ -1147,7 +1018,7 @@ func TestFoundRows(t *testing.T) {
func TestStrict(t *testing.T) {
// ALLOW_INVALID_DATES to get rid of stricter modes - we want to test for warnings, not errors
relaxedDsn := dsn + "&sql_mode='ALLOW_INVALID_DATES,NO_AUTO_CREATE_USER'"
relaxedDsn := dsn + "&sql_mode=ALLOW_INVALID_DATES"
// make sure the MySQL version is recent enough with a separate connection
// before running the test
conn, err := MySQLDriver{}.Open(relaxedDsn)
@ -1173,7 +1044,7 @@ func TestStrict(t *testing.T) {
var checkWarnings = func(err error, mode string, idx int) {
if err == nil {
dbt.Errorf("expected STRICT error on query [%s] %s", mode, queries[idx].in)
dbt.Errorf("Expected STRICT error on query [%s] %s", mode, queries[idx].in)
}
if warnings, ok := err.(MySQLWarnings); ok {
@ -1182,18 +1053,18 @@ func TestStrict(t *testing.T) {
codes[i] = warnings[i].Code
}
if len(codes) != len(queries[idx].codes) {
dbt.Errorf("unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
dbt.Errorf("Unexpected STRICT error count on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
}
for i := range warnings {
if codes[i] != queries[idx].codes[i] {
dbt.Errorf("unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
dbt.Errorf("Unexpected STRICT error codes on query [%s] %s: Wanted %v, Got %v", mode, queries[idx].in, queries[idx].codes, codes)
return
}
}
} else {
dbt.Errorf("unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
dbt.Errorf("Unexpected error on query [%s] %s: %s", mode, queries[idx].in, err.Error())
}
}
@ -1209,7 +1080,7 @@ func TestStrict(t *testing.T) {
for i := range queries {
stmt, err = dbt.db.Prepare(queries[i].in)
if err != nil {
dbt.Errorf("error on preparing query %s: %s", queries[i].in, err.Error())
dbt.Errorf("Error on preparing query %s: %s", queries[i].in, err.Error())
}
_, err = stmt.Exec()
@ -1217,7 +1088,7 @@ func TestStrict(t *testing.T) {
err = stmt.Close()
if err != nil {
dbt.Errorf("error on closing stmt for query %s: %s", queries[i].in, err.Error())
dbt.Errorf("Error on closing stmt for query %s: %s", queries[i].in, err.Error())
}
}
})
@ -1227,9 +1098,9 @@ func TestTLS(t *testing.T) {
tlsTest := func(dbt *DBTest) {
if err := dbt.db.Ping(); err != nil {
if err == ErrNoTLS {
dbt.Skip("server does not support TLS")
dbt.Skip("Server does not support TLS")
} else {
dbt.Fatalf("error on Ping: %s", err.Error())
dbt.Fatalf("Error on Ping: %s", err.Error())
}
}
@ -1242,7 +1113,7 @@ func TestTLS(t *testing.T) {
}
if value == nil {
dbt.Fatal("no Cipher")
dbt.Fatal("No Cipher")
}
}
}
@ -1259,42 +1130,42 @@ func TestTLS(t *testing.T) {
func TestReuseClosedConnection(t *testing.T) {
// this test does not use sql.database, it uses the driver directly
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
t.Skipf("MySQL-Server not running on %s", netAddr)
}
md := &MySQLDriver{}
conn, err := md.Open(dsn)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
t.Fatalf("Error connecting: %s", err.Error())
}
stmt, err := conn.Prepare("DO 1")
if err != nil {
t.Fatalf("error preparing statement: %s", err.Error())
t.Fatalf("Error preparing statement: %s", err.Error())
}
_, err = stmt.Exec(nil)
if err != nil {
t.Fatalf("error executing statement: %s", err.Error())
t.Fatalf("Error executing statement: %s", err.Error())
}
err = conn.Close()
if err != nil {
t.Fatalf("error closing connection: %s", err.Error())
t.Fatalf("Error closing connection: %s", err.Error())
}
defer func() {
if err := recover(); err != nil {
t.Errorf("panic after reusing a closed connection: %v", err)
t.Errorf("Panic after reusing a closed connection: %v", err)
}
}()
_, err = stmt.Exec(nil)
if err != nil && err != driver.ErrBadConn {
t.Errorf("unexpected error '%s', expected '%s'",
t.Errorf("Unexpected error '%s', expected '%s'",
err.Error(), driver.ErrBadConn.Error())
}
}
func TestCharset(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
t.Skipf("MySQL-Server not running on %s", netAddr)
}
mustSetCharset := func(charsetParam, expected string) {
@ -1303,14 +1174,14 @@ func TestCharset(t *testing.T) {
defer rows.Close()
if !rows.Next() {
dbt.Fatalf("error getting connection charset: %s", rows.Err())
dbt.Fatalf("Error getting connection charset: %s", rows.Err())
}
var got string
rows.Scan(&got)
if got != expected {
dbt.Fatalf("expected connection charset %s but got %s", expected, got)
dbt.Fatalf("Expected connection charset %s but got %s", expected, got)
}
})
}
@ -1332,14 +1203,14 @@ func TestFailingCharset(t *testing.T) {
_, err := dbt.db.Exec("SELECT 1")
if err == nil {
dbt.db.Close()
t.Fatalf("connection must not succeed without a valid charset")
t.Fatalf("Connection must not succeed without a valid charset")
}
})
}
func TestCollation(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
t.Skipf("MySQL-Server not running on %s", netAddr)
}
defaultCollation := "utf8_general_ci"
@ -1369,7 +1240,7 @@ func TestCollation(t *testing.T) {
}
if got != expected {
dbt.Fatalf("expected connection collation %s but got %s", expected, got)
dbt.Fatalf("Expected connection collation %s but got %s", expected, got)
}
})
}
@ -1434,7 +1305,7 @@ func TestTimezoneConversion(t *testing.T) {
// Retrieve time from DB
rows := dbt.mustQuery("SELECT ts FROM test")
if !rows.Next() {
dbt.Fatal("did not get any rows out")
dbt.Fatal("Didn't get any rows out")
}
var dbTime time.Time
@ -1445,7 +1316,7 @@ func TestTimezoneConversion(t *testing.T) {
// Check that dates match
if reftime.Unix() != dbTime.Unix() {
dbt.Errorf("times do not match.\n")
dbt.Errorf("Times don't match.\n")
dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime)
dbt.Errorf(" Now(UTC)=%v\n", dbTime)
}
@ -1471,7 +1342,7 @@ func TestRowsClose(t *testing.T) {
}
if rows.Next() {
dbt.Fatal("unexpected row after rows.Close()")
dbt.Fatal("Unexpected row after rows.Close()")
}
err = rows.Err()
@ -1503,7 +1374,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
}
if !rows.Next() {
dbt.Fatal("getting row failed")
dbt.Fatal("Getting row failed")
} else {
err = rows.Err()
if err != nil {
@ -1513,7 +1384,7 @@ func TestCloseStmtBeforeRows(t *testing.T) {
var out bool
err = rows.Scan(&out)
if err != nil {
dbt.Fatalf("error on rows.Scan(): %s", err.Error())
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
@ -1549,7 +1420,7 @@ func TestStmtMultiRows(t *testing.T) {
// 1
if !rows1.Next() {
dbt.Fatal("first rows1.Next failed")
dbt.Fatal("1st rows1.Next failed")
} else {
err = rows1.Err()
if err != nil {
@ -1558,7 +1429,7 @@ func TestStmtMultiRows(t *testing.T) {
err = rows1.Scan(&out)
if err != nil {
dbt.Fatalf("error on rows.Scan(): %s", err.Error())
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
@ -1566,7 +1437,7 @@ func TestStmtMultiRows(t *testing.T) {
}
if !rows2.Next() {
dbt.Fatal("first rows2.Next failed")
dbt.Fatal("1st rows2.Next failed")
} else {
err = rows2.Err()
if err != nil {
@ -1575,7 +1446,7 @@ func TestStmtMultiRows(t *testing.T) {
err = rows2.Scan(&out)
if err != nil {
dbt.Fatalf("error on rows.Scan(): %s", err.Error())
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != true {
dbt.Errorf("true != %t", out)
@ -1584,7 +1455,7 @@ func TestStmtMultiRows(t *testing.T) {
// 2
if !rows1.Next() {
dbt.Fatal("second rows1.Next failed")
dbt.Fatal("2nd rows1.Next failed")
} else {
err = rows1.Err()
if err != nil {
@ -1593,14 +1464,14 @@ func TestStmtMultiRows(t *testing.T) {
err = rows1.Scan(&out)
if err != nil {
dbt.Fatalf("error on rows.Scan(): %s", err.Error())
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != false {
dbt.Errorf("false != %t", out)
}
if rows1.Next() {
dbt.Fatal("unexpected row on rows1")
dbt.Fatal("Unexpected row on rows1")
}
err = rows1.Close()
if err != nil {
@ -1609,7 +1480,7 @@ func TestStmtMultiRows(t *testing.T) {
}
if !rows2.Next() {
dbt.Fatal("second rows2.Next failed")
dbt.Fatal("2nd rows2.Next failed")
} else {
err = rows2.Err()
if err != nil {
@ -1618,14 +1489,14 @@ func TestStmtMultiRows(t *testing.T) {
err = rows2.Scan(&out)
if err != nil {
dbt.Fatalf("error on rows.Scan(): %s", err.Error())
dbt.Fatalf("Error on rows.Scan(): %s", err.Error())
}
if out != false {
dbt.Errorf("false != %t", out)
}
if rows2.Next() {
dbt.Fatal("unexpected row on rows2")
dbt.Fatal("Unexpected row on rows2")
}
err = rows2.Close()
if err != nil {
@ -1670,7 +1541,7 @@ func TestConcurrent(t *testing.T) {
if err != nil {
dbt.Fatalf("%s", err.Error())
}
dbt.Logf("testing up to %d concurrent connections \r\n", max)
dbt.Logf("Testing up to %d concurrent connections \r\n", max)
var remaining, succeeded int32 = int32(max), 0
@ -1694,7 +1565,7 @@ func TestConcurrent(t *testing.T) {
if err != nil {
if err.Error() != "Error 1040: Too many connections" {
fatalf("error on conn %d: %s", id, err.Error())
fatalf("Error on Conn %d: %s", id, err.Error())
}
return
}
@ -1702,13 +1573,13 @@ func TestConcurrent(t *testing.T) {
// keep the connection busy until all connections are open
for remaining > 0 {
if _, err = tx.Exec("DO 1"); err != nil {
fatalf("error on conn %d: %s", id, err.Error())
fatalf("Error on Conn %d: %s", id, err.Error())
return
}
}
if err = tx.Commit(); err != nil {
fatalf("error on conn %d: %s", id, err.Error())
fatalf("Error on Conn %d: %s", id, err.Error())
return
}
@ -1724,14 +1595,14 @@ func TestConcurrent(t *testing.T) {
dbt.Fatal(fatalError)
}
dbt.Logf("reached %d concurrent connections\r\n", succeeded)
dbt.Logf("Reached %d concurrent connections\r\n", succeeded)
})
}
// Tests custom dial functions
func TestCustomDial(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
t.Skipf("MySQL-Server not running on %s", netAddr)
}
// our custom dial function which justs wraps net.Dial here
@ -1741,16 +1612,16 @@ func TestCustomDial(t *testing.T) {
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s&strict=true", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
t.Fatalf("Error connecting: %s", err.Error())
}
defer db.Close()
if _, err = db.Exec("DO 1"); err != nil {
t.Fatalf("connection failed: %s", err.Error())
t.Fatalf("Connection failed: %s", err.Error())
}
}
func TestSQLInjection(t *testing.T) {
func TestSqlInjection(t *testing.T) {
createTest := func(arg string) func(dbt *DBTest) {
return func(dbt *DBTest) {
dbt.mustExec("CREATE TABLE test (v INTEGER)")
@ -1763,16 +1634,16 @@ func TestSQLInjection(t *testing.T) {
if err == sql.ErrNoRows {
return // success, sql injection failed
} else if err == nil {
dbt.Errorf("sql injection successful with arg: %s", arg)
dbt.Errorf("Sql injection successful with arg: %s", arg)
} else {
dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error())
dbt.Errorf("Error running query with arg: %s; err: %s", arg, err.Error())
}
}
}
dsns := []string{
dsn,
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
}
for _, testdsn := range dsns {
runTests(t, testdsn, createTest("1 OR 1=1"))
@ -1802,56 +1673,9 @@ func TestInsertRetrieveEscapedData(t *testing.T) {
dsns := []string{
dsn,
dsn + "&sql_mode='NO_BACKSLASH_ESCAPES,NO_AUTO_CREATE_USER'",
dsn + "&sql_mode=NO_BACKSLASH_ESCAPES",
}
for _, testdsn := range dsns {
runTests(t, testdsn, testData)
}
}
func TestUnixSocketAuthFail(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
// Save the current logger so we can restore it.
oldLogger := errLog
// Set a new logger so we can capture its output.
buffer := bytes.NewBuffer(make([]byte, 0, 64))
newLogger := log.New(buffer, "prefix: ", 0)
SetLogger(newLogger)
// Restore the logger.
defer SetLogger(oldLogger)
// Make a new DSN that uses the MySQL socket file and a bad password, which
// we can make by simply appending any character to the real password.
badPass := pass + "x"
socket := ""
if prot == "unix" {
socket = addr
} else {
// Get socket file from MySQL.
err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket)
if err != nil {
t.Fatalf("error on SELECT @@socket: %s", err.Error())
}
}
t.Logf("socket: %s", socket)
badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s&strict=true", user, badPass, socket, dbname)
db, err := sql.Open("mysql", badDSN)
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
// Connect to MySQL for real. This will cause an auth failure.
err = db.Ping()
if err == nil {
t.Error("expected Ping() to return an error")
}
// The driver should not log anything.
if actual := buffer.String(); actual != "" {
t.Errorf("expected no output, got %q", actual)
}
})
}

View File

@ -1,513 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"net"
"net/url"
"strings"
"time"
)
var (
errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
)
// Config is a configuration parsed from a DSN string
type Config struct {
User string // Username
Passwd string // Password (requires User)
Net string // Network type
Addr string // Network address (requires Net)
DBName string // Database name
Params map[string]string // Connection parameters
Collation string // Connection collation
Loc *time.Location // Location for time.Time values
TLSConfig string // TLS configuration name
tls *tls.Config // TLS configuration
Timeout time.Duration // Dial timeout
ReadTimeout time.Duration // I/O read timeout
WriteTimeout time.Duration // I/O write timeout
AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
AllowCleartextPasswords bool // Allows the cleartext client side plugin
AllowOldPasswords bool // Allows the old insecure password method
ClientFoundRows bool // Return number of matching rows instead of rows changed
ColumnsWithAlias bool // Prepend table alias to column names
InterpolateParams bool // Interpolate placeholders into query string
MultiStatements bool // Allow multiple statements in one query
ParseTime bool // Parse time values to time.Time
Strict bool // Return warnings as errors
}
// FormatDSN formats the given Config into a DSN string which can be passed to
// the driver.
func (cfg *Config) FormatDSN() string {
var buf bytes.Buffer
// [username[:password]@]
if len(cfg.User) > 0 {
buf.WriteString(cfg.User)
if len(cfg.Passwd) > 0 {
buf.WriteByte(':')
buf.WriteString(cfg.Passwd)
}
buf.WriteByte('@')
}
// [protocol[(address)]]
if len(cfg.Net) > 0 {
buf.WriteString(cfg.Net)
if len(cfg.Addr) > 0 {
buf.WriteByte('(')
buf.WriteString(cfg.Addr)
buf.WriteByte(')')
}
}
// /dbname
buf.WriteByte('/')
buf.WriteString(cfg.DBName)
// [?param1=value1&...&paramN=valueN]
hasParam := false
if cfg.AllowAllFiles {
hasParam = true
buf.WriteString("?allowAllFiles=true")
}
if cfg.AllowCleartextPasswords {
if hasParam {
buf.WriteString("&allowCleartextPasswords=true")
} else {
hasParam = true
buf.WriteString("?allowCleartextPasswords=true")
}
}
if cfg.AllowOldPasswords {
if hasParam {
buf.WriteString("&allowOldPasswords=true")
} else {
hasParam = true
buf.WriteString("?allowOldPasswords=true")
}
}
if cfg.ClientFoundRows {
if hasParam {
buf.WriteString("&clientFoundRows=true")
} else {
hasParam = true
buf.WriteString("?clientFoundRows=true")
}
}
if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
if hasParam {
buf.WriteString("&collation=")
} else {
hasParam = true
buf.WriteString("?collation=")
}
buf.WriteString(col)
}
if cfg.ColumnsWithAlias {
if hasParam {
buf.WriteString("&columnsWithAlias=true")
} else {
hasParam = true
buf.WriteString("?columnsWithAlias=true")
}
}
if cfg.InterpolateParams {
if hasParam {
buf.WriteString("&interpolateParams=true")
} else {
hasParam = true
buf.WriteString("?interpolateParams=true")
}
}
if cfg.Loc != time.UTC && cfg.Loc != nil {
if hasParam {
buf.WriteString("&loc=")
} else {
hasParam = true
buf.WriteString("?loc=")
}
buf.WriteString(url.QueryEscape(cfg.Loc.String()))
}
if cfg.MultiStatements {
if hasParam {
buf.WriteString("&multiStatements=true")
} else {
hasParam = true
buf.WriteString("?multiStatements=true")
}
}
if cfg.ParseTime {
if hasParam {
buf.WriteString("&parseTime=true")
} else {
hasParam = true
buf.WriteString("?parseTime=true")
}
}
if cfg.ReadTimeout > 0 {
if hasParam {
buf.WriteString("&readTimeout=")
} else {
hasParam = true
buf.WriteString("?readTimeout=")
}
buf.WriteString(cfg.ReadTimeout.String())
}
if cfg.Strict {
if hasParam {
buf.WriteString("&strict=true")
} else {
hasParam = true
buf.WriteString("?strict=true")
}
}
if cfg.Timeout > 0 {
if hasParam {
buf.WriteString("&timeout=")
} else {
hasParam = true
buf.WriteString("?timeout=")
}
buf.WriteString(cfg.Timeout.String())
}
if len(cfg.TLSConfig) > 0 {
if hasParam {
buf.WriteString("&tls=")
} else {
hasParam = true
buf.WriteString("?tls=")
}
buf.WriteString(url.QueryEscape(cfg.TLSConfig))
}
if cfg.WriteTimeout > 0 {
if hasParam {
buf.WriteString("&writeTimeout=")
} else {
hasParam = true
buf.WriteString("?writeTimeout=")
}
buf.WriteString(cfg.WriteTimeout.String())
}
// other params
if cfg.Params != nil {
for param, value := range cfg.Params {
if hasParam {
buf.WriteByte('&')
} else {
hasParam = true
buf.WriteByte('?')
}
buf.WriteString(param)
buf.WriteByte('=')
buf.WriteString(url.QueryEscape(value))
}
}
return buf.String()
}
// ParseDSN parses the DSN string to a Config
func ParseDSN(dsn string) (cfg *Config, err error) {
// New config with some default values
cfg = &Config{
Loc: time.UTC,
Collation: defaultCollation,
}
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.Passwd = dsn[k+1 : j]
break
}
}
cfg.User = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.Addr = dsn[k+1 : i-1]
break
}
}
cfg.Net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.DBName = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
return nil, errInvalidDSNUnsafeCollation
}
// Set default network if empty
if cfg.Net == "" {
cfg.Net = "tcp"
}
// Set default address if empty
if cfg.Addr == "" {
switch cfg.Net {
case "tcp":
cfg.Addr = "127.0.0.1:3306"
case "unix":
cfg.Addr = "/tmp/mysql.sock"
default:
return nil, errors.New("default addr for network '" + cfg.Net + "' unknown")
}
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *Config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// cfg params
switch value := param[1]; param[0] {
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.AllowAllFiles, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use cleartext authentication mode (MySQL 5.5.10+)
case "allowCleartextPasswords":
var isBool bool
cfg.AllowCleartextPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.AllowOldPasswords, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.ClientFoundRows, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Collation
case "collation":
cfg.Collation = value
break
case "columnsWithAlias":
var isBool bool
cfg.ColumnsWithAlias, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Compression
case "compress":
return errors.New("compression not implemented yet")
// Enable client side placeholder substitution
case "interpolateParams":
var isBool bool
cfg.InterpolateParams, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.Loc, err = time.LoadLocation(value)
if err != nil {
return
}
// multiple statements in one query
case "multiStatements":
var isBool bool
cfg.MultiStatements, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// time.Time parsing
case "parseTime":
var isBool bool
cfg.ParseTime, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// I/O read Timeout
case "readTimeout":
cfg.ReadTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
// Strict mode
case "strict":
var isBool bool
cfg.Strict, isBool = readBool(value)
if !isBool {
return errors.New("invalid bool value: " + value)
}
// Dial Timeout
case "timeout":
cfg.Timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.TLSConfig = "true"
cfg.tls = &tls.Config{}
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}
if tlsConfig, ok := tlsConfigRegister[name]; ok {
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
tlsConfig.ServerName = host
}
}
cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
return errors.New("invalid value / unknown config name: " + name)
}
}
// I/O write Timeout
case "writeTimeout":
cfg.WriteTimeout, err = time.ParseDuration(value)
if err != nil {
return
}
default:
// lazy init
if cfg.Params == nil {
cfg.Params = make(map[string]string)
}
if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}

View File

@ -1,231 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"fmt"
"net/url"
"reflect"
"testing"
"time"
)
var testDSNs = []struct {
in string
out *Config
}{{
"username:password@protocol(address)/dbname?param=value",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, ColumnsWithAlias: true},
}, {
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8_general_ci", Loc: time.UTC, ColumnsWithAlias: true, MultiStatements: true},
}, {
"user@unix(/path/to/socket)/dbname?charset=utf8",
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, TLSConfig: "true"},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8_general_ci", Loc: time.UTC, TLSConfig: "skip-verify"},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true},
}, {
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.Local},
}, {
"/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"@/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"user:p@/ssword@/",
&Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8_general_ci", Loc: time.UTC},
}, {
"unix/?arg=%2Fsome%2Fpath.ext",
&Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8_general_ci", Loc: time.UTC},
}}
func TestDSNParser(t *testing.T) {
for i, tst := range testDSNs {
cfg, err := ParseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
if !reflect.DeepEqual(cfg, tst.out) {
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := ParseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func TestDSNReformat(t *testing.T) {
for i, tst := range testDSNs {
dsn1 := tst.in
cfg1, err := ParseDSN(dsn1)
if err != nil {
t.Error(err.Error())
continue
}
cfg1.tls = nil // pointer not static
res1 := fmt.Sprintf("%+v", cfg1)
dsn2 := cfg1.FormatDSN()
cfg2, err := ParseDSN(dsn2)
if err != nil {
t.Error(err.Error())
continue
}
cfg2.tls = nil // pointer not static
res2 := fmt.Sprintf("%+v", cfg2)
if res1 != res2 {
t.Errorf("%d. %q does not match %q", i, res2, res1)
}
}
}
func TestDSNWithCustomTLS(t *testing.T) {
baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
tlsCfg := tls.Config{}
RegisterTLSConfig("utils_test", &tlsCfg)
// Custom TLS is missing
tst := baseDSN + "invalid_tls"
cfg, err := ParseDSN(tst)
if err == nil {
t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
tst = baseDSN + "utils_test"
// Custom TLS with a server name
name := "foohost"
tlsCfg.ServerName = name
cfg, err = ParseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
}
// Custom TLS without a server name
name = "localhost"
tlsCfg.ServerName = ""
cfg, err = ParseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
}
DeregisterTLSConfig("utils_test")
}
func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
const configKey = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
name := "foohost"
tlsCfg := tls.Config{ServerName: name}
RegisterTLSConfig(configKey, &tlsCfg)
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
}
}
func TestDSNUnsafeCollation(t *testing.T) {
_, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
if err != errInvalidDSNUnsafeCollation {
t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err)
}
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := ParseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}

View File

@ -19,20 +19,20 @@ import (
// Various errors the driver might return. Can change between driver versions.
var (
ErrInvalidConn = errors.New("invalid connection")
ErrMalformPkt = errors.New("malformed packet")
ErrNoTLS = errors.New("TLS requested but server does not support TLS")
ErrOldPassword = errors.New("this user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrCleartextPassword = errors.New("this user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN")
ErrUnknownPlugin = errors.New("this authentication plugin is not supported")
ErrOldProtocol = errors.New("MySQL server does not support required protocol 41+")
ErrPktSync = errors.New("commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server")
ErrBusyBuffer = errors.New("busy buffer")
ErrInvalidConn = errors.New("Invalid Connection")
ErrMalformPkt = errors.New("Malformed Packet")
ErrNoTLS = errors.New("TLS encryption requested but server does not support TLS")
ErrOldPassword = errors.New("This user requires old password authentication. If you still want to use it, please add 'allowOldPasswords=1' to your DSN. See also https://github.com/go-sql-driver/mysql/wiki/old_passwords")
ErrCleartextPassword = errors.New("This user requires clear text authentication. If you still want to use it, please add 'allowCleartextPasswords=1' to your DSN.")
ErrUnknownPlugin = errors.New("The authentication plugin is not supported.")
ErrOldProtocol = errors.New("MySQL-Server does not support required Protocol 41+")
ErrPktSync = errors.New("Commands out of sync. You can't run this command now")
ErrPktSyncMul = errors.New("Commands out of sync. Did you run multiple statements at once?")
ErrPktTooLarge = errors.New("Packet for query is too large. You can change this value on the server by adjusting the 'max_allowed_packet' variable.")
ErrBusyBuffer = errors.New("Busy buffer")
)
var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile))
var errLog Logger = log.New(os.Stderr, "[MySQL] ", log.Ldate|log.Ltime|log.Lshortfile)
// Logger is used to log critical error messages.
type Logger interface {

View File

@ -96,10 +96,6 @@ func deferredClose(err *error, closer io.Closer) {
func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
var rdr io.Reader
var data []byte
packetSize := 16 * 1024 // 16KB is small enough for disk readahead and large enough for TCP
if mc.maxWriteSize < packetSize {
packetSize = mc.maxWriteSize
}
if idx := strings.Index(name, "Reader::"); idx == 0 || (idx > 0 && name[idx-1] == '/') { // io.Reader
// The server might return an an absolute path. See issue #355.
@ -112,6 +108,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
if inMap {
rdr = handler()
if rdr != nil {
data = make([]byte, 4+mc.maxWriteSize)
if cl, ok := rdr.(io.Closer); ok {
defer deferredClose(&err, cl)
}
@ -126,7 +124,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
fileRegisterLock.RLock()
fr := fileRegister[name]
fileRegisterLock.RUnlock()
if mc.cfg.AllowAllFiles || fr {
if mc.cfg.allowAllFiles || fr {
var file *os.File
var fi os.FileInfo
@ -136,19 +134,22 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
// get file size
if fi, err = file.Stat(); err == nil {
rdr = file
if fileSize := int(fi.Size()); fileSize < packetSize {
packetSize = fileSize
if fileSize := int(fi.Size()); fileSize <= mc.maxWriteSize {
data = make([]byte, 4+fileSize)
} else if fileSize <= mc.maxPacketAllowed {
data = make([]byte, 4+mc.maxWriteSize)
} else {
err = fmt.Errorf("Local File '%s' too large: Size: %d, Max: %d", name, fileSize, mc.maxPacketAllowed)
}
}
}
} else {
err = fmt.Errorf("local file '%s' is not registered", name)
err = fmt.Errorf("Local File '%s' is not registered. Use the DSN parameter 'allowAllFiles=true' to allow all files", name)
}
}
// send content packets
if err == nil {
data := make([]byte, 4+packetSize)
var n int
for err == nil {
n, err = rdr.Read(data[4:])
@ -174,8 +175,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) {
// read OK packet
if err == nil {
return mc.readResultOK()
} else {
mc.readPacket()
}
mc.readPacket()
return err
}

View File

@ -13,7 +13,6 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"math"
@ -48,8 +47,9 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
if data[3] != mc.sequence {
if data[3] > mc.sequence {
return nil, ErrPktSyncMul
} else {
return nil, ErrPktSync
}
return nil, ErrPktSync
}
mc.sequence++
@ -100,12 +100,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
data[3] = mc.sequence
// Write packet
if mc.writeTimeout > 0 {
if err := mc.netConn.SetWriteDeadline(time.Now().Add(mc.writeTimeout)); err != nil {
return err
}
}
n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
@ -146,7 +140,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
// protocol version [1 byte]
if data[0] < minProtocolVersion {
return nil, fmt.Errorf(
"unsupported protocol version %d. Version %d or higher is required",
"Unsupported MySQL Protocol Version %d. Protocol Version %d or higher is required",
data[0],
minProtocolVersion,
)
@ -225,10 +219,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientTransactions |
clientLocalFiles |
clientPluginAuth |
clientMultiResults |
mc.flags&clientLongFlag
if mc.cfg.ClientFoundRows {
if mc.cfg.clientFoundRows {
clientFlags |= clientFoundRows
}
@ -237,17 +230,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
clientFlags |= clientSSL
}
if mc.cfg.MultiStatements {
clientFlags |= clientMultiStatements
}
// User Password
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd))
scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.passwd))
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.user) + 1 + 1 + len(scrambleBuff) + 21 + 1
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
if n := len(mc.cfg.dbname); n > 0 {
clientFlags |= clientConnectWithDB
pktLen += n + 1
}
@ -273,14 +262,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
data[11] = 0x00
// Charset [1 byte]
var found bool
data[12], found = collations[mc.cfg.Collation]
if !found {
// Note possibility for false negatives:
// could be triggered although the collation is valid if the
// collations map does not contain entries the server supports.
return errors.New("unknown collation")
}
data[12] = mc.cfg.collation
// SSL Connection Request Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
@ -296,7 +278,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
return err
}
mc.netConn = tlsConn
mc.buf.nc = tlsConn
mc.buf.rd = tlsConn
}
// Filler [23 bytes] (all 0x00)
@ -306,8 +288,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
}
// User [null terminated string]
if len(mc.cfg.User) > 0 {
pos += copy(data[pos:], mc.cfg.User)
if len(mc.cfg.user) > 0 {
pos += copy(data[pos:], mc.cfg.user)
}
data[pos] = 0x00
pos++
@ -317,8 +299,8 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
pos += 1 + copy(data[pos+1:], scrambleBuff)
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
pos += copy(data[pos:], mc.cfg.DBName)
if len(mc.cfg.dbname) > 0 {
pos += copy(data[pos:], mc.cfg.dbname)
data[pos] = 0x00
pos++
}
@ -335,7 +317,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// User password
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd))
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
// Calculate the packet length and add a tailing 0
pktLen := len(scrambleBuff) + 1
@ -357,7 +339,7 @@ func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeClearAuthPacket() error {
// Calculate the packet length and add a tailing 0
pktLen := len(mc.cfg.Passwd) + 1
pktLen := len(mc.cfg.passwd) + 1
data := mc.buf.takeSmallBuffer(4 + pktLen)
if data == nil {
// can not take the buffer. Something must be wrong with the connection
@ -366,7 +348,7 @@ func (mc *mysqlConn) writeClearAuthPacket() error {
}
// Add the clear password [null terminated string]
copy(data[4:], mc.cfg.Passwd)
copy(data[4:], mc.cfg.passwd)
data[4+pktLen-1] = 0x00
return mc.writePacket(data)
@ -532,10 +514,6 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
}
}
func readStatus(b []byte) statusFlag {
return statusFlag(b[0]) | statusFlag(b[1])<<8
}
// Ok Packet
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
func (mc *mysqlConn) handleOkPacket(data []byte) error {
@ -550,21 +528,18 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
mc.insertId, _, m = readLengthEncodedInteger(data[1+n:])
// server_status [2 bytes]
mc.status = readStatus(data[1+n+m : 1+n+m+2])
if err := mc.discardResults(); err != nil {
return err
}
mc.status = statusFlag(data[1+n+m]) | statusFlag(data[1+n+m+1])<<8
// warning count [2 bytes]
if !mc.strict {
return nil
} else {
pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
return mc.getWarnings()
}
return nil
}
pos := 1 + n + m + 2
if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 {
return mc.getWarnings()
}
return nil
}
// Read Packets as Field Packets until EOF-Packet or an Error appears
@ -583,7 +558,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
if i == count {
return columns, nil
}
return nil, fmt.Errorf("column count mismatch n:%d len:%d", count, len(columns))
return nil, fmt.Errorf("ColumnsCount mismatch n:%d len:%d", count, len(columns))
}
// Catalog
@ -600,7 +575,7 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
pos += n
// Table [len coded string]
if mc.cfg.ColumnsWithAlias {
if mc.cfg.columnsWithAlias {
tableName, _, n, err := readLengthEncodedString(data[pos:])
if err != nil {
return nil, err
@ -672,11 +647,6 @@ func (rows *textRows) readRow(dest []driver.Value) error {
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
// server_status [2 bytes]
rows.mc.status = readStatus(data[3:])
if err := rows.mc.discardResults(); err != nil {
return err
}
rows.mc = nil
return io.EOF
}
@ -704,7 +674,7 @@ func (rows *textRows) readRow(dest []driver.Value) error {
fieldTypeDate, fieldTypeNewDate:
dest[i], err = parseDateTime(
string(dest[i].([]byte)),
mc.cfg.Loc,
mc.cfg.loc,
)
if err == nil {
continue
@ -734,10 +704,6 @@ func (mc *mysqlConn) readUntilEOF() error {
if err == nil && data[0] != iEOF {
continue
}
if err == nil && data[0] == iEOF && len(data) == 5 {
mc.status = readStatus(data[3:])
}
return err // Err or EOF
}
}
@ -770,13 +736,13 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
// Warning count [16 bit uint]
if !stmt.mc.strict {
return columnCount, nil
} else {
// Check for warnings count > 0, only available in MySQL > 4.1
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
return columnCount, stmt.mc.getWarnings()
}
return columnCount, nil
}
// Check for warnings count > 0, only available in MySQL > 4.1
if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 {
return columnCount, stmt.mc.getWarnings()
}
return columnCount, nil
}
return 0, err
}
@ -838,7 +804,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if len(args) != stmt.paramCount {
return fmt.Errorf(
"argument count mismatch (got: %d; has: %d)",
"Arguments count mismatch (Got: %d Has: %d)",
len(args),
stmt.paramCount,
)
@ -1015,7 +981,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
if v.IsZero() {
val = []byte("0000-00-00")
} else {
val = []byte(v.In(mc.cfg.Loc).Format(timeFormat))
val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
}
paramValues = appendLengthEncodedInteger(paramValues,
@ -1024,7 +990,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
paramValues = append(paramValues, val...)
default:
return fmt.Errorf("can not convert type: %T", arg)
return fmt.Errorf("Can't convert type: %T", arg)
}
}
@ -1042,28 +1008,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
return mc.writePacket(data)
}
func (mc *mysqlConn) discardResults() error {
for mc.status&statusMoreResultsExists != 0 {
resLen, err := mc.readResultSetHeaderPacket()
if err != nil {
return err
}
if resLen > 0 {
// columns
if err := mc.readUntilEOF(); err != nil {
return err
}
// rows
if err := mc.readUntilEOF(); err != nil {
return err
}
} else {
mc.status &^= statusMoreResultsExists
}
}
return nil
}
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
func (rows *binaryRows) readRow(dest []driver.Value) error {
data, err := rows.mc.readPacket()
@ -1073,16 +1017,11 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// packet indicator [1 byte]
if data[0] != iOK {
rows.mc = nil
// EOF Packet
if data[0] == iEOF && len(data) == 5 {
rows.mc.status = readStatus(data[3:])
if err := rows.mc.discardResults(); err != nil {
return err
}
rows.mc = nil
return io.EOF
}
rows.mc = nil
// Error otherwise
return rows.mc.handleErrorPacket(data)
@ -1149,7 +1088,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
continue
case fieldTypeFloat:
dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
dest[i] = float64(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])))
pos += 4
continue
@ -1162,7 +1101,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar,
fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB,
fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB,
fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON:
fieldTypeVarString, fieldTypeString, fieldTypeGeometry:
var isNull bool
var n int
dest[i], isNull, n, err = readLengthEncodedString(data[pos:])
@ -1199,13 +1138,13 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
dstlen = 8 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
"MySQL protocol error, illegal decimals value %d",
rows.columns[i].decimals,
)
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
case rows.mc.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.loc)
default:
var dstlen uint8
if rows.columns[i].fieldType == fieldTypeDate {
@ -1218,7 +1157,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
dstlen = 19 + 1 + decimals
default:
return fmt.Errorf(
"protocol error, illegal decimals value %d",
"MySQL protocol error, illegal decimals value %d",
rows.columns[i].decimals,
)
}
@ -1235,7 +1174,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
// Please report if this happens!
default:
return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType)
return fmt.Errorf("Unknown FieldType %d", rows.columns[i].fieldType)
}
}

View File

@ -38,7 +38,7 @@ type emptyRows struct{}
func (rows *mysqlRows) Columns() []string {
columns := make([]string, len(rows.columns))
if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias {
if rows.mc.cfg.columnsWithAlias {
for i := range columns {
if tableName := rows.columns[i].tableName; len(tableName) > 0 {
columns[i] = tableName + "." + rows.columns[i].name
@ -65,12 +65,6 @@ func (rows *mysqlRows) Close() error {
// Remove unread packets from stream
err := mc.readUntilEOF()
if err == nil {
if err = mc.discardResults(); err != nil {
return err
}
}
rows.mc = nil
return err
}

View File

@ -101,9 +101,9 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) {
}
rows := new(binaryRows)
rows.mc = mc
if resLen > 0 {
rows.mc = mc
// Columns
// If not cached, read them and cache them
if stmt.columns == nil {

View File

@ -13,16 +13,28 @@ import (
"crypto/tls"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"net"
"net/url"
"strings"
"time"
)
var (
tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name")
errInvalidDSNUnsafeCollation = errors.New("Invalid DSN: interpolateParams can be used with ascii, latin1, utf8 and utf8mb4 charset")
)
func init() {
tlsConfigRegister = make(map[string]*tls.Config)
}
// RegisterTLSConfig registers a custom tls.Config to be used with sql.Open.
// Use the key as a value in the DSN where tls=value.
//
@ -48,11 +60,7 @@ var (
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("key '%s' is reserved", key)
}
if tlsConfigRegister == nil {
tlsConfigRegister = make(map[string]*tls.Config)
return fmt.Errorf("Key '%s' is reserved", key)
}
tlsConfigRegister[key] = config
@ -61,9 +69,234 @@ func RegisterTLSConfig(key string, config *tls.Config) error {
// DeregisterTLSConfig removes the tls.Config associated with key.
func DeregisterTLSConfig(key string) {
if tlsConfigRegister != nil {
delete(tlsConfigRegister, key)
delete(tlsConfigRegister, key)
}
// parseDSN parses the DSN string to a config
func parseDSN(dsn string) (cfg *config, err error) {
// New config with some default values
cfg = &config{
loc: time.UTC,
collation: defaultCollation,
}
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
break
}
}
cfg.user = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.addr = dsn[k+1 : i-1]
break
}
}
cfg.net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.dbname = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
if cfg.interpolateParams && unsafeCollations[cfg.collation] {
return nil, errInvalidDSNUnsafeCollation
}
// Set default network if empty
if cfg.net == "" {
cfg.net = "tcp"
}
// Set default address if empty
if cfg.addr == "" {
switch cfg.net {
case "tcp":
cfg.addr = "127.0.0.1:3306"
case "unix":
cfg.addr = "/tmp/mysql.sock"
default:
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
}
}
return
}
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
// cfg params
switch value := param[1]; param[0] {
// Enable client side placeholder substitution
case "interpolateParams":
var isBool bool
cfg.interpolateParams, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Disable INFILE whitelist / enable all files
case "allowAllFiles":
var isBool bool
cfg.allowAllFiles, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Use cleartext authentication mode (MySQL 5.5.10+)
case "allowCleartextPasswords":
var isBool bool
cfg.allowCleartextPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Use old authentication mode (pre MySQL 4.1)
case "allowOldPasswords":
var isBool bool
cfg.allowOldPasswords, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Switch "rowsAffected" mode
case "clientFoundRows":
var isBool bool
cfg.clientFoundRows, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Collation
case "collation":
collation, ok := collations[value]
if !ok {
// Note possibility for false negatives:
// could be triggered although the collation is valid if the
// collations map does not contain entries the server supports.
err = errors.New("unknown collation")
return
}
cfg.collation = collation
break
case "columnsWithAlias":
var isBool bool
cfg.columnsWithAlias, isBool = readBool(value)
if !isBool {
return fmt.Errorf("Invalid Bool value: %s", value)
}
// Time Location
case "loc":
if value, err = url.QueryUnescape(value); err != nil {
return
}
cfg.loc, err = time.LoadLocation(value)
if err != nil {
return
}
// Dial Timeout
case "timeout":
cfg.timeout, err = time.ParseDuration(value)
if err != nil {
return
}
// TLS-Encryption
case "tls":
boolValue, isBool := readBool(value)
if isBool {
if boolValue {
cfg.tls = &tls.Config{}
}
} else {
if strings.ToLower(value) == "skip-verify" {
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else if tlsConfig, ok := tlsConfigRegister[value]; ok {
if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.addr)
if err == nil {
tlsConfig.ServerName = host
}
}
cfg.tls = tlsConfig
} else {
return fmt.Errorf("Invalid value / unknown config name: %s", value)
}
}
default:
// lazy init
if cfg.params == nil {
cfg.params = make(map[string]string)
}
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
}
return
}
// Returns the bool value of the input.
@ -260,7 +493,7 @@ func parseDateTime(str string, loc *time.Location) (t time.Time, err error) {
}
t, err = time.Parse(timeFormat[:len(str)], str)
default:
err = fmt.Errorf("invalid time string: %s", str)
err = fmt.Errorf("Invalid Time-String: %s", str)
return
}
@ -309,7 +542,7 @@ func parseBinaryDateTime(num uint64, data []byte, loc *time.Location) (driver.Va
loc,
), nil
}
return nil, fmt.Errorf("invalid DATETIME packet length %d", num)
return nil, fmt.Errorf("Invalid DATETIME-packet length %d", num)
}
// zeroDateTime is used in formatBinaryDateTime to avoid an allocation
@ -344,7 +577,7 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value
switch len(src) {
case 8, 12:
default:
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
return nil, fmt.Errorf("Invalid TIME-packet length %d", len(src))
}
// +2 to enable negative time and 100+ hours
dst = make([]byte, 0, length+2)
@ -378,7 +611,7 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
return nil, fmt.Errorf("illegal %s-packet length %d", t, len(src))
}
dst = make([]byte, 0, length)
// start with the date
@ -644,7 +877,7 @@ func escapeBytesBackslash(buf, v []byte) []byte {
pos += 2
default:
buf[pos] = c
pos++
pos += 1
}
}
@ -689,7 +922,7 @@ func escapeStringBackslash(buf []byte, v string) []byte {
pos += 2
default:
buf[pos] = c
pos++
pos += 1
}
}

View File

@ -10,12 +10,161 @@ package mysql
import (
"bytes"
"crypto/tls"
"encoding/binary"
"fmt"
"testing"
"time"
)
var testDSNs = []struct {
in string
out string
loc *time.Location
}{
{"username:password@protocol(address)/dbname?param=value", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", "&{user:username passwd:password net:protocol addr:address dbname:dbname params:map[param:value] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:true interpolateParams:false}", time.UTC},
{"user@unix(/path/to/socket)/dbname?charset=utf8", "&{user:user passwd: net:unix addr:/path/to/socket dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", "&{user:user passwd:password net:tcp addr:localhost:5555 dbname:dbname params:map[charset:utf8mb4,utf8] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:password@/dbname?loc=UTC&timeout=30s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci", "&{user:user passwd:password net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:30000000000 collation:224 allowAllFiles:true allowOldPasswords:true allowCleartextPasswords:false clientFoundRows:true columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", "&{user:user passwd:p@ss(word) net:tcp addr:[de:ad:be:ef::ca:fe]:80 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.Local},
{"/dbname", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname:dbname params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"@/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"/", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"", "&{user: passwd: net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"user:p@/ssword@/", "&{user:user passwd:p@/ssword net:tcp addr:127.0.0.1:3306 dbname: params:map[] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
{"unix/?arg=%2Fsome%2Fpath.ext", "&{user: passwd: net:unix addr:/tmp/mysql.sock dbname: params:map[arg:/some/path.ext] loc:%p tls:<nil> timeout:0 collation:33 allowAllFiles:false allowOldPasswords:false allowCleartextPasswords:false clientFoundRows:false columnsWithAlias:false interpolateParams:false}", time.UTC},
}
func TestDSNParser(t *testing.T) {
var cfg *config
var err error
var res string
for i, tst := range testDSNs {
cfg, err = parseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
res = fmt.Sprintf("%+v", cfg)
if res != fmt.Sprintf(tst.out, tst.loc) {
t.Errorf("%d. parseDSN(%q) => %q, want %q", i, tst.in, res, fmt.Sprintf(tst.out, tst.loc))
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
"user:pass@tcp(1.2.3.4:3306)", // no trailing slash
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := parseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func TestDSNWithCustomTLS(t *testing.T) {
baseDSN := "user:password@tcp(localhost:5555)/dbname?tls="
tlsCfg := tls.Config{}
RegisterTLSConfig("utils_test", &tlsCfg)
// Custom TLS is missing
tst := baseDSN + "invalid_tls"
cfg, err := parseDSN(tst)
if err == nil {
t.Errorf("Invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
tst = baseDSN + "utils_test"
// Custom TLS with a server name
name := "foohost"
tlsCfg.ServerName = name
cfg, err = parseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("Did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
}
// Custom TLS without a server name
name = "localhost"
tlsCfg.ServerName = ""
cfg, err = parseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("Did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
}
DeregisterTLSConfig("utils_test")
}
func TestDSNUnsafeCollation(t *testing.T) {
_, err := parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
if err != errInvalidDSNUnsafeCollation {
t.Error("Expected %v, Got %v", errInvalidDSNUnsafeCollation, err)
}
_, err = parseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
if err != nil {
t.Error("Expected %v, Got %v", nil, err)
}
_, err = parseDSN("/dbname?collation=gbk_chinese_ci")
if err != nil {
t.Error("Expected %v, Got %v", nil, err)
}
_, err = parseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
if err != nil {
t.Error("Expected %v, Got %v", nil, err)
}
_, err = parseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
if err != nil {
t.Error("Expected %v, Got %v", nil, err)
}
_, err = parseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
if err != nil {
t.Error("Expected %v, Got %v", nil, err)
}
_, err = parseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
if err != nil {
t.Error("Expected %v, Got %v", nil, err)
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := parseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}
func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}

View File

@ -1,5 +1,6 @@
.DS_Store
*.[568ao]
*.pb.go
*.ao
*.so
*.pyc
@ -12,4 +13,5 @@ core
_obj
_test
_testmain.go
protoc-gen-go/testdata/multi/*.pb.go
compiler/protoc-gen-go
compiler/testdata/extension_test

View File

@ -33,11 +33,13 @@
all: install
install:
go install ./proto ./jsonpb ./ptypes
go install ./proto
go install ./jsonpb
go install ./protoc-gen-go
test:
go test ./proto ./jsonpb ./ptypes
go test ./proto
go test ./jsonpb
make -C protoc-gen-go/testdata test
clean:

View File

@ -140,7 +140,6 @@ To create and play with a Test object from the example package,
test := &example.Test {
Label: proto.String("hello"),
Type: proto.Int32(17),
Reps: []int64{1, 2, 3},
Optionalgroup: &example.Test_OptionalGroup {
RequiredField: proto.String("good bye"),
},
@ -191,9 +190,3 @@ the `plugins` parameter to protoc-gen-go; the usual way is to insert it into
the --go_out argument to protoc:
protoc --go_out=plugins=grpc:. *.proto
## Plugins ##
The `protoc-gen-go/generator` package exposes a plugin interface,
which is used by the gRPC code generation. This interface is not
supported and is subject to incompatible changes without notice.

View File

@ -41,41 +41,37 @@ package jsonpb
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"reflect"
"sort"
"strconv"
"strings"
"time"
"github.com/golang/protobuf/proto"
)
var (
byteArrayType = reflect.TypeOf([]byte{})
)
// Marshaler is a configurable object for converting between
// protocol buffer objects and a JSON representation for them.
// protocol buffer objects and a JSON representation for them
type Marshaler struct {
// Whether to render enum values as integers, as opposed to string values.
EnumsAsInts bool
// Whether to render fields with zero values.
EmitDefaults bool
// A string to indent each level by. The presence of this field will
// also cause a space to appear between the field separator and
// value, and for newlines to be appear between fields and array
// elements.
Indent string
// Whether to use the original (.proto) name for fields.
OrigName bool
}
// Marshal marshals a protocol buffer into JSON.
func (m *Marshaler) Marshal(out io.Writer, pb proto.Message) error {
writer := &errWriter{writer: out}
return m.marshalObject(writer, pb, "", "")
return m.marshalObject(writer, pb, "")
}
// MarshalToString converts a protocol buffer object to JSON string.
@ -94,83 +90,15 @@ func (s int32Slice) Len() int { return len(s) }
func (s int32Slice) Less(i, j int) bool { return s[i] < s[j] }
func (s int32Slice) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
type wkt interface {
XXX_WellKnownType() string
}
// marshalObject writes a struct to the Writer.
func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeURL string) error {
s := reflect.ValueOf(v).Elem()
// Handle well-known types.
if wkt, ok := v.(wkt); ok {
switch wkt.XXX_WellKnownType() {
case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value",
"Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue":
// "Wrappers use the same representation in JSON
// as the wrapped primitive type, ..."
sprop := proto.GetProperties(s.Type())
return m.marshalValue(out, sprop.Prop[0], s.Field(0), indent)
case "Any":
// Any is a bit more involved.
return m.marshalAny(out, v, indent)
case "Duration":
// "Generated output always contains 3, 6, or 9 fractional digits,
// depending on required precision."
s, ns := s.Field(0).Int(), s.Field(1).Int()
d := time.Duration(s)*time.Second + time.Duration(ns)*time.Nanosecond
x := fmt.Sprintf("%.9f", d.Seconds())
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
out.write(`"`)
out.write(x)
out.write(`s"`)
return out.err
case "Struct":
// Let marshalValue handle the `fields` map.
// TODO: pass the correct Properties if needed.
return m.marshalValue(out, &proto.Properties{}, s.Field(0), indent)
case "Timestamp":
// "RFC 3339, where generated output will always be Z-normalized
// and uses 3, 6 or 9 fractional digits."
s, ns := s.Field(0).Int(), s.Field(1).Int()
t := time.Unix(s, ns).UTC()
// time.RFC3339Nano isn't exactly right (we need to get 3/6/9 fractional digits).
x := t.Format("2006-01-02T15:04:05.000000000")
x = strings.TrimSuffix(x, "000")
x = strings.TrimSuffix(x, "000")
out.write(`"`)
out.write(x)
out.write(`Z"`)
return out.err
case "Value":
// Value has a single oneof.
kind := s.Field(0)
if kind.IsNil() {
// "absence of any variant indicates an error"
return errors.New("nil Value")
}
// oneof -> *T -> T -> T.F
x := kind.Elem().Elem().Field(0)
// TODO: pass the correct Properties if needed.
return m.marshalValue(out, &proto.Properties{}, x, indent)
}
}
func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent string) error {
out.write("{")
if m.Indent != "" {
out.write("\n")
}
s := reflect.ValueOf(v).Elem()
firstField := true
if typeURL != "" {
if err := m.marshalTypeURL(out, indent, typeURL); err != nil {
return err
}
firstField = false
}
for i := 0; i < s.NumField(); i++ {
value := s.Field(i)
valueField := s.Type().Field(i)
@ -178,6 +106,8 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
continue
}
// TODO: proto3 objects should have default values omitted.
// IsNil will panic on most value kinds.
switch value.Kind() {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
@ -186,31 +116,6 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
}
}
if !m.EmitDefaults {
switch value.Kind() {
case reflect.Bool:
if !value.Bool() {
continue
}
case reflect.Int32, reflect.Int64:
if value.Int() == 0 {
continue
}
case reflect.Uint32, reflect.Uint64:
if value.Uint() == 0 {
continue
}
case reflect.Float32, reflect.Float64:
if value.Float() == 0 {
continue
}
case reflect.String:
if value.Len() == 0 {
continue
}
}
}
// Oneof fields need special handling.
if valueField.Tag.Get("protobuf_oneof") != "" {
// value is an interface containing &T{real_value}.
@ -218,7 +123,7 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
value = sv.Field(0)
valueField = sv.Type().Field(0)
}
prop := jsonProperties(valueField, m.OrigName)
prop := jsonProperties(valueField)
if !firstField {
m.writeSep(out)
}
@ -229,14 +134,12 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
}
// Handle proto2 extensions.
if ep, ok := v.(proto.Message); ok {
if ep, ok := v.(extendableProto); ok {
extensions := proto.RegisteredExtensions(v)
extensionMap := ep.ExtensionMap()
// Sort extensions for stable output.
ids := make([]int32, 0, len(extensions))
for id, desc := range extensions {
if !proto.HasExtension(ep, desc) {
continue
}
ids := make([]int32, 0, len(extensionMap))
for id := range extensionMap {
ids = append(ids, id)
}
sort.Sort(int32Slice(ids))
@ -253,7 +156,7 @@ func (m *Marshaler) marshalObject(out *errWriter, v proto.Message, indent, typeU
value := reflect.ValueOf(ext)
var prop proto.Properties
prop.Parse(desc.Tag)
prop.JSONName = fmt.Sprintf("[%s]", desc.Name)
prop.OrigName = fmt.Sprintf("[%s]", desc.Name)
if !firstField {
m.writeSep(out)
}
@ -281,70 +184,6 @@ func (m *Marshaler) writeSep(out *errWriter) {
}
}
func (m *Marshaler) marshalAny(out *errWriter, any proto.Message, indent string) error {
// "If the Any contains a value that has a special JSON mapping,
// it will be converted as follows: {"@type": xxx, "value": yyy}.
// Otherwise, the value will be converted into a JSON object,
// and the "@type" field will be inserted to indicate the actual data type."
v := reflect.ValueOf(any).Elem()
turl := v.Field(0).String()
val := v.Field(1).Bytes()
// Only the part of type_url after the last slash is relevant.
mname := turl
if slash := strings.LastIndex(mname, "/"); slash >= 0 {
mname = mname[slash+1:]
}
mt := proto.MessageType(mname)
if mt == nil {
return fmt.Errorf("unknown message type %q", mname)
}
msg := reflect.New(mt.Elem()).Interface().(proto.Message)
if err := proto.Unmarshal(val, msg); err != nil {
return err
}
if _, ok := msg.(wkt); ok {
out.write("{")
if m.Indent != "" {
out.write("\n")
}
if err := m.marshalTypeURL(out, indent, turl); err != nil {
return err
}
m.writeSep(out)
out.write(`"value":`)
if err := m.marshalObject(out, msg, indent, ""); err != nil {
return err
}
if m.Indent != "" {
out.write("\n")
out.write(indent)
}
out.write("}")
return out.err
}
return m.marshalObject(out, msg, indent, turl)
}
func (m *Marshaler) marshalTypeURL(out *errWriter, indent, typeURL string) error {
if m.Indent != "" {
out.write(indent)
out.write(m.Indent)
}
out.write(`"@type":`)
if m.Indent != "" {
out.write(" ")
}
b, err := json.Marshal(typeURL)
if err != nil {
return err
}
out.write(string(b))
return out.err
}
// marshalField writes field description and value to the Writer.
func (m *Marshaler) marshalField(out *errWriter, prop *proto.Properties, v reflect.Value, indent string) error {
if m.Indent != "" {
@ -352,7 +191,7 @@ func (m *Marshaler) marshalField(out *errWriter, prop *proto.Properties, v refle
out.write(m.Indent)
}
out.write(`"`)
out.write(prop.JSONName)
out.write(prop.OrigName)
out.write(`":`)
if m.Indent != "" {
out.write(" ")
@ -370,7 +209,7 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle
v = reflect.Indirect(v)
// Handle repeated elements.
if v.Kind() == reflect.Slice && v.Type().Elem().Kind() != reflect.Uint8 {
if v.Type() != byteArrayType && v.Kind() == reflect.Slice {
out.write("[")
comma := ""
for i := 0; i < v.Len(); i++ {
@ -394,19 +233,6 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle
return out.err
}
// Handle well-known types.
// Most are handled up in marshalObject (because 99% are messages).
type wkt interface {
XXX_WellKnownType() string
}
if wkt, ok := v.Interface().(wkt); ok {
switch wkt.XXX_WellKnownType() {
case "NullValue":
out.write("null")
return out.err
}
}
// Handle enumerations.
if !m.EnumsAsInts && prop.Enum != "" {
// Unknown enum values will are stringified by the proto library as their
@ -432,7 +258,7 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle
// Handle nested messages.
if v.Kind() == reflect.Struct {
return m.marshalObject(out, v.Addr().Interface().(proto.Message), indent+m.Indent, "")
return m.marshalObject(out, v.Addr().Interface().(proto.Message), indent+m.Indent)
}
// Handle maps.
@ -502,23 +328,15 @@ func (m *Marshaler) marshalValue(out *errWriter, prop *proto.Properties, v refle
return out.err
}
// UnmarshalNext unmarshals the next protocol buffer from a JSON object stream.
// This function is lenient and will decode any options permutations of the
// related Marshaler.
func UnmarshalNext(dec *json.Decoder, pb proto.Message) error {
inputValue := json.RawMessage{}
if err := dec.Decode(&inputValue); err != nil {
return err
}
return unmarshalValue(reflect.ValueOf(pb).Elem(), inputValue, nil)
}
// Unmarshal unmarshals a JSON object stream into a protocol
// buffer. This function is lenient and will decode any options
// permutations of the related Marshaler.
func Unmarshal(r io.Reader, pb proto.Message) error {
dec := json.NewDecoder(r)
return UnmarshalNext(dec, pb)
inputValue := json.RawMessage{}
if err := json.NewDecoder(r).Decode(&inputValue); err != nil {
return err
}
return unmarshalValue(reflect.ValueOf(pb).Elem(), inputValue)
}
// UnmarshalString will populate the fields of a protocol buffer based
@ -529,83 +347,13 @@ func UnmarshalString(str string, pb proto.Message) error {
}
// unmarshalValue converts/copies a value into the target.
// prop may be nil.
func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *proto.Properties) error {
func unmarshalValue(target reflect.Value, inputValue json.RawMessage) error {
targetType := target.Type()
// Allocate memory for pointer fields.
if targetType.Kind() == reflect.Ptr {
target.Set(reflect.New(targetType.Elem()))
return unmarshalValue(target.Elem(), inputValue, prop)
}
// Handle well-known types.
type wkt interface {
XXX_WellKnownType() string
}
if wkt, ok := target.Addr().Interface().(wkt); ok {
switch wkt.XXX_WellKnownType() {
case "DoubleValue", "FloatValue", "Int64Value", "UInt64Value",
"Int32Value", "UInt32Value", "BoolValue", "StringValue", "BytesValue":
// "Wrappers use the same representation in JSON
// as the wrapped primitive type, except that null is allowed."
// encoding/json will turn JSON `null` into Go `nil`,
// so we don't have to do any extra work.
return unmarshalValue(target.Field(0), inputValue, prop)
case "Any":
return fmt.Errorf("unmarshaling Any not supported yet")
case "Duration":
unq, err := strconv.Unquote(string(inputValue))
if err != nil {
return err
}
d, err := time.ParseDuration(unq)
if err != nil {
return fmt.Errorf("bad Duration: %v", err)
}
ns := d.Nanoseconds()
s := ns / 1e9
ns %= 1e9
target.Field(0).SetInt(s)
target.Field(1).SetInt(ns)
return nil
case "Timestamp":
unq, err := strconv.Unquote(string(inputValue))
if err != nil {
return err
}
t, err := time.Parse(time.RFC3339Nano, unq)
if err != nil {
return fmt.Errorf("bad Timestamp: %v", err)
}
ns := t.UnixNano()
s := ns / 1e9
ns %= 1e9
target.Field(0).SetInt(s)
target.Field(1).SetInt(ns)
return nil
}
}
// Handle enums, which have an underlying type of int32,
// and may appear as strings.
// The case of an enum appearing as a number is handled
// at the bottom of this function.
if inputValue[0] == '"' && prop != nil && prop.Enum != "" {
vmap := proto.EnumValueMap(prop.Enum)
// Don't need to do unquoting; valid enum names
// are from a limited character set.
s := inputValue[1 : len(inputValue)-1]
n, ok := vmap[string(s)]
if !ok {
return fmt.Errorf("unknown value %q for enum %s", s, prop.Enum)
}
if target.Kind() == reflect.Ptr { // proto2
target.Set(reflect.New(targetType.Elem()))
target = target.Elem()
}
target.SetInt(int64(n))
return nil
return unmarshalValue(target.Elem(), inputValue)
}
// Handle nested messages.
@ -615,56 +363,56 @@ func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *prot
return err
}
consumeField := func(prop *proto.Properties) (json.RawMessage, bool) {
// Be liberal in what names we accept; both orig_name and camelName are okay.
fieldNames := acceptedJSONFieldNames(prop)
vOrig, okOrig := jsonFields[fieldNames.orig]
vCamel, okCamel := jsonFields[fieldNames.camel]
if !okOrig && !okCamel {
return nil, false
}
// If, for some reason, both are present in the data, favour the camelName.
var raw json.RawMessage
if okOrig {
raw = vOrig
delete(jsonFields, fieldNames.orig)
}
if okCamel {
raw = vCamel
delete(jsonFields, fieldNames.camel)
}
return raw, true
}
sprops := proto.GetProperties(targetType)
for i := 0; i < target.NumField(); i++ {
ft := target.Type().Field(i)
if strings.HasPrefix(ft.Name, "XXX_") {
continue
}
fieldName := jsonProperties(ft).OrigName
valueForField, ok := consumeField(sprops.Prop[i])
valueForField, ok := jsonFields[fieldName]
if !ok {
continue
}
delete(jsonFields, fieldName)
if err := unmarshalValue(target.Field(i), valueForField, sprops.Prop[i]); err != nil {
// Handle enums, which have an underlying type of int32,
// and may appear as strings. We do this while handling
// the struct so we have access to the enum info.
// The case of an enum appearing as a number is handled
// by the recursive call to unmarshalValue.
if enum := sprops.Prop[i].Enum; valueForField[0] == '"' && enum != "" {
vmap := proto.EnumValueMap(enum)
// Don't need to do unquoting; valid enum names
// are from a limited character set.
s := valueForField[1 : len(valueForField)-1]
n, ok := vmap[string(s)]
if !ok {
return fmt.Errorf("unknown value %q for enum %s", s, enum)
}
f := target.Field(i)
if f.Kind() == reflect.Ptr { // proto2
f.Set(reflect.New(f.Type().Elem()))
f = f.Elem()
}
f.SetInt(int64(n))
continue
}
if err := unmarshalValue(target.Field(i), valueForField); err != nil {
return err
}
}
// Check for any oneof fields.
if len(jsonFields) > 0 {
for _, oop := range sprops.OneofTypes {
raw, ok := consumeField(oop.Prop)
if !ok {
continue
}
for fname, raw := range jsonFields {
if oop, ok := sprops.OneofTypes[fname]; ok {
nv := reflect.New(oop.Type.Elem())
target.Field(oop.Field).Set(nv)
if err := unmarshalValue(nv.Elem().Field(0), raw, oop.Prop); err != nil {
if err := unmarshalValue(nv.Elem().Field(0), raw); err != nil {
return err
}
delete(jsonFields, fname)
}
}
if len(jsonFields) > 0 {
@ -680,7 +428,7 @@ func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *prot
}
// Handle arrays (which aren't encoded bytes)
if targetType.Kind() == reflect.Slice && targetType.Elem().Kind() != reflect.Uint8 {
if targetType != byteArrayType && targetType.Kind() == reflect.Slice {
var slc []json.RawMessage
if err := json.Unmarshal(inputValue, &slc); err != nil {
return err
@ -688,7 +436,7 @@ func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *prot
len := len(slc)
target.Set(reflect.MakeSlice(targetType, len, len))
for i := 0; i < len; i++ {
if err := unmarshalValue(target.Index(i), slc[i], prop); err != nil {
if err := unmarshalValue(target.Index(i), slc[i]); err != nil {
return err
}
}
@ -702,13 +450,6 @@ func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *prot
return err
}
target.Set(reflect.MakeMap(targetType))
var keyprop, valprop *proto.Properties
if prop != nil {
// These could still be nil if the protobuf metadata is broken somehow.
// TODO: This won't work because the fields are unexported.
// We should probably just reparse them.
//keyprop, valprop = prop.mkeyprop, prop.mvalprop
}
for ks, raw := range mp {
// Unmarshal map key. The core json library already decoded the key into a
// string, so we handle that specially. Other types were quoted post-serialization.
@ -717,14 +458,14 @@ func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *prot
k = reflect.ValueOf(ks)
} else {
k = reflect.New(targetType.Key()).Elem()
if err := unmarshalValue(k, json.RawMessage(ks), keyprop); err != nil {
if err := unmarshalValue(k, json.RawMessage(ks)); err != nil {
return err
}
}
// Unmarshal map value.
v := reflect.New(targetType.Elem()).Elem()
if err := unmarshalValue(v, raw, valprop); err != nil {
if err := unmarshalValue(v, raw); err != nil {
return err
}
target.SetMapIndex(k, v)
@ -743,26 +484,18 @@ func unmarshalValue(target reflect.Value, inputValue json.RawMessage, prop *prot
return json.Unmarshal(inputValue, target.Addr().Interface())
}
// jsonProperties returns parsed proto.Properties for the field and corrects JSONName attribute.
func jsonProperties(f reflect.StructField, origName bool) *proto.Properties {
// jsonProperties returns parsed proto.Properties for the field.
func jsonProperties(f reflect.StructField) *proto.Properties {
var prop proto.Properties
prop.Init(f.Type, f.Name, f.Tag.Get("protobuf"), &f)
if origName || prop.JSONName == "" {
prop.JSONName = prop.OrigName
}
return &prop
}
type fieldNames struct {
orig, camel string
}
func acceptedJSONFieldNames(prop *proto.Properties) fieldNames {
opts := fieldNames{orig: prop.OrigName, camel: prop.OrigName}
if prop.JSONName != "" {
opts.camel = prop.JSONName
}
return opts
// extendableProto is an interface implemented by any protocol buffer that may be extended.
type extendableProto interface {
proto.Message
ExtensionRangeArray() []proto.ExtensionRange
ExtensionMap() map[int32]proto.Extension
}
// Writer wrapper inspired by https://blog.golang.org/errors-are-values

View File

@ -32,21 +32,12 @@
package jsonpb
import (
"bytes"
"encoding/json"
"io"
"reflect"
"testing"
"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/jsonpb/jsonpb_test_proto"
"github.com/golang/protobuf/proto"
proto3pb "github.com/golang/protobuf/proto/proto3_proto"
anypb "github.com/golang/protobuf/ptypes/any"
durpb "github.com/golang/protobuf/ptypes/duration"
stpb "github.com/golang/protobuf/ptypes/struct"
tspb "github.com/golang/protobuf/ptypes/timestamp"
wpb "github.com/golang/protobuf/ptypes/wrappers"
)
var (
@ -71,31 +62,31 @@ var (
}
simpleObjectJSON = `{` +
`"oBool":true,` +
`"oInt32":-32,` +
`"oInt64":"-6400000000",` +
`"oUint32":32,` +
`"oUint64":"6400000000",` +
`"oSint32":-13,` +
`"oSint64":"-2600000000",` +
`"oFloat":3.14,` +
`"oDouble":6.02214179e+23,` +
`"oString":"hello \"there\"",` +
`"oBytes":"YmVlcCBib29w"` +
`"o_bool":true,` +
`"o_int32":-32,` +
`"o_int64":"-6400000000",` +
`"o_uint32":32,` +
`"o_uint64":"6400000000",` +
`"o_sint32":-13,` +
`"o_sint64":"-2600000000",` +
`"o_float":3.14,` +
`"o_double":6.02214179e+23,` +
`"o_string":"hello \"there\"",` +
`"o_bytes":"YmVlcCBib29w"` +
`}`
simpleObjectPrettyJSON = `{
"oBool": true,
"oInt32": -32,
"oInt64": "-6400000000",
"oUint32": 32,
"oUint64": "6400000000",
"oSint32": -13,
"oSint64": "-2600000000",
"oFloat": 3.14,
"oDouble": 6.02214179e+23,
"oString": "hello \"there\"",
"oBytes": "YmVlcCBib29w"
"o_bool": true,
"o_int32": -32,
"o_int64": "-6400000000",
"o_uint32": 32,
"o_uint64": "6400000000",
"o_sint32": -13,
"o_sint64": "-2600000000",
"o_float": 3.14,
"o_double": 6.02214179e+23,
"o_string": "hello \"there\"",
"o_bytes": "YmVlcCBib29w"
}`
repeatsObject = &pb.Repeats{
@ -113,65 +104,65 @@ var (
}
repeatsObjectJSON = `{` +
`"rBool":[true,false,true],` +
`"rInt32":[-3,-4,-5],` +
`"rInt64":["-123456789","-987654321"],` +
`"rUint32":[1,2,3],` +
`"rUint64":["6789012345","3456789012"],` +
`"rSint32":[-1,-2,-3],` +
`"rSint64":["-6789012345","-3456789012"],` +
`"rFloat":[3.14,6.28],` +
`"rDouble":[2.99792458e+08,6.62606957e-34],` +
`"rString":["happy","days"],` +
`"rBytes":["c2tpdHRsZXM=","bSZtJ3M="]` +
`"r_bool":[true,false,true],` +
`"r_int32":[-3,-4,-5],` +
`"r_int64":["-123456789","-987654321"],` +
`"r_uint32":[1,2,3],` +
`"r_uint64":["6789012345","3456789012"],` +
`"r_sint32":[-1,-2,-3],` +
`"r_sint64":["-6789012345","-3456789012"],` +
`"r_float":[3.14,6.28],` +
`"r_double":[2.99792458e+08,6.62606957e-34],` +
`"r_string":["happy","days"],` +
`"r_bytes":["c2tpdHRsZXM=","bSZtJ3M="]` +
`}`
repeatsObjectPrettyJSON = `{
"rBool": [
"r_bool": [
true,
false,
true
],
"rInt32": [
"r_int32": [
-3,
-4,
-5
],
"rInt64": [
"r_int64": [
"-123456789",
"-987654321"
],
"rUint32": [
"r_uint32": [
1,
2,
3
],
"rUint64": [
"r_uint64": [
"6789012345",
"3456789012"
],
"rSint32": [
"r_sint32": [
-1,
-2,
-3
],
"rSint64": [
"r_sint64": [
"-6789012345",
"-3456789012"
],
"rFloat": [
"r_float": [
3.14,
6.28
],
"rDouble": [
"r_double": [
2.99792458e+08,
6.62606957e-34
],
"rString": [
"r_string": [
"happy",
"days"
],
"rBytes": [
"r_bytes": [
"c2tpdHRsZXM=",
"bSZtJ3M="
]
@ -191,46 +182,46 @@ var (
}
complexObjectJSON = `{"color":"GREEN",` +
`"rColor":["RED","GREEN","BLUE"],` +
`"simple":{"oInt32":-32},` +
`"rSimple":[{"oInt32":-32},{"oInt64":"25"}],` +
`"repeats":{"rString":["roses","red"]},` +
`"rRepeats":[{"rString":["roses","red"]},{"rString":["violets","blue"]}]` +
`"r_color":["RED","GREEN","BLUE"],` +
`"simple":{"o_int32":-32},` +
`"r_simple":[{"o_int32":-32},{"o_int64":"25"}],` +
`"repeats":{"r_string":["roses","red"]},` +
`"r_repeats":[{"r_string":["roses","red"]},{"r_string":["violets","blue"]}]` +
`}`
complexObjectPrettyJSON = `{
"color": "GREEN",
"rColor": [
"r_color": [
"RED",
"GREEN",
"BLUE"
],
"simple": {
"oInt32": -32
"o_int32": -32
},
"rSimple": [
"r_simple": [
{
"oInt32": -32
"o_int32": -32
},
{
"oInt64": "25"
"o_int64": "25"
}
],
"repeats": {
"rString": [
"r_string": [
"roses",
"red"
]
},
"rRepeats": [
"r_repeats": [
{
"rString": [
"r_string": [
"roses",
"red"
]
},
{
"rString": [
"r_string": [
"violets",
"blue"
]
@ -244,7 +235,7 @@ var (
colorListPrettyJSON = `{
"color": 1000,
"rColor": [
"r_color": [
"RED"
]
}`
@ -300,20 +291,7 @@ var marshalingTests = []struct {
&pb.Widget{Color: pb.Widget_BLUE.Enum()}, colorPrettyJSON},
{"unknown enum value object", marshalerAllOptions,
&pb.Widget{Color: pb.Widget_Color(1000).Enum(), RColor: []pb.Widget_Color{pb.Widget_RED}}, colorListPrettyJSON},
{"repeated proto3 enum", Marshaler{},
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}},
`{"rFunny":["PUNS","SLAPSTICK"]}`},
{"repeated proto3 enum as int", Marshaler{EnumsAsInts: true},
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}},
`{"rFunny":[1,2]}`},
{"empty value", marshaler, &pb.Simple3{}, `{}`},
{"empty value emitted", Marshaler{EmitDefaults: true}, &pb.Simple3{}, `{"dub":0}`},
{"proto3 object with empty value", marshaler, &pb.Simple3{}, `{"dub":0}`},
{"map<int64, int32>", marshaler, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}, `{"nummy":{"1":2,"3":4}}`},
{"map<int64, int32>", marshalerAllOptions, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}, nummyPrettyJSON},
{"map<string, string>", marshaler,
@ -326,53 +304,14 @@ var marshalingTests = []struct {
{"map<int64, string>", marshaler, &pb.Mappy{Buggy: map[int64]string{1234: "yup"}},
`{"buggy":{"1234":"yup"}}`},
{"map<bool, bool>", marshaler, &pb.Mappy{Booly: map[bool]bool{false: true}}, `{"booly":{"false":true}}`},
// TODO: This is broken.
//{"map<string, enum>", marshaler, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}, `{"enumy":{"XIV":"ROMAN"}`},
{"map<string, enum as int>", Marshaler{EnumsAsInts: true}, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}, `{"enumy":{"XIV":2}}`},
{"proto2 map<int64, string>", marshaler, &pb.Maps{MInt64Str: map[int64]string{213: "cat"}},
`{"mInt64Str":{"213":"cat"}}`},
`{"m_int64_str":{"213":"cat"}}`},
{"proto2 map<bool, Object>", marshaler,
&pb.Maps{MBoolSimple: map[bool]*pb.Simple{true: &pb.Simple{OInt32: proto.Int32(1)}}},
`{"mBoolSimple":{"true":{"oInt32":1}}}`},
`{"m_bool_simple":{"true":{"o_int32":1}}}`},
{"oneof, not set", marshaler, &pb.MsgWithOneof{}, `{}`},
{"oneof, set", marshaler, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Title{"Grand Poobah"}}, `{"title":"Grand Poobah"}`},
{"force orig_name", Marshaler{OrigName: true}, &pb.Simple{OInt32: proto.Int32(4)},
`{"o_int32":4}`},
{"proto2 extension", marshaler, realNumber, realNumberJSON},
{"Any with message", marshaler, &pb.KnownTypes{An: &anypb.Any{
TypeUrl: "something.example.com/jsonpb.Simple",
Value: []byte{
// &pb.Simple{OBool:true}
1 << 3, 1,
},
}}, `{"an":{"@type":"something.example.com/jsonpb.Simple","oBool":true}}`},
{"Any with WKT", marshaler, &pb.KnownTypes{An: &anypb.Any{
TypeUrl: "type.googleapis.com/google.protobuf.Duration",
Value: []byte{
// &durpb.Duration{Seconds: 1, Nanos: 212000000 }
1 << 3, 1, // seconds
2 << 3, 0x80, 0xba, 0x8b, 0x65, // nanos
},
}}, `{"an":{"@type":"type.googleapis.com/google.protobuf.Duration","value":"1.212s"}}`},
{"Duration", marshaler, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}, `{"dur":"3.000s"}`},
{"Struct", marshaler, &pb.KnownTypes{St: &stpb.Struct{
Fields: map[string]*stpb.Value{
"one": &stpb.Value{Kind: &stpb.Value_StringValue{"loneliest number"}},
"two": &stpb.Value{Kind: &stpb.Value_NullValue{stpb.NullValue_NULL_VALUE}},
},
}}, `{"st":{"one":"loneliest number","two":null}}`},
{"Timestamp", marshaler, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}, `{"ts":"2014-05-13T16:53:20.021Z"}`},
{"DoubleValue", marshaler, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}, `{"dbl":1.2}`},
{"FloatValue", marshaler, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}, `{"flt":1.2}`},
{"Int64Value", marshaler, &pb.KnownTypes{I64: &wpb.Int64Value{Value: -3}}, `{"i64":"-3"}`},
{"UInt64Value", marshaler, &pb.KnownTypes{U64: &wpb.UInt64Value{Value: 3}}, `{"u64":"3"}`},
{"Int32Value", marshaler, &pb.KnownTypes{I32: &wpb.Int32Value{Value: -4}}, `{"i32":-4}`},
{"UInt32Value", marshaler, &pb.KnownTypes{U32: &wpb.UInt32Value{Value: 4}}, `{"u32":4}`},
{"BoolValue", marshaler, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}, `{"bool":true}`},
{"StringValue", marshaler, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}, `{"str":"plush"}`},
{"BytesValue", marshaler, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}, `{"bytes":"d293"}`},
}
func TestMarshaling(t *testing.T) {
@ -404,49 +343,12 @@ var unmarshalingTests = []struct {
{"unknown enum value object",
"{\n \"color\": 1000,\n \"r_color\": [\n \"RED\"\n ]\n}",
&pb.Widget{Color: pb.Widget_Color(1000).Enum(), RColor: []pb.Widget_Color{pb.Widget_RED}}},
{"repeated proto3 enum", `{"rFunny":["PUNS","SLAPSTICK"]}`,
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}}},
{"repeated proto3 enum as int", `{"rFunny":[1,2]}`,
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}}},
{"repeated proto3 enum as mix of strings and ints", `{"rFunny":["PUNS",2]}`,
&proto3pb.Message{RFunny: []proto3pb.Message_Humour{
proto3pb.Message_PUNS,
proto3pb.Message_SLAPSTICK,
}}},
{"unquoted int64 object", `{"oInt64":-314}`, &pb.Simple{OInt64: proto.Int64(-314)}},
{"unquoted uint64 object", `{"oUint64":123}`, &pb.Simple{OUint64: proto.Uint64(123)}},
{"unquoted int64 object", `{"o_int64":-314}`, &pb.Simple{OInt64: proto.Int64(-314)}},
{"unquoted uint64 object", `{"o_uint64":123}`, &pb.Simple{OUint64: proto.Uint64(123)}},
{"map<int64, int32>", `{"nummy":{"1":2,"3":4}}`, &pb.Mappy{Nummy: map[int64]int32{1: 2, 3: 4}}},
{"map<string, string>", `{"strry":{"\"one\"":"two","three":"four"}}`, &pb.Mappy{Strry: map[string]string{`"one"`: "two", "three": "four"}}},
{"map<int32, Object>", `{"objjy":{"1":{"dub":1}}}`, &pb.Mappy{Objjy: map[int32]*pb.Simple3{1: &pb.Simple3{Dub: 1}}}},
// TODO: This is broken.
//{"map<string, enum>", `{"enumy":{"XIV":"ROMAN"}`, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}},
{"map<string, enum as int>", `{"enumy":{"XIV":2}}`, &pb.Mappy{Enumy: map[string]pb.Numeral{"XIV": pb.Numeral_ROMAN}}},
{"oneof", `{"salary":31000}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Salary{31000}}},
{"oneof spec name", `{"country":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Country{"Australia"}}},
{"oneof orig_name", `{"Country":"Australia"}`, &pb.MsgWithOneof{Union: &pb.MsgWithOneof_Country{"Australia"}}},
{"orig_name input", `{"o_bool":true}`, &pb.Simple{OBool: proto.Bool(true)}},
{"camelName input", `{"oBool":true}`, &pb.Simple{OBool: proto.Bool(true)}},
{"Duration", `{"dur":"3.000s"}`, &pb.KnownTypes{Dur: &durpb.Duration{Seconds: 3}}},
{"Timestamp", `{"ts":"2014-05-13T16:53:20.021Z"}`, &pb.KnownTypes{Ts: &tspb.Timestamp{Seconds: 14e8, Nanos: 21e6}}},
{"DoubleValue", `{"dbl":1.2}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{Value: 1.2}}},
{"FloatValue", `{"flt":1.2}`, &pb.KnownTypes{Flt: &wpb.FloatValue{Value: 1.2}}},
{"Int64Value", `{"i64":"-3"}`, &pb.KnownTypes{I64: &wpb.Int64Value{Value: -3}}},
{"UInt64Value", `{"u64":"3"}`, &pb.KnownTypes{U64: &wpb.UInt64Value{Value: 3}}},
{"Int32Value", `{"i32":-4}`, &pb.KnownTypes{I32: &wpb.Int32Value{Value: -4}}},
{"UInt32Value", `{"u32":4}`, &pb.KnownTypes{U32: &wpb.UInt32Value{Value: 4}}},
{"BoolValue", `{"bool":true}`, &pb.KnownTypes{Bool: &wpb.BoolValue{Value: true}}},
{"StringValue", `{"str":"plush"}`, &pb.KnownTypes{Str: &wpb.StringValue{Value: "plush"}}},
{"BytesValue", `{"bytes":"d293"}`, &pb.KnownTypes{Bytes: &wpb.BytesValue{Value: []byte("wow")}}},
// `null` is also a permissible value. Let's just test one.
{"null DoubleValue", `{"dbl":null}`, &pb.KnownTypes{Dbl: &wpb.DoubleValue{}}},
}
func TestUnmarshaling(t *testing.T) {
@ -456,7 +358,7 @@ func TestUnmarshaling(t *testing.T) {
err := UnmarshalString(tt.json, p)
if err != nil {
t.Errorf("%s: %v", tt.desc, err)
t.Error(err)
continue
}
@ -469,42 +371,6 @@ func TestUnmarshaling(t *testing.T) {
}
}
func TestUnmarshalNext(t *testing.T) {
// We only need to check against a few, not all of them.
tests := unmarshalingTests[:5]
// Create a buffer with many concatenated JSON objects.
var b bytes.Buffer
for _, tt := range tests {
b.WriteString(tt.json)
}
dec := json.NewDecoder(&b)
for _, tt := range tests {
// Make a new instance of the type of our expected object.
p := reflect.New(reflect.TypeOf(tt.pb).Elem()).Interface().(proto.Message)
err := UnmarshalNext(dec, p)
if err != nil {
t.Errorf("%s: %v", tt.desc, err)
continue
}
// For easier diffs, compare text strings of the protos.
exp := proto.MarshalTextString(tt.pb)
act := proto.MarshalTextString(p)
if string(exp) != string(act) {
t.Errorf("%s: got [%s] want [%s]", tt.desc, act, exp)
}
}
p := &pb.Simple{}
err := UnmarshalNext(dec, p)
if err != io.EOF {
t.Errorf("eof: got %v, expected io.EOF", err)
}
}
var unmarshalingShouldError = []struct {
desc string
in string

View File

@ -30,4 +30,4 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
regenerate:
protoc --go_out=Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any,Mgoogle/protobuf/duration.proto=github.com/golang/protobuf/ptypes/duration,Mgoogle/protobuf/struct.proto=github.com/golang/protobuf/ptypes/struct,Mgoogle/protobuf/timestamp.proto=github.com/golang/protobuf/ptypes/timestamp,Mgoogle/protobuf/wrappers.proto=github.com/golang/protobuf/ptypes/wrappers:. *.proto
protoc --go_out=. *.proto

View File

@ -19,7 +19,6 @@ It has these top-level messages:
MsgWithOneof
Real
Complex
KnownTypes
*/
package jsonpb
@ -36,30 +35,6 @@ var _ = math.Inf
// is compatible with the proto package it is being compiled against.
const _ = proto.ProtoPackageIsVersion1
type Numeral int32
const (
Numeral_UNKNOWN Numeral = 0
Numeral_ARABIC Numeral = 1
Numeral_ROMAN Numeral = 2
)
var Numeral_name = map[int32]string{
0: "UNKNOWN",
1: "ARABIC",
2: "ROMAN",
}
var Numeral_value = map[string]int32{
"UNKNOWN": 0,
"ARABIC": 1,
"ROMAN": 2,
}
func (x Numeral) String() string {
return proto.EnumName(Numeral_name, int32(x))
}
func (Numeral) EnumDescriptor() ([]byte, []int) { return fileDescriptor0, []int{0} }
type Simple3 struct {
Dub float64 `protobuf:"fixed64,1,opt,name=dub" json:"dub,omitempty"`
}
@ -75,7 +50,6 @@ type Mappy struct {
Objjy map[int32]*Simple3 `protobuf:"bytes,3,rep,name=objjy" json:"objjy,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
Buggy map[int64]string `protobuf:"bytes,4,rep,name=buggy" json:"buggy,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
Booly map[bool]bool `protobuf:"bytes,5,rep,name=booly" json:"booly,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"varint,2,opt,name=value"`
Enumy map[string]Numeral `protobuf:"bytes,6,rep,name=enumy" json:"enumy,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"varint,2,opt,name=value,enum=jsonpb.Numeral"`
}
func (m *Mappy) Reset() { *m = Mappy{} }
@ -118,42 +92,28 @@ func (m *Mappy) GetBooly() map[bool]bool {
return nil
}
func (m *Mappy) GetEnumy() map[string]Numeral {
if m != nil {
return m.Enumy
}
return nil
}
func init() {
proto.RegisterType((*Simple3)(nil), "jsonpb.Simple3")
proto.RegisterType((*Mappy)(nil), "jsonpb.Mappy")
proto.RegisterEnum("jsonpb.Numeral", Numeral_name, Numeral_value)
}
var fileDescriptor0 = []byte{
// 357 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x84, 0x93, 0x4d, 0x4b, 0xc3, 0x40,
0x10, 0x86, 0x4d, 0xe2, 0xa6, 0xcd, 0x14, 0x34, 0x2c, 0x82, 0x8b, 0x5e, 0x4a, 0x41, 0x28, 0x82,
0x39, 0xb4, 0x97, 0xe2, 0xad, 0x95, 0x1e, 0x8a, 0x34, 0x85, 0x14, 0xf1, 0x58, 0x1a, 0x5d, 0x8a,
0x31, 0xc9, 0x86, 0x7c, 0x08, 0xfb, 0x83, 0xfc, 0x9f, 0x32, 0x9b, 0xd4, 0xc4, 0xb2, 0xe0, 0x6d,
0x92, 0xf7, 0x79, 0xc2, 0xec, 0x1b, 0x16, 0xae, 0x13, 0x91, 0xf3, 0x5d, 0xc9, 0x8b, 0x72, 0x27,
0xc2, 0x88, 0xbf, 0x95, 0x85, 0x97, 0xe5, 0xa2, 0x14, 0xd4, 0x8e, 0x0a, 0x91, 0x66, 0xe1, 0xe8,
0x16, 0x7a, 0xdb, 0x8f, 0x24, 0x8b, 0xf9, 0x94, 0xba, 0x60, 0xbd, 0x57, 0x21, 0x33, 0x86, 0xc6,
0xd8, 0x08, 0x70, 0x1c, 0x7d, 0x13, 0x20, 0xeb, 0x7d, 0x96, 0x49, 0xea, 0x01, 0x49, 0xab, 0x24,
0x91, 0xcc, 0x18, 0x5a, 0xe3, 0xc1, 0x84, 0x79, 0xb5, 0xee, 0xa9, 0xd4, 0xf3, 0x31, 0x5a, 0xa6,
0x65, 0x2e, 0x83, 0x1a, 0x43, 0xbe, 0x28, 0xf3, 0x5c, 0x32, 0x53, 0xc7, 0x6f, 0x31, 0x6a, 0x78,
0x85, 0x21, 0x2f, 0xc2, 0x28, 0x92, 0xcc, 0xd2, 0xf1, 0x1b, 0x8c, 0x1a, 0x5e, 0x61, 0xc8, 0x87,
0xd5, 0xe1, 0x20, 0xd9, 0xb9, 0x8e, 0x5f, 0x60, 0xd4, 0xf0, 0x0a, 0x53, 0xbc, 0x10, 0xb1, 0x64,
0x44, 0xcb, 0x63, 0x74, 0xe4, 0x71, 0x46, 0x9e, 0xa7, 0x55, 0x22, 0x99, 0xad, 0xe3, 0x97, 0x18,
0x35, 0xbc, 0xc2, 0x6e, 0x66, 0x00, 0x6d, 0x09, 0xd8, 0xe4, 0x27, 0x97, 0xaa, 0x49, 0x2b, 0xc0,
0x91, 0x5e, 0x01, 0xf9, 0xda, 0xc7, 0x15, 0x67, 0xe6, 0xd0, 0x18, 0x93, 0xa0, 0x7e, 0x78, 0x34,
0x67, 0x06, 0x9a, 0x6d, 0x1d, 0x5d, 0xd3, 0xd1, 0x98, 0x4e, 0xd7, 0x5c, 0x01, 0xb4, 0xc5, 0x74,
0x4d, 0x52, 0x9b, 0x77, 0x5d, 0x73, 0x30, 0xb9, 0x3c, 0x9e, 0xa1, 0xf9, 0xdf, 0x27, 0x4b, 0xb4,
0x9d, 0xfd, 0xb7, 0xbe, 0x73, 0x6a, 0xfe, 0xb6, 0xd7, 0x35, 0xfb, 0x1a, 0xb3, 0x7f, 0xb2, 0x7e,
0xdb, 0xa3, 0xe6, 0xe0, 0x7f, 0xd6, 0xbf, 0x68, 0xd7, 0xf7, 0xab, 0x84, 0xe7, 0xfb, 0xb8, 0xf3,
0xa9, 0xfb, 0x07, 0xe8, 0x35, 0x6f, 0xe9, 0x00, 0x7a, 0x2f, 0xfe, 0xb3, 0xbf, 0x79, 0xf5, 0xdd,
0x33, 0x0a, 0x60, 0xcf, 0x83, 0xf9, 0x62, 0xf5, 0xe4, 0x1a, 0xd4, 0x01, 0x12, 0x6c, 0xd6, 0x73,
0xdf, 0x35, 0x43, 0x5b, 0x5d, 0x81, 0xe9, 0x4f, 0x00, 0x00, 0x00, 0xff, 0xff, 0x3d, 0x04, 0xff,
0x62, 0x1d, 0x03, 0x00, 0x00,
// 261 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x74, 0xd1, 0xc1, 0x4a, 0xc3, 0x30,
0x1c, 0x06, 0x70, 0xb2, 0x98, 0xb9, 0xfe, 0x8b, 0x08, 0x39, 0x68, 0xf0, 0x20, 0xb2, 0x8b, 0x3d,
0x48, 0x0f, 0xdb, 0x45, 0x04, 0x2f, 0x82, 0x47, 0xf5, 0xb0, 0x07, 0x18, 0x8b, 0x86, 0x61, 0x6d,
0x9b, 0x90, 0xa6, 0x42, 0x9e, 0xd1, 0x97, 0xb2, 0x89, 0xae, 0xfe, 0x0b, 0xf1, 0xfc, 0xfd, 0x42,
0xbe, 0x2f, 0x81, 0xf3, 0x46, 0x5b, 0xb5, 0x75, 0xaa, 0x73, 0x5b, 0x2d, 0x2b, 0xf5, 0xea, 0xba,
0xd2, 0x58, 0xed, 0x34, 0x9f, 0x57, 0x9d, 0x6e, 0x8d, 0x5c, 0x9e, 0xc1, 0xf1, 0xe6, 0xbd, 0x31,
0xb5, 0x5a, 0xf3, 0x1c, 0xe8, 0x5b, 0x2f, 0x05, 0xb9, 0x22, 0x05, 0x59, 0x7e, 0x51, 0x60, 0x4f,
0x3b, 0x63, 0x3c, 0xbf, 0x06, 0xd6, 0xf6, 0x4d, 0xe3, 0x87, 0x80, 0x16, 0xf9, 0x4a, 0x94, 0x3f,
0x27, 0xcb, 0x98, 0x96, 0xcf, 0x21, 0x7a, 0x6c, 0x9d, 0x8d, 0xb0, 0x73, 0xd6, 0x7a, 0x31, 0x4b,
0xc1, 0x4d, 0x88, 0x46, 0x38, 0x94, 0xa9, 0xbc, 0xa0, 0x29, 0xf8, 0x12, 0xa2, 0x11, 0xca, 0x7e,
0xbf, 0xf7, 0xe2, 0x28, 0x05, 0x1f, 0x42, 0xf4, 0x07, 0xb5, 0xae, 0xbd, 0x60, 0x49, 0x18, 0xa2,
0x08, 0x2f, 0x6e, 0x00, 0x50, 0xe3, 0x61, 0xf1, 0x87, 0xf2, 0x71, 0x31, 0xe5, 0x27, 0xc0, 0x3e,
0x77, 0x75, 0xaf, 0x86, 0xfa, 0xa4, 0x60, 0x77, 0xb3, 0x5b, 0x12, 0x34, 0xaa, 0x8d, 0x74, 0x36,
0xd5, 0x59, 0xd4, 0xf7, 0x00, 0xa8, 0x3b, 0xd2, 0x8c, 0x5f, 0x62, 0x9d, 0xaf, 0x4e, 0x0f, 0xfd,
0x7e, 0x9f, 0xfe, 0x70, 0x19, 0x5a, 0xf4, 0x7f, 0xb5, 0x6c, 0xd4, 0xe3, 0x2c, 0xac, 0x17, 0x53,
0xbd, 0x08, 0x5a, 0xce, 0xe3, 0xa7, 0xaf, 0xbf, 0x03, 0x00, 0x00, 0xff, 0xff, 0xd5, 0x3b, 0xe8,
0xea, 0x0f, 0x02, 0x00, 0x00,
}

View File

@ -37,17 +37,10 @@ message Simple3 {
double dub = 1;
}
enum Numeral {
UNKNOWN = 0;
ARABIC = 1;
ROMAN = 2;
}
message Mappy {
map<int64, int32> nummy = 1;
map<string, string> strry = 2;
map<int32, Simple3> objjy = 3;
map<int64, string> buggy = 4;
map<bool, bool> booly = 5;
map<string, Numeral> enumy = 6;
}

View File

@ -7,11 +7,6 @@ package jsonpb
import proto "github.com/golang/protobuf/proto"
import fmt "fmt"
import math "math"
import google_protobuf "github.com/golang/protobuf/ptypes/any"
import google_protobuf1 "github.com/golang/protobuf/ptypes/duration"
import google_protobuf2 "github.com/golang/protobuf/ptypes/struct"
import google_protobuf3 "github.com/golang/protobuf/ptypes/timestamp"
import google_protobuf4 "github.com/golang/protobuf/ptypes/wrappers"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
@ -57,17 +52,17 @@ func (Widget_Color) EnumDescriptor() ([]byte, []int) { return fileDescriptor1, [
// Test message for holding primitive types.
type Simple struct {
OBool *bool `protobuf:"varint,1,opt,name=o_bool,json=oBool" json:"o_bool,omitempty"`
OInt32 *int32 `protobuf:"varint,2,opt,name=o_int32,json=oInt32" json:"o_int32,omitempty"`
OInt64 *int64 `protobuf:"varint,3,opt,name=o_int64,json=oInt64" json:"o_int64,omitempty"`
OUint32 *uint32 `protobuf:"varint,4,opt,name=o_uint32,json=oUint32" json:"o_uint32,omitempty"`
OUint64 *uint64 `protobuf:"varint,5,opt,name=o_uint64,json=oUint64" json:"o_uint64,omitempty"`
OSint32 *int32 `protobuf:"zigzag32,6,opt,name=o_sint32,json=oSint32" json:"o_sint32,omitempty"`
OSint64 *int64 `protobuf:"zigzag64,7,opt,name=o_sint64,json=oSint64" json:"o_sint64,omitempty"`
OFloat *float32 `protobuf:"fixed32,8,opt,name=o_float,json=oFloat" json:"o_float,omitempty"`
ODouble *float64 `protobuf:"fixed64,9,opt,name=o_double,json=oDouble" json:"o_double,omitempty"`
OString *string `protobuf:"bytes,10,opt,name=o_string,json=oString" json:"o_string,omitempty"`
OBytes []byte `protobuf:"bytes,11,opt,name=o_bytes,json=oBytes" json:"o_bytes,omitempty"`
OBool *bool `protobuf:"varint,1,opt,name=o_bool" json:"o_bool,omitempty"`
OInt32 *int32 `protobuf:"varint,2,opt,name=o_int32" json:"o_int32,omitempty"`
OInt64 *int64 `protobuf:"varint,3,opt,name=o_int64" json:"o_int64,omitempty"`
OUint32 *uint32 `protobuf:"varint,4,opt,name=o_uint32" json:"o_uint32,omitempty"`
OUint64 *uint64 `protobuf:"varint,5,opt,name=o_uint64" json:"o_uint64,omitempty"`
OSint32 *int32 `protobuf:"zigzag32,6,opt,name=o_sint32" json:"o_sint32,omitempty"`
OSint64 *int64 `protobuf:"zigzag64,7,opt,name=o_sint64" json:"o_sint64,omitempty"`
OFloat *float32 `protobuf:"fixed32,8,opt,name=o_float" json:"o_float,omitempty"`
ODouble *float64 `protobuf:"fixed64,9,opt,name=o_double" json:"o_double,omitempty"`
OString *string `protobuf:"bytes,10,opt,name=o_string" json:"o_string,omitempty"`
OBytes []byte `protobuf:"bytes,11,opt,name=o_bytes" json:"o_bytes,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
@ -155,17 +150,17 @@ func (m *Simple) GetOBytes() []byte {
// Test message for holding repeated primitives.
type Repeats struct {
RBool []bool `protobuf:"varint,1,rep,name=r_bool,json=rBool" json:"r_bool,omitempty"`
RInt32 []int32 `protobuf:"varint,2,rep,name=r_int32,json=rInt32" json:"r_int32,omitempty"`
RInt64 []int64 `protobuf:"varint,3,rep,name=r_int64,json=rInt64" json:"r_int64,omitempty"`
RUint32 []uint32 `protobuf:"varint,4,rep,name=r_uint32,json=rUint32" json:"r_uint32,omitempty"`
RUint64 []uint64 `protobuf:"varint,5,rep,name=r_uint64,json=rUint64" json:"r_uint64,omitempty"`
RSint32 []int32 `protobuf:"zigzag32,6,rep,name=r_sint32,json=rSint32" json:"r_sint32,omitempty"`
RSint64 []int64 `protobuf:"zigzag64,7,rep,name=r_sint64,json=rSint64" json:"r_sint64,omitempty"`
RFloat []float32 `protobuf:"fixed32,8,rep,name=r_float,json=rFloat" json:"r_float,omitempty"`
RDouble []float64 `protobuf:"fixed64,9,rep,name=r_double,json=rDouble" json:"r_double,omitempty"`
RString []string `protobuf:"bytes,10,rep,name=r_string,json=rString" json:"r_string,omitempty"`
RBytes [][]byte `protobuf:"bytes,11,rep,name=r_bytes,json=rBytes" json:"r_bytes,omitempty"`
RBool []bool `protobuf:"varint,1,rep,name=r_bool" json:"r_bool,omitempty"`
RInt32 []int32 `protobuf:"varint,2,rep,name=r_int32" json:"r_int32,omitempty"`
RInt64 []int64 `protobuf:"varint,3,rep,name=r_int64" json:"r_int64,omitempty"`
RUint32 []uint32 `protobuf:"varint,4,rep,name=r_uint32" json:"r_uint32,omitempty"`
RUint64 []uint64 `protobuf:"varint,5,rep,name=r_uint64" json:"r_uint64,omitempty"`
RSint32 []int32 `protobuf:"zigzag32,6,rep,name=r_sint32" json:"r_sint32,omitempty"`
RSint64 []int64 `protobuf:"zigzag64,7,rep,name=r_sint64" json:"r_sint64,omitempty"`
RFloat []float32 `protobuf:"fixed32,8,rep,name=r_float" json:"r_float,omitempty"`
RDouble []float64 `protobuf:"fixed64,9,rep,name=r_double" json:"r_double,omitempty"`
RString []string `protobuf:"bytes,10,rep,name=r_string" json:"r_string,omitempty"`
RBytes [][]byte `protobuf:"bytes,11,rep,name=r_bytes" json:"r_bytes,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
@ -254,11 +249,11 @@ func (m *Repeats) GetRBytes() [][]byte {
// Test message for holding enums and nested messages.
type Widget struct {
Color *Widget_Color `protobuf:"varint,1,opt,name=color,enum=jsonpb.Widget_Color" json:"color,omitempty"`
RColor []Widget_Color `protobuf:"varint,2,rep,name=r_color,json=rColor,enum=jsonpb.Widget_Color" json:"r_color,omitempty"`
RColor []Widget_Color `protobuf:"varint,2,rep,name=r_color,enum=jsonpb.Widget_Color" json:"r_color,omitempty"`
Simple *Simple `protobuf:"bytes,10,opt,name=simple" json:"simple,omitempty"`
RSimple []*Simple `protobuf:"bytes,11,rep,name=r_simple,json=rSimple" json:"r_simple,omitempty"`
RSimple []*Simple `protobuf:"bytes,11,rep,name=r_simple" json:"r_simple,omitempty"`
Repeats *Repeats `protobuf:"bytes,20,opt,name=repeats" json:"repeats,omitempty"`
RRepeats []*Repeats `protobuf:"bytes,21,rep,name=r_repeats,json=rRepeats" json:"r_repeats,omitempty"`
RRepeats []*Repeats `protobuf:"bytes,21,rep,name=r_repeats" json:"r_repeats,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
@ -310,8 +305,8 @@ func (m *Widget) GetRRepeats() []*Repeats {
}
type Maps struct {
MInt64Str map[int64]string `protobuf:"bytes,1,rep,name=m_int64_str,json=mInt64Str" json:"m_int64_str,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
MBoolSimple map[bool]*Simple `protobuf:"bytes,2,rep,name=m_bool_simple,json=mBoolSimple" json:"m_bool_simple,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
MInt64Str map[int64]string `protobuf:"bytes,1,rep,name=m_int64_str" json:"m_int64_str,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
MBoolSimple map[bool]*Simple `protobuf:"bytes,2,rep,name=m_bool_simple" json:"m_bool_simple,omitempty" protobuf_key:"varint,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"`
XXX_unrecognized []byte `json:"-"`
}
@ -338,7 +333,6 @@ type MsgWithOneof struct {
// Types that are valid to be assigned to Union:
// *MsgWithOneof_Title
// *MsgWithOneof_Salary
// *MsgWithOneof_Country
Union isMsgWithOneof_Union `protobuf_oneof:"union"`
XXX_unrecognized []byte `json:"-"`
}
@ -358,13 +352,9 @@ type MsgWithOneof_Title struct {
type MsgWithOneof_Salary struct {
Salary int64 `protobuf:"varint,2,opt,name=salary,oneof"`
}
type MsgWithOneof_Country struct {
Country string `protobuf:"bytes,3,opt,name=Country,json=country,oneof"`
}
func (*MsgWithOneof_Title) isMsgWithOneof_Union() {}
func (*MsgWithOneof_Salary) isMsgWithOneof_Union() {}
func (*MsgWithOneof_Country) isMsgWithOneof_Union() {}
func (*MsgWithOneof_Title) isMsgWithOneof_Union() {}
func (*MsgWithOneof_Salary) isMsgWithOneof_Union() {}
func (m *MsgWithOneof) GetUnion() isMsgWithOneof_Union {
if m != nil {
@ -387,19 +377,11 @@ func (m *MsgWithOneof) GetSalary() int64 {
return 0
}
func (m *MsgWithOneof) GetCountry() string {
if x, ok := m.GetUnion().(*MsgWithOneof_Country); ok {
return x.Country
}
return ""
}
// XXX_OneofFuncs is for the internal use of the proto package.
func (*MsgWithOneof) XXX_OneofFuncs() (func(msg proto.Message, b *proto.Buffer) error, func(msg proto.Message, tag, wire int, b *proto.Buffer) (bool, error), func(msg proto.Message) (n int), []interface{}) {
return _MsgWithOneof_OneofMarshaler, _MsgWithOneof_OneofUnmarshaler, _MsgWithOneof_OneofSizer, []interface{}{
(*MsgWithOneof_Title)(nil),
(*MsgWithOneof_Salary)(nil),
(*MsgWithOneof_Country)(nil),
}
}
@ -413,9 +395,6 @@ func _MsgWithOneof_OneofMarshaler(msg proto.Message, b *proto.Buffer) error {
case *MsgWithOneof_Salary:
b.EncodeVarint(2<<3 | proto.WireVarint)
b.EncodeVarint(uint64(x.Salary))
case *MsgWithOneof_Country:
b.EncodeVarint(3<<3 | proto.WireBytes)
b.EncodeStringBytes(x.Country)
case nil:
default:
return fmt.Errorf("MsgWithOneof.Union has unexpected type %T", x)
@ -440,13 +419,6 @@ func _MsgWithOneof_OneofUnmarshaler(msg proto.Message, tag, wire int, b *proto.B
x, err := b.DecodeVarint()
m.Union = &MsgWithOneof_Salary{int64(x)}
return true, err
case 3: // union.Country
if wire != proto.WireBytes {
return true, proto.ErrInternalBadWireType
}
x, err := b.DecodeStringBytes()
m.Union = &MsgWithOneof_Country{x}
return true, err
default:
return false, nil
}
@ -463,10 +435,6 @@ func _MsgWithOneof_OneofSizer(msg proto.Message) (n int) {
case *MsgWithOneof_Salary:
n += proto.SizeVarint(2<<3 | proto.WireVarint)
n += proto.SizeVarint(uint64(x.Salary))
case *MsgWithOneof_Country:
n += proto.SizeVarint(3<<3 | proto.WireBytes)
n += proto.SizeVarint(uint64(len(x.Country)))
n += len(x.Country)
case nil:
default:
panic(fmt.Sprintf("proto: unexpected type %T in oneof", x))
@ -543,120 +511,7 @@ var E_Complex_RealExtension = &proto.ExtensionDesc{
ExtensionType: (*Complex)(nil),
Field: 123,
Name: "jsonpb.Complex.real_extension",
Tag: "bytes,123,opt,name=real_extension,json=realExtension",
}
type KnownTypes struct {
An *google_protobuf.Any `protobuf:"bytes,14,opt,name=an" json:"an,omitempty"`
Dur *google_protobuf1.Duration `protobuf:"bytes,1,opt,name=dur" json:"dur,omitempty"`
St *google_protobuf2.Struct `protobuf:"bytes,12,opt,name=st" json:"st,omitempty"`
Ts *google_protobuf3.Timestamp `protobuf:"bytes,2,opt,name=ts" json:"ts,omitempty"`
Dbl *google_protobuf4.DoubleValue `protobuf:"bytes,3,opt,name=dbl" json:"dbl,omitempty"`
Flt *google_protobuf4.FloatValue `protobuf:"bytes,4,opt,name=flt" json:"flt,omitempty"`
I64 *google_protobuf4.Int64Value `protobuf:"bytes,5,opt,name=i64" json:"i64,omitempty"`
U64 *google_protobuf4.UInt64Value `protobuf:"bytes,6,opt,name=u64" json:"u64,omitempty"`
I32 *google_protobuf4.Int32Value `protobuf:"bytes,7,opt,name=i32" json:"i32,omitempty"`
U32 *google_protobuf4.UInt32Value `protobuf:"bytes,8,opt,name=u32" json:"u32,omitempty"`
Bool *google_protobuf4.BoolValue `protobuf:"bytes,9,opt,name=bool" json:"bool,omitempty"`
Str *google_protobuf4.StringValue `protobuf:"bytes,10,opt,name=str" json:"str,omitempty"`
Bytes *google_protobuf4.BytesValue `protobuf:"bytes,11,opt,name=bytes" json:"bytes,omitempty"`
XXX_unrecognized []byte `json:"-"`
}
func (m *KnownTypes) Reset() { *m = KnownTypes{} }
func (m *KnownTypes) String() string { return proto.CompactTextString(m) }
func (*KnownTypes) ProtoMessage() {}
func (*KnownTypes) Descriptor() ([]byte, []int) { return fileDescriptor1, []int{7} }
func (m *KnownTypes) GetAn() *google_protobuf.Any {
if m != nil {
return m.An
}
return nil
}
func (m *KnownTypes) GetDur() *google_protobuf1.Duration {
if m != nil {
return m.Dur
}
return nil
}
func (m *KnownTypes) GetSt() *google_protobuf2.Struct {
if m != nil {
return m.St
}
return nil
}
func (m *KnownTypes) GetTs() *google_protobuf3.Timestamp {
if m != nil {
return m.Ts
}
return nil
}
func (m *KnownTypes) GetDbl() *google_protobuf4.DoubleValue {
if m != nil {
return m.Dbl
}
return nil
}
func (m *KnownTypes) GetFlt() *google_protobuf4.FloatValue {
if m != nil {
return m.Flt
}
return nil
}
func (m *KnownTypes) GetI64() *google_protobuf4.Int64Value {
if m != nil {
return m.I64
}
return nil
}
func (m *KnownTypes) GetU64() *google_protobuf4.UInt64Value {
if m != nil {
return m.U64
}
return nil
}
func (m *KnownTypes) GetI32() *google_protobuf4.Int32Value {
if m != nil {
return m.I32
}
return nil
}
func (m *KnownTypes) GetU32() *google_protobuf4.UInt32Value {
if m != nil {
return m.U32
}
return nil
}
func (m *KnownTypes) GetBool() *google_protobuf4.BoolValue {
if m != nil {
return m.Bool
}
return nil
}
func (m *KnownTypes) GetStr() *google_protobuf4.StringValue {
if m != nil {
return m.Str
}
return nil
}
func (m *KnownTypes) GetBytes() *google_protobuf4.BytesValue {
if m != nil {
return m.Bytes
}
return nil
Tag: "bytes,123,opt,name=real_extension",
}
var E_Name = &proto.ExtensionDesc{
@ -675,77 +530,49 @@ func init() {
proto.RegisterType((*MsgWithOneof)(nil), "jsonpb.MsgWithOneof")
proto.RegisterType((*Real)(nil), "jsonpb.Real")
proto.RegisterType((*Complex)(nil), "jsonpb.Complex")
proto.RegisterType((*KnownTypes)(nil), "jsonpb.KnownTypes")
proto.RegisterEnum("jsonpb.Widget_Color", Widget_Color_name, Widget_Color_value)
proto.RegisterExtension(E_Complex_RealExtension)
proto.RegisterExtension(E_Name)
}
var fileDescriptor1 = []byte{
// 1031 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x74, 0x95, 0xdf, 0x72, 0xdb, 0x44,
0x14, 0xc6, 0x2b, 0xc9, 0x96, 0xec, 0x75, 0x12, 0xcc, 0x4e, 0x4a, 0x15, 0x13, 0x40, 0xe3, 0x29,
0x45, 0x14, 0xea, 0x0e, 0x8a, 0xc7, 0xc3, 0x14, 0x6e, 0x9a, 0xc6, 0x50, 0x06, 0x52, 0x66, 0x36,
0x0d, 0xbd, 0xf4, 0xc8, 0xf1, 0xc6, 0xa8, 0xc8, 0x5a, 0xcf, 0xee, 0x8a, 0xd4, 0x03, 0x17, 0x79,
0x08, 0x5e, 0x01, 0x1e, 0x81, 0x27, 0xe2, 0x41, 0x98, 0x73, 0x56, 0x7f, 0x12, 0x3b, 0xbe, 0x8a,
0x8f, 0xce, 0x77, 0xbe, 0xac, 0x7e, 0x7b, 0x74, 0x0e, 0xa1, 0x9a, 0x2b, 0x3d, 0x11, 0xd3, 0xb7,
0xfc, 0x42, 0xab, 0xc1, 0x52, 0x0a, 0x2d, 0xa8, 0xfb, 0x56, 0x89, 0x6c, 0x39, 0xed, 0x1d, 0xcc,
0x85, 0x98, 0xa7, 0xfc, 0x29, 0x3e, 0x9d, 0xe6, 0x97, 0x4f, 0xe3, 0x6c, 0x65, 0x24, 0xbd, 0x8f,
0xd7, 0x53, 0xb3, 0x5c, 0xc6, 0x3a, 0x11, 0x59, 0x91, 0x3f, 0x5c, 0xcf, 0x2b, 0x2d, 0xf3, 0x0b,
0x5d, 0x64, 0x3f, 0x59, 0xcf, 0xea, 0x64, 0xc1, 0x95, 0x8e, 0x17, 0xcb, 0x6d, 0xf6, 0x57, 0x32,
0x5e, 0x2e, 0xb9, 0x2c, 0x4e, 0xd8, 0xff, 0xdb, 0x26, 0xee, 0x59, 0xb2, 0x58, 0xa6, 0x9c, 0xde,
0x27, 0xae, 0x98, 0x4c, 0x85, 0x48, 0x7d, 0x2b, 0xb0, 0xc2, 0x16, 0x6b, 0x8a, 0x63, 0x21, 0x52,
0xfa, 0x80, 0x78, 0x62, 0x92, 0x64, 0xfa, 0x28, 0xf2, 0xed, 0xc0, 0x0a, 0x9b, 0xcc, 0x15, 0x3f,
0x40, 0x54, 0x25, 0x46, 0x43, 0xdf, 0x09, 0xac, 0xd0, 0x31, 0x89, 0xd1, 0x90, 0x1e, 0x90, 0x96,
0x98, 0xe4, 0xa6, 0xa4, 0x11, 0x58, 0xe1, 0x2e, 0xf3, 0xc4, 0x39, 0x86, 0x75, 0x6a, 0x34, 0xf4,
0x9b, 0x81, 0x15, 0x36, 0x8a, 0x54, 0x59, 0xa5, 0x4c, 0x95, 0x1b, 0x58, 0xe1, 0xfb, 0xcc, 0x13,
0x67, 0x37, 0xaa, 0x94, 0xa9, 0xf2, 0x02, 0x2b, 0xa4, 0x45, 0x6a, 0x34, 0x34, 0x87, 0xb8, 0x4c,
0x45, 0xac, 0xfd, 0x56, 0x60, 0x85, 0x36, 0x73, 0xc5, 0x77, 0x10, 0x99, 0x9a, 0x99, 0xc8, 0xa7,
0x29, 0xf7, 0xdb, 0x81, 0x15, 0x5a, 0xcc, 0x13, 0x27, 0x18, 0x16, 0x76, 0x5a, 0x26, 0xd9, 0xdc,
0x27, 0x81, 0x15, 0xb6, 0xc1, 0x0e, 0x43, 0x63, 0x37, 0x5d, 0x69, 0xae, 0xfc, 0x4e, 0x60, 0x85,
0x3b, 0xcc, 0x15, 0xc7, 0x10, 0xf5, 0xff, 0xb1, 0x89, 0xc7, 0xf8, 0x92, 0xc7, 0x5a, 0x01, 0x28,
0x59, 0x82, 0x72, 0x00, 0x94, 0x2c, 0x41, 0xc9, 0x0a, 0x94, 0x03, 0xa0, 0x64, 0x05, 0x4a, 0x56,
0xa0, 0x1c, 0x00, 0x25, 0x2b, 0x50, 0xb2, 0x06, 0xe5, 0x00, 0x28, 0x59, 0x83, 0x92, 0x35, 0x28,
0x07, 0x40, 0xc9, 0x1a, 0x94, 0xac, 0x41, 0x39, 0x00, 0x4a, 0x9e, 0xdd, 0xa8, 0xaa, 0x40, 0x39,
0x00, 0x4a, 0xd6, 0xa0, 0x64, 0x05, 0xca, 0x01, 0x50, 0xb2, 0x02, 0x25, 0x6b, 0x50, 0x0e, 0x80,
0x92, 0x35, 0x28, 0x59, 0x83, 0x72, 0x00, 0x94, 0xac, 0x41, 0xc9, 0x0a, 0x94, 0x03, 0xa0, 0xa4,
0x01, 0xf5, 0xaf, 0x4d, 0xdc, 0x37, 0xc9, 0x6c, 0xce, 0x35, 0x7d, 0x4c, 0x9a, 0x17, 0x22, 0x15,
0x12, 0xfb, 0x69, 0x2f, 0xda, 0x1f, 0x98, 0xaf, 0x61, 0x60, 0xd2, 0x83, 0x17, 0x90, 0x63, 0x46,
0x42, 0x9f, 0x80, 0x9f, 0x51, 0x03, 0xbc, 0x6d, 0x6a, 0x57, 0xe2, 0x5f, 0xfa, 0x88, 0xb8, 0x0a,
0xbb, 0x16, 0x2f, 0xb0, 0x13, 0xed, 0x95, 0x6a, 0xd3, 0xcb, 0xac, 0xc8, 0xd2, 0xcf, 0x0d, 0x10,
0x54, 0xc2, 0x39, 0x37, 0x95, 0x00, 0xa8, 0x90, 0x7a, 0xd2, 0x5c, 0xb0, 0xbf, 0x8f, 0x9e, 0xef,
0x95, 0xca, 0xe2, 0xde, 0x59, 0x99, 0xa7, 0x5f, 0x92, 0xb6, 0x9c, 0x94, 0xe2, 0xfb, 0x68, 0xbb,
0x21, 0x6e, 0xc9, 0xe2, 0x57, 0xff, 0x53, 0xd2, 0x34, 0x87, 0xf6, 0x88, 0xc3, 0xc6, 0x27, 0xdd,
0x7b, 0xb4, 0x4d, 0x9a, 0xdf, 0xb3, 0xf1, 0xf8, 0x55, 0xd7, 0xa2, 0x2d, 0xd2, 0x38, 0xfe, 0xe9,
0x7c, 0xdc, 0xb5, 0xfb, 0x7f, 0xd9, 0xa4, 0x71, 0x1a, 0x2f, 0x15, 0xfd, 0x86, 0x74, 0x16, 0xa6,
0x5d, 0x80, 0x3d, 0xf6, 0x58, 0x27, 0xfa, 0xb0, 0xf4, 0x07, 0xc9, 0xe0, 0x14, 0xfb, 0xe7, 0x4c,
0xcb, 0x71, 0xa6, 0xe5, 0x8a, 0xb5, 0x17, 0x65, 0x4c, 0x9f, 0x93, 0xdd, 0x05, 0xf6, 0x66, 0xf9,
0xd6, 0x36, 0x96, 0x7f, 0x74, 0xbb, 0x1c, 0xfa, 0xd5, 0xbc, 0xb6, 0x31, 0xe8, 0x2c, 0xea, 0x27,
0xbd, 0x6f, 0xc9, 0xde, 0x6d, 0x7f, 0xda, 0x25, 0xce, 0x6f, 0x7c, 0x85, 0xd7, 0xe8, 0x30, 0xf8,
0x49, 0xf7, 0x49, 0xf3, 0xf7, 0x38, 0xcd, 0x39, 0x8e, 0x84, 0x36, 0x33, 0xc1, 0x33, 0xfb, 0x6b,
0xab, 0xf7, 0x8a, 0x74, 0xd7, 0xed, 0x6f, 0xd6, 0xb7, 0x4c, 0xfd, 0xc3, 0x9b, 0xf5, 0x9b, 0x97,
0x52, 0xfb, 0xf5, 0x39, 0xd9, 0x39, 0x55, 0xf3, 0x37, 0x89, 0xfe, 0xf5, 0xe7, 0x8c, 0x8b, 0x4b,
0xfa, 0x01, 0x69, 0xea, 0x44, 0xa7, 0x1c, 0xdd, 0xda, 0x2f, 0xef, 0x31, 0x13, 0x52, 0x9f, 0xb8,
0x2a, 0x4e, 0x63, 0xb9, 0x42, 0x4b, 0xe7, 0xe5, 0x3d, 0x56, 0xc4, 0xb4, 0x47, 0xbc, 0x17, 0x22,
0x87, 0x83, 0xe0, 0x9c, 0x82, 0x1a, 0xef, 0xc2, 0x3c, 0x38, 0xf6, 0x48, 0x33, 0xcf, 0x12, 0x91,
0xf5, 0x1f, 0x91, 0x06, 0xe3, 0x71, 0x5a, 0xbf, 0x98, 0x85, 0x33, 0xc3, 0x04, 0x8f, 0x5b, 0xad,
0x59, 0xf7, 0xfa, 0xfa, 0xfa, 0xda, 0xee, 0x5f, 0x81, 0x19, 0x9c, 0xf1, 0x1d, 0x3d, 0x24, 0xed,
0x64, 0x11, 0xcf, 0x93, 0x0c, 0xfe, 0xa9, 0x91, 0xd7, 0x0f, 0xea, 0x92, 0xe8, 0x84, 0xec, 0x49,
0x1e, 0xa7, 0x13, 0xfe, 0x4e, 0xf3, 0x4c, 0x25, 0x22, 0xa3, 0x3b, 0x75, 0xb3, 0xc4, 0xa9, 0xff,
0xc7, 0xed, 0x6e, 0x2b, 0xec, 0xd9, 0x2e, 0x14, 0x8d, 0xcb, 0x9a, 0xfe, 0x7f, 0x0d, 0x42, 0x7e,
0xcc, 0xc4, 0x55, 0xf6, 0x7a, 0xb5, 0xe4, 0x8a, 0x3e, 0x24, 0x76, 0x9c, 0xf9, 0x7b, 0x58, 0xba,
0x3f, 0x30, 0x43, 0x7e, 0x50, 0x0e, 0xf9, 0xc1, 0xf3, 0x6c, 0xc5, 0xec, 0x38, 0xa3, 0x5f, 0x10,
0x67, 0x96, 0x9b, 0xef, 0xaf, 0x13, 0x1d, 0x6c, 0xc8, 0x4e, 0x8a, 0x55, 0xc3, 0x40, 0x45, 0x3f,
0x23, 0xb6, 0xd2, 0xfe, 0x0e, 0x6a, 0x1f, 0x6c, 0x68, 0xcf, 0x70, 0xed, 0x30, 0x5b, 0xc1, 0x77,
0x6d, 0x6b, 0x55, 0xdc, 0x5c, 0x6f, 0x43, 0xf8, 0xba, 0xdc, 0x40, 0xcc, 0xd6, 0x8a, 0x0e, 0x88,
0x33, 0x9b, 0xa6, 0x08, 0xbe, 0x13, 0x1d, 0x6e, 0x9e, 0x00, 0x07, 0xcd, 0x2f, 0x00, 0x99, 0x81,
0x90, 0x3e, 0x21, 0xce, 0x65, 0xaa, 0x71, 0x6d, 0x40, 0xd3, 0xaf, 0xeb, 0x71, 0x64, 0x15, 0xf2,
0xcb, 0x54, 0x83, 0x3c, 0x29, 0x56, 0xc9, 0x5d, 0x72, 0x6c, 0xe3, 0x42, 0x9e, 0x8c, 0x86, 0x70,
0x9a, 0x7c, 0x34, 0xc4, 0xf5, 0x72, 0xd7, 0x69, 0xce, 0x6f, 0xea, 0xf3, 0xd1, 0x10, 0xed, 0x8f,
0x22, 0xdc, 0x39, 0x5b, 0xec, 0x8f, 0xa2, 0xd2, 0xfe, 0x28, 0x42, 0xfb, 0xa3, 0x08, 0x17, 0xd1,
0x36, 0xfb, 0x4a, 0x9f, 0xa3, 0xbe, 0x81, 0x6b, 0xa4, 0xbd, 0x05, 0x25, 0x7c, 0x47, 0x46, 0x8e,
0x3a, 0xf0, 0x87, 0x89, 0x40, 0xb6, 0xf8, 0x9b, 0xd1, 0x5c, 0xf8, 0x2b, 0x2d, 0xe9, 0x57, 0xa4,
0x59, 0xef, 0xb2, 0xbb, 0x5e, 0x00, 0x47, 0xb6, 0x29, 0x30, 0xca, 0x67, 0x01, 0x69, 0x64, 0xf1,
0x82, 0xaf, 0xb5, 0xe8, 0x9f, 0xf8, 0x95, 0x63, 0xe6, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0xca,
0xa2, 0x76, 0x34, 0xe8, 0x08, 0x00, 0x00,
// 598 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x09, 0x6e, 0x88, 0x02, 0xff, 0x74, 0x93, 0x5f, 0x6b, 0x13, 0x4d,
0x14, 0xc6, 0xbb, 0x3b, 0xfb, 0xf7, 0xa4, 0x4d, 0xb7, 0x43, 0x5f, 0x58, 0xfa, 0x52, 0x2d, 0x2b,
0x05, 0xf1, 0x22, 0x94, 0x58, 0x45, 0x72, 0x99, 0x1a, 0xac, 0x60, 0x14, 0x52, 0xa4, 0x57, 0xb2,
0x6c, 0x9a, 0x69, 0xdc, 0xba, 0xd9, 0x09, 0xb3, 0x13, 0x69, 0xd0, 0x0b, 0xc1, 0x2f, 0xa7, 0xdf,
0xc3, 0x0f, 0xe2, 0xcc, 0x9c, 0x6c, 0xb6, 0x8d, 0x9a, 0xcb, 0x67, 0x9e, 0xf3, 0xe4, 0xfc, 0xce,
0x39, 0x0b, 0x54, 0xb2, 0x4a, 0xa6, 0x7c, 0x7c, 0xc3, 0xae, 0x64, 0xd5, 0x99, 0x0b, 0x2e, 0x39,
0xf5, 0x6e, 0x2a, 0x5e, 0xce, 0xc7, 0xc9, 0x0f, 0x0b, 0xbc, 0x8b, 0x7c, 0x36, 0x2f, 0x18, 0x6d,
0x83, 0xc7, 0xd3, 0x31, 0xe7, 0x45, 0x6c, 0x1d, 0x59, 0x8f, 0x03, 0xba, 0x0b, 0x3e, 0x4f, 0xf3,
0x52, 0x3e, 0xed, 0xc6, 0xb6, 0x12, 0xdc, 0xb5, 0xf0, 0xfc, 0x34, 0x26, 0x4a, 0x20, 0x34, 0x82,
0x80, 0xa7, 0x0b, 0xb4, 0x38, 0x4a, 0xd9, 0x69, 0x14, 0xe5, 0x71, 0x95, 0xe2, 0xa0, 0x52, 0xa1,
0xc7, 0x53, 0xca, 0x5e, 0xa3, 0x28, 0x8f, 0xaf, 0x14, 0x8a, 0xc1, 0xd7, 0x05, 0xcf, 0x64, 0x1c,
0x28, 0xc1, 0x46, 0xcb, 0x84, 0x2f, 0xc6, 0x05, 0x8b, 0x43, 0xa5, 0x58, 0xab, 0x22, 0x29, 0xf2,
0x72, 0x1a, 0x83, 0x52, 0x42, 0x2c, 0x1a, 0x2f, 0x15, 0x5b, 0xdc, 0x52, 0xc2, 0x76, 0xf2, 0xd3,
0x02, 0x7f, 0xc4, 0xe6, 0x2c, 0x93, 0x95, 0x66, 0x11, 0x35, 0x0b, 0x41, 0x16, 0xb1, 0x66, 0x21,
0xc8, 0x22, 0xd6, 0x2c, 0x04, 0x59, 0x44, 0xc3, 0x42, 0x90, 0x45, 0x34, 0x2c, 0x04, 0x59, 0x44,
0xc3, 0x42, 0x90, 0x45, 0x34, 0x2c, 0x04, 0x59, 0xc4, 0x9a, 0x85, 0x20, 0x8b, 0x68, 0x58, 0x08,
0xb2, 0x88, 0x86, 0x85, 0x20, 0x8b, 0x58, 0xb3, 0x10, 0xc5, 0xf2, 0xdd, 0x06, 0xef, 0x32, 0x9f,
0x4c, 0x99, 0xa4, 0x8f, 0xc0, 0xbd, 0xe2, 0x05, 0x17, 0x66, 0x2b, 0xed, 0xee, 0x7e, 0x07, 0x37,
0xd7, 0xc1, 0xe7, 0xce, 0x99, 0x7e, 0xa3, 0xc7, 0x3a, 0x00, 0x6d, 0x9a, 0xef, 0x5f, 0xb6, 0x07,
0xe0, 0x55, 0x66, 0xd9, 0x66, 0x86, 0xad, 0x6e, 0xbb, 0x76, 0xad, 0x4e, 0xe0, 0x08, 0x71, 0x8c,
0x43, 0x37, 0xf2, 0x37, 0x87, 0x2f, 0x70, 0xc6, 0xf1, 0xbe, 0x89, 0xd8, 0xad, 0x0d, 0xf5, 0xe8,
0x13, 0x08, 0x45, 0x5a, 0x7b, 0xfe, 0x33, 0x21, 0x9b, 0x9e, 0xe4, 0x18, 0x5c, 0x6c, 0xc8, 0x07,
0x32, 0x1a, 0xbc, 0x8c, 0xb6, 0x68, 0x08, 0xee, 0xab, 0xd1, 0x60, 0xf0, 0x36, 0xb2, 0x68, 0x00,
0x4e, 0xff, 0xcd, 0xfb, 0x41, 0x64, 0x27, 0xbf, 0x2c, 0x70, 0x86, 0xd9, 0xbc, 0xa2, 0x27, 0xd0,
0x9a, 0xe1, 0xb6, 0xf4, 0xdc, 0xcc, 0x4e, 0x5b, 0xdd, 0xff, 0xeb, 0x54, 0x6d, 0xe9, 0x0c, 0x5f,
0xeb, 0xe7, 0x0b, 0x29, 0x06, 0xa5, 0x14, 0x4b, 0x7a, 0x0a, 0x3b, 0x33, 0x73, 0x00, 0x35, 0x8e,
0x6d, 0x6a, 0x0e, 0xef, 0xd7, 0xf4, 0x95, 0x01, 0xc1, 0x4c, 0xd5, 0xc1, 0x09, 0xb4, 0x37, 0x72,
0x5a, 0x40, 0x3e, 0xb1, 0xa5, 0x99, 0x3d, 0xa1, 0x3b, 0xe0, 0x7e, 0xce, 0x8a, 0x05, 0x33, 0xdf,
0x43, 0xd8, 0xb3, 0x5f, 0x58, 0x07, 0x7d, 0x88, 0x36, 0x53, 0xee, 0xd6, 0x04, 0xf4, 0xf0, 0x6e,
0xcd, 0x1f, 0xf3, 0xd4, 0x19, 0x49, 0x0f, 0xb6, 0x87, 0xd5, 0xf4, 0x32, 0x97, 0x1f, 0xdf, 0x95,
0x8c, 0x5f, 0xab, 0x6b, 0x70, 0x65, 0x2e, 0x55, 0xcf, 0x3a, 0x21, 0x3c, 0xdf, 0x52, 0x07, 0xe3,
0x55, 0x59, 0x91, 0x89, 0xa5, 0x09, 0x21, 0xe7, 0x5b, 0x7d, 0x1f, 0xdc, 0x45, 0x99, 0xf3, 0x32,
0x79, 0x08, 0xce, 0x88, 0x65, 0x45, 0xd3, 0x9a, 0xae, 0xb1, 0x9e, 0x04, 0xc1, 0x24, 0xfa, 0xa6,
0x7e, 0x76, 0xf2, 0x01, 0xfc, 0x33, 0xae, 0xff, 0xea, 0x96, 0xee, 0x41, 0x98, 0xcf, 0xb2, 0x69,
0x5e, 0xea, 0xa4, 0x0d, 0x5f, 0xf7, 0x19, 0xb4, 0x85, 0x0a, 0x4a, 0xd9, 0xad, 0x64, 0x65, 0xa5,
0xa2, 0xe9, 0x76, 0xb3, 0xb5, 0xac, 0x88, 0xbf, 0xdc, 0xdf, 0xf6, 0x2a, 0xb3, 0x77, 0x00, 0x4e,
0x99, 0xcd, 0xd8, 0x86, 0xf9, 0xab, 0x6e, 0xfc, 0x77, 0x00, 0x00, 0x00, 0xff, 0xff, 0xfb, 0xc2,
0xb2, 0xf6, 0x78, 0x04, 0x00, 0x00,
}

View File

@ -31,12 +31,6 @@
syntax = "proto2";
import "google/protobuf/any.proto";
import "google/protobuf/duration.proto";
import "google/protobuf/struct.proto";
import "google/protobuf/timestamp.proto";
import "google/protobuf/wrappers.proto";
package jsonpb;
// Test message for holding primitive types.
@ -95,7 +89,6 @@ message MsgWithOneof {
oneof union {
string title = 1;
int64 salary = 2;
string Country = 3;
}
}
@ -115,20 +108,3 @@ message Complex {
optional double imaginary = 1;
extensions 100 to max;
}
message KnownTypes {
optional google.protobuf.Any an = 14;
optional google.protobuf.Duration dur = 1;
optional google.protobuf.Struct st = 12;
optional google.protobuf.Timestamp ts = 2;
optional google.protobuf.DoubleValue dbl = 3;
optional google.protobuf.FloatValue flt = 4;
optional google.protobuf.Int64Value i64 = 5;
optional google.protobuf.UInt64Value u64 = 6;
optional google.protobuf.Int32Value i32 = 7;
optional google.protobuf.UInt32Value u32 = 8;
optional google.protobuf.BoolValue bool = 9;
optional google.protobuf.StringValue str = 10;
optional google.protobuf.BytesValue bytes = 11;
}

View File

@ -39,5 +39,5 @@ test: install generate-test-pbs
generate-test-pbs:
make install
make -C testdata
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata,Mgoogle/protobuf/any.proto=github.com/golang/protobuf/ptypes/any:. proto3_proto/proto3.proto
protoc --go_out=Mtestdata/test.proto=github.com/golang/protobuf/proto/testdata:. proto3_proto/proto3.proto
make

View File

@ -1329,18 +1329,9 @@ func TestRequiredFieldEnforcement(t *testing.T) {
func TestTypedNilMarshal(t *testing.T) {
// A typed nil should return ErrNil and not crash.
{
var m *GoEnum
if _, err := Marshal(m); err != ErrNil {
t.Errorf("Marshal(%#v): got %v, want ErrNil", m, err)
}
}
{
m := &Communique{Union: &Communique_Msg{nil}}
if _, err := Marshal(m); err == nil || err == ErrNil {
t.Errorf("Marshal(%#v): got %v, want errOneofHasNil", m, err)
}
_, err := Marshal((*GoEnum)(nil))
if err != ErrNil {
t.Errorf("Marshal: got err %v, want ErrNil", err)
}
}
@ -1956,88 +1947,14 @@ func TestMapFieldRoundTrips(t *testing.T) {
}
func TestMapFieldWithNil(t *testing.T) {
m1 := &MessageWithMap{
m := &MessageWithMap{
MsgMapping: map[int64]*FloatingPoint{
1: nil,
},
}
b, err := Marshal(m1)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
m2 := new(MessageWithMap)
if err := Unmarshal(b, m2); err != nil {
t.Fatalf("Unmarshal: %v, got these bytes: %v", err, b)
}
if v, ok := m2.MsgMapping[1]; !ok {
t.Error("msg_mapping[1] not present")
} else if v != nil {
t.Errorf("msg_mapping[1] not nil: %v", v)
}
}
func TestMapFieldWithNilBytes(t *testing.T) {
m1 := &MessageWithMap{
ByteMapping: map[bool][]byte{
false: []byte{},
true: nil,
},
}
n := Size(m1)
b, err := Marshal(m1)
if err != nil {
t.Fatalf("Marshal: %v", err)
}
if n != len(b) {
t.Errorf("Size(m1) = %d; want len(Marshal(m1)) = %d", n, len(b))
}
m2 := new(MessageWithMap)
if err := Unmarshal(b, m2); err != nil {
t.Fatalf("Unmarshal: %v, got these bytes: %v", err, b)
}
if v, ok := m2.ByteMapping[false]; !ok {
t.Error("byte_mapping[false] not present")
} else if len(v) != 0 {
t.Errorf("byte_mapping[false] not empty: %#v", v)
}
if v, ok := m2.ByteMapping[true]; !ok {
t.Error("byte_mapping[true] not present")
} else if len(v) != 0 {
t.Errorf("byte_mapping[true] not empty: %#v", v)
}
}
func TestDecodeMapFieldMissingKey(t *testing.T) {
b := []byte{
0x0A, 0x03, // message, tag 1 (name_mapping), of length 3 bytes
// no key
0x12, 0x01, 0x6D, // string value of length 1 byte, value "m"
}
got := &MessageWithMap{}
err := Unmarshal(b, got)
if err != nil {
t.Fatalf("failed to marshal map with missing key: %v", err)
}
want := &MessageWithMap{NameMapping: map[int32]string{0: "m"}}
if !Equal(got, want) {
t.Errorf("Unmarshaled map with no key was not as expected. got: %v, want %v", got, want)
}
}
func TestDecodeMapFieldMissingValue(t *testing.T) {
b := []byte{
0x0A, 0x02, // message, tag 1 (name_mapping), of length 2 bytes
0x08, 0x01, // varint key, value 1
// no value
}
got := &MessageWithMap{}
err := Unmarshal(b, got)
if err != nil {
t.Fatalf("failed to marshal map with missing value: %v", err)
}
want := &MessageWithMap{NameMapping: map[int32]string{1: ""}}
if !Equal(got, want) {
t.Errorf("Unmarshaled map with no value was not as expected. got: %v, want %v", got, want)
b, err := Marshal(m)
if err == nil {
t.Fatalf("Marshal of bad map should have failed, got these bytes: %v", b)
}
}

View File

@ -1,272 +0,0 @@
// Go support for Protocol Buffers - Google's data interchange format
//
// Copyright 2016 The Go Authors. All rights reserved.
// https://github.com/golang/protobuf
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above
// copyright notice, this list of conditions and the following disclaimer
// in the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Google Inc. nor the names of its
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package proto_test
import (
"strings"
"testing"
"github.com/golang/protobuf/proto"
pb "github.com/golang/protobuf/proto/proto3_proto"
testpb "github.com/golang/protobuf/proto/testdata"
anypb "github.com/golang/protobuf/ptypes/any"
)
var (
expandedMarshaler = proto.TextMarshaler{ExpandAny: true}
expandedCompactMarshaler = proto.TextMarshaler{Compact: true, ExpandAny: true}
)
// anyEqual reports whether two messages which may be google.protobuf.Any or may
// contain google.protobuf.Any fields are equal. We can't use proto.Equal for
// comparison, because semantically equivalent messages may be marshaled to
// binary in different tag order. Instead, trust that TextMarshaler with
// ExpandAny option works and compare the text marshaling results.
func anyEqual(got, want proto.Message) bool {
// if messages are proto.Equal, no need to marshal.
if proto.Equal(got, want) {
return true
}
g := expandedMarshaler.Text(got)
w := expandedMarshaler.Text(want)
return g == w
}
type golden struct {
m proto.Message
t, c string
}
var goldenMessages = makeGolden()
func makeGolden() []golden {
nested := &pb.Nested{Bunny: "Monty"}
nb, err := proto.Marshal(nested)
if err != nil {
panic(err)
}
m1 := &pb.Message{
Name: "David",
ResultCount: 47,
Anything: &anypb.Any{TypeUrl: "type.googleapis.com/" + proto.MessageName(nested), Value: nb},
}
m2 := &pb.Message{
Name: "David",
ResultCount: 47,
Anything: &anypb.Any{TypeUrl: "http://[::1]/type.googleapis.com/" + proto.MessageName(nested), Value: nb},
}
m3 := &pb.Message{
Name: "David",
ResultCount: 47,
Anything: &anypb.Any{TypeUrl: `type.googleapis.com/"/` + proto.MessageName(nested), Value: nb},
}
m4 := &pb.Message{
Name: "David",
ResultCount: 47,
Anything: &anypb.Any{TypeUrl: "type.googleapis.com/a/path/" + proto.MessageName(nested), Value: nb},
}
m5 := &anypb.Any{TypeUrl: "type.googleapis.com/" + proto.MessageName(nested), Value: nb}
any1 := &testpb.MyMessage{Count: proto.Int32(47), Name: proto.String("David")}
proto.SetExtension(any1, testpb.E_Ext_More, &testpb.Ext{Data: proto.String("foo")})
proto.SetExtension(any1, testpb.E_Ext_Text, proto.String("bar"))
any1b, err := proto.Marshal(any1)
if err != nil {
panic(err)
}
any2 := &testpb.MyMessage{Count: proto.Int32(42), Bikeshed: testpb.MyMessage_GREEN.Enum(), RepBytes: [][]byte{[]byte("roboto")}}
proto.SetExtension(any2, testpb.E_Ext_More, &testpb.Ext{Data: proto.String("baz")})
any2b, err := proto.Marshal(any2)
if err != nil {
panic(err)
}
m6 := &pb.Message{
Name: "David",
ResultCount: 47,
Anything: &anypb.Any{TypeUrl: "type.googleapis.com/" + proto.MessageName(any1), Value: any1b},
ManyThings: []*anypb.Any{
&anypb.Any{TypeUrl: "type.googleapis.com/" + proto.MessageName(any2), Value: any2b},
&anypb.Any{TypeUrl: "type.googleapis.com/" + proto.MessageName(any1), Value: any1b},
},
}
const (
m1Golden = `
name: "David"
result_count: 47
anything: <
[type.googleapis.com/proto3_proto.Nested]: <
bunny: "Monty"
>
>
`
m2Golden = `
name: "David"
result_count: 47
anything: <
["http://[::1]/type.googleapis.com/proto3_proto.Nested"]: <
bunny: "Monty"
>
>
`
m3Golden = `
name: "David"
result_count: 47
anything: <
["type.googleapis.com/\"/proto3_proto.Nested"]: <
bunny: "Monty"
>
>
`
m4Golden = `
name: "David"
result_count: 47
anything: <
[type.googleapis.com/a/path/proto3_proto.Nested]: <
bunny: "Monty"
>
>
`
m5Golden = `
[type.googleapis.com/proto3_proto.Nested]: <
bunny: "Monty"
>
`
m6Golden = `
name: "David"
result_count: 47
anything: <
[type.googleapis.com/testdata.MyMessage]: <
count: 47
name: "David"
[testdata.Ext.more]: <
data: "foo"
>
[testdata.Ext.text]: "bar"
>
>
many_things: <
[type.googleapis.com/testdata.MyMessage]: <
count: 42
bikeshed: GREEN
rep_bytes: "roboto"
[testdata.Ext.more]: <
data: "baz"
>
>
>
many_things: <
[type.googleapis.com/testdata.MyMessage]: <
count: 47
name: "David"
[testdata.Ext.more]: <
data: "foo"
>
[testdata.Ext.text]: "bar"
>
>
`
)
return []golden{
{m1, strings.TrimSpace(m1Golden) + "\n", strings.TrimSpace(compact(m1Golden)) + " "},
{m2, strings.TrimSpace(m2Golden) + "\n", strings.TrimSpace(compact(m2Golden)) + " "},
{m3, strings.TrimSpace(m3Golden) + "\n", strings.TrimSpace(compact(m3Golden)) + " "},
{m4, strings.TrimSpace(m4Golden) + "\n", strings.TrimSpace(compact(m4Golden)) + " "},
{m5, strings.TrimSpace(m5Golden) + "\n", strings.TrimSpace(compact(m5Golden)) + " "},
{m6, strings.TrimSpace(m6Golden) + "\n", strings.TrimSpace(compact(m6Golden)) + " "},
}
}
func TestMarshalGolden(t *testing.T) {
for _, tt := range goldenMessages {
if got, want := expandedMarshaler.Text(tt.m), tt.t; got != want {
t.Errorf("message %v: got:\n%s\nwant:\n%s", tt.m, got, want)
}
if got, want := expandedCompactMarshaler.Text(tt.m), tt.c; got != want {
t.Errorf("message %v: got:\n`%s`\nwant:\n`%s`", tt.m, got, want)
}
}
}
func TestUnmarshalGolden(t *testing.T) {
for _, tt := range goldenMessages {
want := tt.m
got := proto.Clone(tt.m)
got.Reset()
if err := proto.UnmarshalText(tt.t, got); err != nil {
t.Errorf("failed to unmarshal\n%s\nerror: %v", tt.t, err)
}
if !anyEqual(got, want) {
t.Errorf("message:\n%s\ngot:\n%s\nwant:\n%s", tt.t, got, want)
}
got.Reset()
if err := proto.UnmarshalText(tt.c, got); err != nil {
t.Errorf("failed to unmarshal\n%s\nerror: %v", tt.c, err)
}
if !anyEqual(got, want) {
t.Errorf("message:\n%s\ngot:\n%s\nwant:\n%s", tt.c, got, want)
}
}
}
func TestMarsahlUnknownAny(t *testing.T) {
m := &pb.Message{
Anything: &anypb.Any{
TypeUrl: "foo",
Value: []byte("bar"),
},
}
want := `anything: <
type_url: "foo"
value: "bar"
>
`
got := expandedMarshaler.Text(m)
if got != want {
t.Errorf("got\n`%s`\nwant\n`%s`", got, want)
}
}
func TestAmbiguousAny(t *testing.T) {
pb := &anypb.Any{}
err := proto.UnmarshalText(`
[type.googleapis.com/proto3_proto.Nested]: <
bunny: "Monty"
>
type_url: "ttt/proto3_proto.Nested"
`, pb)
t.Logf("result: %v (error: %v)", expandedMarshaler.Text(pb), err)
if err != nil {
t.Errorf("failed to parse ambiguous Any message: %v", err)
}
}

View File

@ -84,15 +84,9 @@ func mergeStruct(out, in reflect.Value) {
mergeAny(out.Field(i), in.Field(i), false, sprop.Prop[i])
}
if emIn, ok := extendable(in.Addr().Interface()); ok {
emOut, _ := extendable(out.Addr().Interface())
mIn, muIn := emIn.extensionsRead()
if mIn != nil {
mOut := emOut.extensionsWrite()
muIn.Lock()
mergeExtension(mOut, mIn)
muIn.Unlock()
}
if emIn, ok := in.Addr().Interface().(extendableProto); ok {
emOut := out.Addr().Interface().(extendableProto)
mergeExtension(emOut.ExtensionMap(), emIn.ExtensionMap())
}
uf := in.FieldByName("XXX_unrecognized")

View File

@ -390,12 +390,11 @@ func (o *Buffer) unmarshalType(st reflect.Type, prop *StructProperties, is_group
if !ok {
// Maybe it's an extension?
if prop.extendable {
if e, _ := extendable(structPointer_Interface(base, st)); isExtensionField(e, int32(tag)) {
if e := structPointer_Interface(base, st).(extendableProto); isExtensionField(e, int32(tag)) {
if err = o.skip(st, tag, wire); err == nil {
extmap := e.extensionsWrite()
ext := extmap[int32(tag)] // may be missing
ext := e.ExtensionMap()[int32(tag)] // may be missing
ext.enc = append(ext.enc, o.buf[oi:o.index]...)
extmap[int32(tag)] = ext
e.ExtensionMap()[int32(tag)] = ext
}
continue
}
@ -769,11 +768,10 @@ func (o *Buffer) dec_new_map(p *Properties, base structPointer) error {
}
}
keyelem, valelem := keyptr.Elem(), valptr.Elem()
if !keyelem.IsValid() {
keyelem = reflect.Zero(p.mtype.Key())
}
if !valelem.IsValid() {
valelem = reflect.Zero(p.mtype.Elem())
if !keyelem.IsValid() || !valelem.IsValid() {
// We did not decode the key or the value in the map entry.
// Either way, it's an invalid map entry.
return fmt.Errorf("proto: bad map data: missing key/val")
}
v.SetMapIndex(keyelem, valelem)

View File

@ -64,16 +64,8 @@ var (
// a struct with a repeated field containing a nil element.
errRepeatedHasNil = errors.New("proto: repeated field has nil element")
// errOneofHasNil is the error returned if Marshal is called with
// a struct with a oneof field containing a nil element.
errOneofHasNil = errors.New("proto: oneof field has nil value")
// ErrNil is the error returned if Marshal is called with nil.
ErrNil = errors.New("proto: Marshal called with nil")
// ErrTooLarge is the error returned if Marshal is called with a
// message that encodes to >2GB.
ErrTooLarge = errors.New("proto: message encodes to over 2 GB")
)
// The fundamental encoders that put bytes on the wire.
@ -82,10 +74,6 @@ var (
const maxVarintBytes = 10 // maximum length of a varint
// maxMarshalSize is the largest allowed size of an encoded protobuf,
// since C++ and Java use signed int32s for the size.
const maxMarshalSize = 1<<31 - 1
// EncodeVarint returns the varint encoding of x.
// This is the format for the
// int32, int64, uint32, uint64, bool, and enum
@ -285,9 +273,6 @@ func (p *Buffer) Marshal(pb Message) error {
stats.Encode++
}
if len(p.buf) > maxMarshalSize {
return ErrTooLarge
}
return err
}
@ -1073,25 +1058,10 @@ func size_slice_struct_group(p *Properties, base structPointer) (n int) {
// Encode an extension map.
func (o *Buffer) enc_map(p *Properties, base structPointer) error {
exts := structPointer_ExtMap(base, p.field)
if err := encodeExtensionsMap(*exts); err != nil {
v := *structPointer_ExtMap(base, p.field)
if err := encodeExtensionMap(v); err != nil {
return err
}
return o.enc_map_body(*exts)
}
func (o *Buffer) enc_exts(p *Properties, base structPointer) error {
exts := structPointer_Extensions(base, p.field)
if err := encodeExtensions(exts); err != nil {
return err
}
v, _ := exts.extensionsRead()
return o.enc_map_body(v)
}
func (o *Buffer) enc_map_body(v map[int32]Extension) error {
// Fast-path for common cases: zero or one extensions.
if len(v) <= 1 {
for _, e := range v {
@ -1114,13 +1084,8 @@ func (o *Buffer) enc_map_body(v map[int32]Extension) error {
}
func size_map(p *Properties, base structPointer) int {
v := structPointer_ExtMap(base, p.field)
return extensionsMapSize(*v)
}
func size_exts(p *Properties, base structPointer) int {
v := structPointer_Extensions(base, p.field)
return extensionsSize(v)
v := *structPointer_ExtMap(base, p.field)
return sizeExtensionMap(v)
}
// Encode a map field.
@ -1149,7 +1114,7 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
if err := p.mkeyprop.enc(o, p.mkeyprop, keybase); err != nil {
return err
}
if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil && err != ErrNil {
if err := p.mvalprop.enc(o, p.mvalprop, valbase); err != nil {
return err
}
return nil
@ -1159,6 +1124,11 @@ func (o *Buffer) enc_new_map(p *Properties, base structPointer) error {
for _, key := range v.MapKeys() {
val := v.MapIndex(key)
// The only illegal map entry values are nil message pointers.
if val.Kind() == reflect.Ptr && val.IsNil() {
return errors.New("proto: map has nil element")
}
keycopy.Set(key)
valcopy.Set(val)
@ -1246,18 +1216,13 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
return err
}
}
if len(o.buf) > maxMarshalSize {
return ErrTooLarge
}
}
}
// Do oneof fields.
if prop.oneofMarshaler != nil {
m := structPointer_Interface(base, prop.stype).(Message)
if err := prop.oneofMarshaler(m, o); err == ErrNil {
return errOneofHasNil
} else if err != nil {
if err := prop.oneofMarshaler(m, o); err != nil {
return err
}
}
@ -1265,9 +1230,6 @@ func (o *Buffer) enc_struct(prop *StructProperties, base structPointer) error {
// Add unrecognized fields at the end.
if prop.unrecField.IsValid() {
v := *structPointer_Bytes(base, prop.unrecField)
if len(o.buf)+len(v) > maxMarshalSize {
return ErrTooLarge
}
if len(v) > 0 {
o.buf = append(o.buf, v...)
}

View File

@ -121,16 +121,9 @@ func equalStruct(v1, v2 reflect.Value) bool {
}
}
if em1 := v1.FieldByName("XXX_InternalExtensions"); em1.IsValid() {
em2 := v2.FieldByName("XXX_InternalExtensions")
if !equalExtensions(v1.Type(), em1.Interface().(XXX_InternalExtensions), em2.Interface().(XXX_InternalExtensions)) {
return false
}
}
if em1 := v1.FieldByName("XXX_extensions"); em1.IsValid() {
em2 := v2.FieldByName("XXX_extensions")
if !equalExtMap(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
if !equalExtensions(v1.Type(), em1.Interface().(map[int32]Extension), em2.Interface().(map[int32]Extension)) {
return false
}
}
@ -191,13 +184,6 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
}
return true
case reflect.Ptr:
// Maps may have nil values in them, so check for nil.
if v1.IsNil() && v2.IsNil() {
return true
}
if v1.IsNil() != v2.IsNil() {
return false
}
return equalAny(v1.Elem(), v2.Elem(), prop)
case reflect.Slice:
if v1.Type().Elem().Kind() == reflect.Uint8 {
@ -237,14 +223,8 @@ func equalAny(v1, v2 reflect.Value, prop *Properties) bool {
}
// base is the struct type that the extensions are based on.
// x1 and x2 are InternalExtensions.
func equalExtensions(base reflect.Type, x1, x2 XXX_InternalExtensions) bool {
em1, _ := x1.extensionsRead()
em2, _ := x2.extensionsRead()
return equalExtMap(base, em1, em2)
}
func equalExtMap(base reflect.Type, em1, em2 map[int32]Extension) bool {
// em1 and em2 are extension maps.
func equalExtensions(base reflect.Type, em1, em2 map[int32]Extension) bool {
if len(em1) != len(em2) {
return false
}

View File

@ -52,99 +52,14 @@ type ExtensionRange struct {
Start, End int32 // both inclusive
}
// extendableProto is an interface implemented by any protocol buffer generated by the current
// proto compiler that may be extended.
// extendableProto is an interface implemented by any protocol buffer that may be extended.
type extendableProto interface {
Message
ExtensionRangeArray() []ExtensionRange
extensionsWrite() map[int32]Extension
extensionsRead() (map[int32]Extension, sync.Locker)
}
// extendableProtoV1 is an interface implemented by a protocol buffer generated by the previous
// version of the proto compiler that may be extended.
type extendableProtoV1 interface {
Message
ExtensionRangeArray() []ExtensionRange
ExtensionMap() map[int32]Extension
}
// extensionAdapter is a wrapper around extendableProtoV1 that implements extendableProto.
type extensionAdapter struct {
extendableProtoV1
}
func (e extensionAdapter) extensionsWrite() map[int32]Extension {
return e.ExtensionMap()
}
func (e extensionAdapter) extensionsRead() (map[int32]Extension, sync.Locker) {
return e.ExtensionMap(), notLocker{}
}
// notLocker is a sync.Locker whose Lock and Unlock methods are nops.
type notLocker struct{}
func (n notLocker) Lock() {}
func (n notLocker) Unlock() {}
// extendable returns the extendableProto interface for the given generated proto message.
// If the proto message has the old extension format, it returns a wrapper that implements
// the extendableProto interface.
func extendable(p interface{}) (extendableProto, bool) {
if ep, ok := p.(extendableProto); ok {
return ep, ok
}
if ep, ok := p.(extendableProtoV1); ok {
return extensionAdapter{ep}, ok
}
return nil, false
}
// XXX_InternalExtensions is an internal representation of proto extensions.
//
// Each generated message struct type embeds an anonymous XXX_InternalExtensions field,
// thus gaining the unexported 'extensions' method, which can be called only from the proto package.
//
// The methods of XXX_InternalExtensions are not concurrency safe in general,
// but calls to logically read-only methods such as has and get may be executed concurrently.
type XXX_InternalExtensions struct {
// The struct must be indirect so that if a user inadvertently copies a
// generated message and its embedded XXX_InternalExtensions, they
// avoid the mayhem of a copied mutex.
//
// The mutex serializes all logically read-only operations to p.extensionMap.
// It is up to the client to ensure that write operations to p.extensionMap are
// mutually exclusive with other accesses.
p *struct {
mu sync.Mutex
extensionMap map[int32]Extension
}
}
// extensionsWrite returns the extension map, creating it on first use.
func (e *XXX_InternalExtensions) extensionsWrite() map[int32]Extension {
if e.p == nil {
e.p = new(struct {
mu sync.Mutex
extensionMap map[int32]Extension
})
e.p.extensionMap = make(map[int32]Extension)
}
return e.p.extensionMap
}
// extensionsRead returns the extensions map for read-only use. It may be nil.
// The caller must hold the returned mutex's lock when accessing Elements within the map.
func (e *XXX_InternalExtensions) extensionsRead() (map[int32]Extension, sync.Locker) {
if e.p == nil {
return nil, nil
}
return e.p.extensionMap, &e.p.mu
}
var extendableProtoType = reflect.TypeOf((*extendableProto)(nil)).Elem()
var extendableProtoV1Type = reflect.TypeOf((*extendableProtoV1)(nil)).Elem()
// ExtensionDesc represents an extension specification.
// Used in generated code from the protocol compiler.
@ -177,13 +92,8 @@ type Extension struct {
}
// SetRawExtension is for testing only.
func SetRawExtension(base Message, id int32, b []byte) {
epb, ok := extendable(base)
if !ok {
return
}
extmap := epb.extensionsWrite()
extmap[id] = Extension{enc: b}
func SetRawExtension(base extendableProto, id int32, b []byte) {
base.ExtensionMap()[id] = Extension{enc: b}
}
// isExtensionField returns true iff the given field number is in an extension range.
@ -198,12 +108,8 @@ func isExtensionField(pb extendableProto, field int32) bool {
// checkExtensionTypes checks that the given extension is valid for pb.
func checkExtensionTypes(pb extendableProto, extension *ExtensionDesc) error {
var pbi interface{} = pb
// Check the extended type.
if ea, ok := pbi.(extensionAdapter); ok {
pbi = ea.extendableProtoV1
}
if a, b := reflect.TypeOf(pbi), reflect.TypeOf(extension.ExtendedType); a != b {
if a, b := reflect.TypeOf(pb), reflect.TypeOf(extension.ExtendedType); a != b {
return errors.New("proto: bad extended type; " + b.String() + " does not extend " + a.String())
}
// Check the range.
@ -249,19 +155,8 @@ func extensionProperties(ed *ExtensionDesc) *Properties {
return prop
}
// encode encodes any unmarshaled (unencoded) extensions in e.
func encodeExtensions(e *XXX_InternalExtensions) error {
m, mu := e.extensionsRead()
if m == nil {
return nil // fast path
}
mu.Lock()
defer mu.Unlock()
return encodeExtensionsMap(m)
}
// encode encodes any unmarshaled (unencoded) extensions in e.
func encodeExtensionsMap(m map[int32]Extension) error {
// encodeExtensionMap encodes any unmarshaled (unencoded) extensions in m.
func encodeExtensionMap(m map[int32]Extension) error {
for k, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
@ -289,17 +184,7 @@ func encodeExtensionsMap(m map[int32]Extension) error {
return nil
}
func extensionsSize(e *XXX_InternalExtensions) (n int) {
m, mu := e.extensionsRead()
if m == nil {
return 0
}
mu.Lock()
defer mu.Unlock()
return extensionsMapSize(m)
}
func extensionsMapSize(m map[int32]Extension) (n int) {
func sizeExtensionMap(m map[int32]Extension) (n int) {
for _, e := range m {
if e.value == nil || e.desc == nil {
// Extension is only in its encoded form.
@ -324,51 +209,26 @@ func extensionsMapSize(m map[int32]Extension) (n int) {
}
// HasExtension returns whether the given extension is present in pb.
func HasExtension(pb Message, extension *ExtensionDesc) bool {
func HasExtension(pb extendableProto, extension *ExtensionDesc) bool {
// TODO: Check types, field numbers, etc.?
epb, ok := extendable(pb)
if !ok {
return false
}
extmap, mu := epb.extensionsRead()
if extmap == nil {
return false
}
mu.Lock()
_, ok = extmap[extension.Field]
mu.Unlock()
_, ok := pb.ExtensionMap()[extension.Field]
return ok
}
// ClearExtension removes the given extension from pb.
func ClearExtension(pb Message, extension *ExtensionDesc) {
epb, ok := extendable(pb)
if !ok {
return
}
func ClearExtension(pb extendableProto, extension *ExtensionDesc) {
// TODO: Check types, field numbers, etc.?
extmap := epb.extensionsWrite()
delete(extmap, extension.Field)
delete(pb.ExtensionMap(), extension.Field)
}
// GetExtension parses and returns the given extension of pb.
// If the extension is not present and has no default value it returns ErrMissingExtension.
func GetExtension(pb Message, extension *ExtensionDesc) (interface{}, error) {
epb, ok := extendable(pb)
if !ok {
return nil, errors.New("proto: not an extendable proto")
}
if err := checkExtensionTypes(epb, extension); err != nil {
func GetExtension(pb extendableProto, extension *ExtensionDesc) (interface{}, error) {
if err := checkExtensionTypes(pb, extension); err != nil {
return nil, err
}
emap, mu := epb.extensionsRead()
if emap == nil {
return defaultExtensionValue(extension)
}
mu.Lock()
defer mu.Unlock()
emap := pb.ExtensionMap()
e, ok := emap[extension.Field]
if !ok {
// defaultExtensionValue returns the default value or
@ -472,9 +332,10 @@ func decodeExtension(b []byte, extension *ExtensionDesc) (interface{}, error) {
// GetExtensions returns a slice of the extensions present in pb that are also listed in es.
// The returned slice has the same length as es; missing extensions will appear as nil elements.
func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, err error) {
epb, ok := extendable(pb)
epb, ok := pb.(extendableProto)
if !ok {
return nil, errors.New("proto: not an extendable proto")
err = errors.New("proto: not an extendable proto")
return
}
extensions = make([]interface{}, len(es))
for i, e := range es {
@ -490,12 +351,8 @@ func GetExtensions(pb Message, es []*ExtensionDesc) (extensions []interface{}, e
}
// SetExtension sets the specified extension of pb to the specified value.
func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error {
epb, ok := extendable(pb)
if !ok {
return errors.New("proto: not an extendable proto")
}
if err := checkExtensionTypes(epb, extension); err != nil {
func SetExtension(pb extendableProto, extension *ExtensionDesc, value interface{}) error {
if err := checkExtensionTypes(pb, extension); err != nil {
return err
}
typ := reflect.TypeOf(extension.ExtensionType)
@ -511,23 +368,10 @@ func SetExtension(pb Message, extension *ExtensionDesc, value interface{}) error
return fmt.Errorf("proto: SetExtension called with nil value of type %T", value)
}
extmap := epb.extensionsWrite()
extmap[extension.Field] = Extension{desc: extension, value: value}
pb.ExtensionMap()[extension.Field] = Extension{desc: extension, value: value}
return nil
}
// ClearAllExtensions clears all extensions from pb.
func ClearAllExtensions(pb Message) {
epb, ok := extendable(pb)
if !ok {
return
}
m := epb.extensionsWrite()
for k := range m {
delete(m, k)
}
}
// A global registry of extensions.
// The generated code will register the generated descriptors by calling RegisterExtension.

Some files were not shown because too many files have changed in this diff Show More