clair/vendor/github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/descriptor/services.go
2017-06-13 15:58:11 -04:00

267 lines
7.0 KiB
Go

package descriptor
import (
"fmt"
"strings"
"github.com/golang/glog"
"github.com/golang/protobuf/proto"
descriptor "github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/grpc-ecosystem/grpc-gateway/protoc-gen-grpc-gateway/httprule"
options "google.golang.org/genproto/googleapis/api/annotations"
)
// loadServices registers services and their methods from "targetFile" to "r".
// It must be called after loadFile is called for all files so that loadServices
// can resolve names of message types and their fields.
func (r *Registry) loadServices(file *File) error {
glog.V(1).Infof("Loading services from %s", file.GetName())
var svcs []*Service
for _, sd := range file.GetService() {
glog.V(2).Infof("Registering %s", sd.GetName())
svc := &Service{
File: file,
ServiceDescriptorProto: sd,
}
for _, md := range sd.GetMethod() {
glog.V(2).Infof("Processing %s.%s", sd.GetName(), md.GetName())
opts, err := extractAPIOptions(md)
if err != nil {
glog.Errorf("Failed to extract ApiMethodOptions from %s.%s: %v", svc.GetName(), md.GetName(), err)
return err
}
if opts == nil {
glog.V(1).Infof("Found non-target method: %s.%s", svc.GetName(), md.GetName())
}
meth, err := r.newMethod(svc, md, opts)
if err != nil {
return err
}
svc.Methods = append(svc.Methods, meth)
}
if len(svc.Methods) == 0 {
continue
}
glog.V(2).Infof("Registered %s with %d method(s)", svc.GetName(), len(svc.Methods))
svcs = append(svcs, svc)
}
file.Services = svcs
return nil
}
func (r *Registry) newMethod(svc *Service, md *descriptor.MethodDescriptorProto, opts *options.HttpRule) (*Method, error) {
requestType, err := r.LookupMsg(svc.File.GetPackage(), md.GetInputType())
if err != nil {
return nil, err
}
responseType, err := r.LookupMsg(svc.File.GetPackage(), md.GetOutputType())
if err != nil {
return nil, err
}
meth := &Method{
Service: svc,
MethodDescriptorProto: md,
RequestType: requestType,
ResponseType: responseType,
}
newBinding := func(opts *options.HttpRule, idx int) (*Binding, error) {
var (
httpMethod string
pathTemplate string
)
switch {
case opts.GetGet() != "":
httpMethod = "GET"
pathTemplate = opts.GetGet()
if opts.Body != "" {
return nil, fmt.Errorf("needs request body even though http method is GET: %s", md.GetName())
}
case opts.GetPut() != "":
httpMethod = "PUT"
pathTemplate = opts.GetPut()
case opts.GetPost() != "":
httpMethod = "POST"
pathTemplate = opts.GetPost()
case opts.GetDelete() != "":
httpMethod = "DELETE"
pathTemplate = opts.GetDelete()
if opts.Body != "" && !r.allowDeleteBody {
return nil, fmt.Errorf("needs request body even though http method is DELETE: %s", md.GetName())
}
case opts.GetPatch() != "":
httpMethod = "PATCH"
pathTemplate = opts.GetPatch()
case opts.GetCustom() != nil:
custom := opts.GetCustom()
httpMethod = custom.Kind
pathTemplate = custom.Path
default:
glog.V(1).Infof("No pattern specified in google.api.HttpRule: %s", md.GetName())
return nil, nil
}
parsed, err := httprule.Parse(pathTemplate)
if err != nil {
return nil, err
}
tmpl := parsed.Compile()
if md.GetClientStreaming() && len(tmpl.Fields) > 0 {
return nil, fmt.Errorf("cannot use path parameter in client streaming")
}
b := &Binding{
Method: meth,
Index: idx,
PathTmpl: tmpl,
HTTPMethod: httpMethod,
}
for _, f := range tmpl.Fields {
param, err := r.newParam(meth, f)
if err != nil {
return nil, err
}
b.PathParams = append(b.PathParams, param)
}
// TODO(yugui) Handle query params
b.Body, err = r.newBody(meth, opts.Body)
if err != nil {
return nil, err
}
return b, nil
}
b, err := newBinding(opts, 0)
if err != nil {
return nil, err
}
if b != nil {
meth.Bindings = append(meth.Bindings, b)
}
for i, additional := range opts.GetAdditionalBindings() {
if len(additional.AdditionalBindings) > 0 {
return nil, fmt.Errorf("additional_binding in additional_binding not allowed: %s.%s", svc.GetName(), meth.GetName())
}
b, err := newBinding(additional, i+1)
if err != nil {
return nil, err
}
meth.Bindings = append(meth.Bindings, b)
}
return meth, nil
}
func extractAPIOptions(meth *descriptor.MethodDescriptorProto) (*options.HttpRule, error) {
if meth.Options == nil {
return nil, nil
}
if !proto.HasExtension(meth.Options, options.E_Http) {
return nil, nil
}
ext, err := proto.GetExtension(meth.Options, options.E_Http)
if err != nil {
return nil, err
}
opts, ok := ext.(*options.HttpRule)
if !ok {
return nil, fmt.Errorf("extension is %T; want an HttpRule", ext)
}
return opts, nil
}
func (r *Registry) newParam(meth *Method, path string) (Parameter, error) {
msg := meth.RequestType
fields, err := r.resolveFiledPath(msg, path)
if err != nil {
return Parameter{}, err
}
l := len(fields)
if l == 0 {
return Parameter{}, fmt.Errorf("invalid field access list for %s", path)
}
target := fields[l-1].Target
switch target.GetType() {
case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
return Parameter{}, fmt.Errorf("aggregate type %s in parameter of %s.%s: %s", target.Type, meth.Service.GetName(), meth.GetName(), path)
}
return Parameter{
FieldPath: FieldPath(fields),
Method: meth,
Target: fields[l-1].Target,
}, nil
}
func (r *Registry) newBody(meth *Method, path string) (*Body, error) {
msg := meth.RequestType
switch path {
case "":
return nil, nil
case "*":
return &Body{FieldPath: nil}, nil
}
fields, err := r.resolveFiledPath(msg, path)
if err != nil {
return nil, err
}
return &Body{FieldPath: FieldPath(fields)}, nil
}
// lookupField looks up a field named "name" within "msg".
// It returns nil if no such field found.
func lookupField(msg *Message, name string) *Field {
for _, f := range msg.Fields {
if f.GetName() == name {
return f
}
}
return nil
}
// resolveFieldPath resolves "path" into a list of fieldDescriptor, starting from "msg".
func (r *Registry) resolveFiledPath(msg *Message, path string) ([]FieldPathComponent, error) {
if path == "" {
return nil, nil
}
root := msg
var result []FieldPathComponent
for i, c := range strings.Split(path, ".") {
if i > 0 {
f := result[i-1].Target
switch f.GetType() {
case descriptor.FieldDescriptorProto_TYPE_MESSAGE, descriptor.FieldDescriptorProto_TYPE_GROUP:
var err error
msg, err = r.LookupMsg(msg.FQMN(), f.GetTypeName())
if err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("not an aggregate type: %s in %s", f.GetName(), path)
}
}
glog.V(2).Infof("Lookup %s in %s", c, msg.FQMN())
f := lookupField(msg, c)
if f == nil {
return nil, fmt.Errorf("no field %q found in %s", path, root.GetName())
}
if f.GetLabel() == descriptor.FieldDescriptorProto_LABEL_REPEATED {
return nil, fmt.Errorf("repeated field not allowed in field path: %s in %s", f.GetName(), path)
}
result = append(result, FieldPathComponent{Name: c, Target: f})
}
return result, nil
}