package runtime import ( "fmt" "io" "net/http" "net/textproto" "github.com/golang/protobuf/proto" "github.com/grpc-ecosystem/grpc-gateway/runtime/internal" "golang.org/x/net/context" "google.golang.org/grpc/codes" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/status" ) // ForwardResponseStream forwards the stream from gRPC server to REST client. func ForwardResponseStream(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, recv func() (proto.Message, error), opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { f, ok := w.(http.Flusher) if !ok { grpclog.Printf("Flush not supported in %T", w) http.Error(w, "unexpected type of web server", http.StatusInternalServerError) return } md, ok := ServerMetadataFromContext(ctx) if !ok { grpclog.Printf("Failed to extract ServerMetadata from context") http.Error(w, "unexpected error", http.StatusInternalServerError) return } handleForwardResponseServerMetadata(w, mux, md) w.Header().Set("Transfer-Encoding", "chunked") w.Header().Set("Content-Type", marshaler.ContentType()) if err := handleForwardResponseOptions(ctx, w, nil, opts); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return } w.WriteHeader(http.StatusOK) f.Flush() for { resp, err := recv() if err == io.EOF { return } if err != nil { handleForwardResponseStreamError(marshaler, w, err) return } if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { handleForwardResponseStreamError(marshaler, w, err) return } buf, err := marshaler.Marshal(streamChunk(resp, nil)) if err != nil { grpclog.Printf("Failed to marshal response chunk: %v", err) return } if _, err = fmt.Fprintf(w, "%s\n", buf); err != nil { grpclog.Printf("Failed to send response chunk: %v", err) return } f.Flush() } } func handleForwardResponseServerMetadata(w http.ResponseWriter, mux *ServeMux, md ServerMetadata) { for k, vs := range md.HeaderMD { hKey := fmt.Sprintf("%s%s", MetadataHeaderPrefix, k) if mux.outgoingHeaderMatcher != nil { if h, ok := mux.outgoingHeaderMatcher(k); ok { hKey = h } } for i := range vs { w.Header().Add(hKey, vs[i]) } } } func handleForwardResponseTrailerHeader(w http.ResponseWriter, md ServerMetadata) { for k := range md.TrailerMD { tKey := textproto.CanonicalMIMEHeaderKey(fmt.Sprintf("%s%s", MetadataTrailerPrefix, k)) w.Header().Add("Trailer", tKey) } } func handleForwardResponseTrailer(w http.ResponseWriter, md ServerMetadata) { for k, vs := range md.TrailerMD { tKey := fmt.Sprintf("%s%s", MetadataTrailerPrefix, k) for i := range vs { w.Header().Add(tKey, vs[i]) } } } // ForwardResponseMessage forwards the message "resp" from gRPC server to REST client. func ForwardResponseMessage(ctx context.Context, mux *ServeMux, marshaler Marshaler, w http.ResponseWriter, req *http.Request, resp proto.Message, opts ...func(context.Context, http.ResponseWriter, proto.Message) error) { md, ok := ServerMetadataFromContext(ctx) if !ok { grpclog.Printf("Failed to extract ServerMetadata from context") } handleForwardResponseServerMetadata(w, mux, md) handleForwardResponseTrailerHeader(w, md) w.Header().Set("Content-Type", marshaler.ContentType()) if err := handleForwardResponseOptions(ctx, w, resp, opts); err != nil { HTTPError(ctx, mux, marshaler, w, req, err) return } buf, err := marshaler.Marshal(resp) if err != nil { grpclog.Printf("Marshal error: %v", err) HTTPError(ctx, mux, marshaler, w, req, err) return } if _, err = w.Write(buf); err != nil { grpclog.Printf("Failed to write response: %v", err) } handleForwardResponseTrailer(w, md) } func handleForwardResponseOptions(ctx context.Context, w http.ResponseWriter, resp proto.Message, opts []func(context.Context, http.ResponseWriter, proto.Message) error) error { if len(opts) == 0 { return nil } for _, opt := range opts { if err := opt(ctx, w, resp); err != nil { grpclog.Printf("Error handling ForwardResponseOptions: %v", err) return err } } return nil } func handleForwardResponseStreamError(marshaler Marshaler, w http.ResponseWriter, err error) { buf, merr := marshaler.Marshal(streamChunk(nil, err)) if merr != nil { grpclog.Printf("Failed to marshal an error: %v", merr) return } if _, werr := fmt.Fprintf(w, "%s\n", buf); werr != nil { grpclog.Printf("Failed to notify error to client: %v", werr) return } } func streamChunk(result proto.Message, err error) map[string]proto.Message { if err != nil { grpcCode := codes.Unknown if s, ok := status.FromError(err); ok { grpcCode = s.Code() } httpCode := HTTPStatusFromCode(grpcCode) return map[string]proto.Message{ "error": &internal.StreamError{ GrpcCode: int32(grpcCode), HttpCode: int32(httpCode), Message: err.Error(), HttpStatus: http.StatusText(httpCode), }, } } if result == nil { return streamChunk(nil, fmt.Errorf("empty response")) } return map[string]proto.Message{"result": result} }