diff --git a/tools/build_protobuf b/tools/build_protobuf old mode 100755 new mode 100644 index 867d48ac4..1a0a79f4c --- a/tools/build_protobuf +++ b/tools/build_protobuf @@ -1,32 +1,32 @@ #!/bin/bash -CURDIR=$(pwd) -PB2DIR=$CURDIR/pb2 -OUTDIR=../trezorlib/messages -INDEX=$OUTDIR/__init__.py - -rm -f $OUTDIR/[A-Z]*.py -mkdir -p $OUTDIR -mkdir -p $PB2DIR -touch $PB2DIR/__init__.py - -rm -f $INDEX -echo '# Automatically generated by pb2py' >> $INDEX -echo 'from __future__ import absolute_import' >> $INDEX -echo '' >> $INDEX - -for i in types messages ; do +set -e + +cd "$(dirname "$0")" + +GENPATH="../trezorlib/messages" +INDEX="$GENPATH/__init__.py" +PROTO_PATH="../../trezor-common/protob" +PROTO_FILES="types messages" +PB2_OUT="pb2" + +rm -f "$GENPATH/[A-Z]*.py" +mkdir -p "$GENPATH" + +cat > "$INDEX" << EOF +# Automatically generated by pb2py + +EOF + +mkdir -p "$PB2_OUT" + +for file in $PROTO_FILES; do # Compile .proto files to python2 modules using google protobuf library - cd $CURDIR/../../trezor-common/protob - protoc --python_out=$PB2DIR -I/usr/include -I. $i.proto + protoc --python_out="$PB2_OUT" -I"$PROTO_PATH" "$file.proto" done -# hack to make output python 3 compatible -sed -i 's/^import types_pb2/from . import types_pb2/g' $CURDIR/pb2/messages_pb2.py - -for i in types messages ; do +for file in $PROTO_FILES; do # Convert google protobuf library to trezor's internal format - cd $CURDIR - ./pb2py -p $CURDIR -l $INDEX $i $OUTDIR + ./pb2py -P "trezorlib.protobuf" -p "$PB2_OUT" -l "$INDEX" "$file" "$GENPATH" done -rm -rf $PB2DIR +rm -rf "$PB2_OUT" diff --git a/tools/pb2py b/tools/pb2py index 50d7cebc7..4b8685341 100755 --- a/tools/pb2py +++ b/tools/pb2py @@ -3,72 +3,78 @@ # to plain-python objects as used in TREZOR Core and python-trezor import sys +import importlib import os import argparse -from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper +def import_pb2(name): + return importlib.import_module("%s_pb2" % name) -def process_type(t, cls, msg_id, indexfile, is_upy): - print(" * type %s" % t) - imports = [] - out = ["", "", "class %s(p.MessageType):" % t, ] - - if cls.DESCRIPTOR.fields_by_name: - out.append(" FIELDS = {") - elif msg_id is None: - out.append(" pass") - - for v in sorted(cls.DESCRIPTOR.fields_by_name.values(), key=lambda x: x.number): - number = v.number - fieldname = v.name - type = None - repeated = v.label == 3 - required = v.label == 2 +def create_message_import(name): + return "from .%s import %s" % (name, name) - # print v.has_default_value, v.default_value - if v.type in (4, 13, 14): - # TYPE_UINT64 = 4 - # TYPE_UINT32 = 13 - # TYPE_ENUM = 14 - type = 'p.UVarintType' +def create_const(name, value, is_upy): + if is_upy: + return "%s = const(%s)" % (name, value) + else: + return "%s = %s" % (name, value) - elif v.type in (17,): - # TYPE_SINT32 = 17 - type = 'p.Sint32Type' - elif v.type in (18,): - # TYPE_SINT64 = 18 - type = 'p.Sint64Type' +def remove_from_start(s, prefix): + if s.startswith(prefix): + return s[len(prefix):] + else: + return s - elif v.type == 9: - # TYPE_STRING = 9 - type = 'p.UnicodeType' - elif v.type == 8: - # TYPE_BOOL = 8 - type = 'p.BoolType' +def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy): + print(" * type %s" % descriptor.name) - elif v.type == 12: - # TYPE_BYTES = 12 - type = 'p.BytesType' + imports = [] + out = ["", "", "class %s(p.MessageType):" % descriptor.name, ] - elif v.type == 11: - # TYPE_MESSAGE = 1 - type = v.message_type.name - imports.append("from .%s import %s" % - (v.message_type.name, v.message_type.name)) + if descriptor.fields_by_number: + out.append(" FIELDS = {") + elif msg_id is None: + out.append(" pass") + for number, field in descriptor.fields_by_number.items(): + field_name = field.name + field_type = None + repeated = (field.label == field.LABEL_REPEATED) + required = (field.label == field.LABEL_REQUIRED) + + types = { + field.TYPE_UINT64: 'p.UVarintType', + field.TYPE_UINT32: 'p.UVarintType', + field.TYPE_ENUM: 'p.UVarintType', + field.TYPE_SINT32: 'p.Sint32Type', + field.TYPE_SINT64: 'p.Sint64Type', + field.TYPE_STRING: 'p.UnicodeType', + field.TYPE_BOOL: 'p.BoolType', + field.TYPE_BYTES: 'p.BytesType' + } + + if field.type == field.TYPE_MESSAGE: + field_type = field.message_type.name + imports.append(create_message_import(field_type)) else: - raise Exception("Unknown field type %s for field %s" % - (v.type, fieldname)) + try: + field_type = types[field.type] + except KeyError: + raise ValueError("Unknown field type %d for field %s" % (field.type, field_name)) + comments = [] if required: - comment = ' # required' - elif v.has_default_value: - comment = ' # default=%s' % repr(v.default_value) + comments.append('required') + if field.has_default_value: + comments.append("default=%s" % repr(field.default_value)) + + if comments: + comment = " # %s" % ' '.join(comments) else: comment = '' @@ -78,27 +84,15 @@ def process_type(t, cls, msg_id, indexfile, is_upy): flags = '0' out.append(" %d: ('%s', %s, %s),%s" % - (number, fieldname, type, flags, comment)) - - # print fieldname, number, type, repeated, comment - # print v.__dict__ - # print v.CPPTYPE_STRING - # print v.LABEL_REPEATED - # print v.enum_type - # v.has_default_value, v.default_value - # v.label == 3 # repeated - # print v.number - - if cls.DESCRIPTOR.fields_by_name: + (number, field_name, field_type, flags, comment)) + + if descriptor.fields_by_name: out.append(" }") if msg_id is not None: out.append(" MESSAGE_WIRE_TYPE = %d" % msg_id) if indexfile is not None: - if is_upy: - indexfile.write("%s = const(%d)\n" % (t, msg_id)) - else: - indexfile.write("%s = %d\n" % (t, msg_id)) + indexfile.write(create_const(t, msg_id, is_upy)) # Remove duplicate imports imports = sorted(list(set(imports))) @@ -106,73 +100,60 @@ def process_type(t, cls, msg_id, indexfile, is_upy): if is_upy: imports = ['import protobuf as p'] + imports else: - imports = ['from __future__ import absolute_import', - 'from .. import protobuf as p'] + imports + imports = ['from .. import protobuf as p'] + imports return imports + out -def process_enum(t, cls, is_upy): +def process_enum(descriptor, is_upy): out = [] if is_upy: out += ("from micropython import const", "") - print(" * enum %s" % t) + print(" * enum %s" % descriptor.name) - for k, v in cls.items(): + for name, value in descriptor.values_by_name.items(): # Remove type name from the beginning of the constant # For example "PinMatrixRequestType_Current" -> "Current" - if k.startswith("%s_" % t): - k = k.replace("%s_" % t, '') + enum_prefix = descriptor.name + name = remove_from_start(name, "%s_" % enum_prefix) # If type ends with *Type, but constant use type name without *Type, remove it too :) # For example "ButtonRequestType & ButtonRequest_Other" => "Other" - if t.endswith("Type") and k.startswith("%s_" % t.replace("Type", '')): - k = k.replace("%s_" % t.replace("Type", ''), '') + if enum_prefix.endswith("Type"): + enum_prefix, _ = enum_prefix.rsplit("Type", 1) + name = remove_from_start(name, "%s_" % enum_prefix) - if is_upy: - out.append("%s = const(%s)" % (k, v)) - else: - out.append("%s = %s" % (k, v)) + out.append(create_const(name, value.number, is_upy)) return out -def find_msg_type(msg_types, t): - for k, v in msg_types: - msg_name = k.replace('MessageType_', '') - if msg_name == t: - return v - - -def process_module(mod, genpath, indexfile, modlist, is_upy): +def process_file(descriptor, protobuf_module, genpath, indexfile, modlist, is_upy): - print("Processing module %s" % mod.__name__) - types = dict([(name, cls) - for name, cls in mod.__dict__.items() if isinstance(cls, type)]) + print("Processing module %s" % descriptor.name) - msg_types = __import__('pb2', globals(), locals(), [ - 'messages_pb2', ]).messages_pb2.MessageType.items() + msg_types = import_pb2('messages').MessageType - for t, cls in sorted(types.items()): + for name, message_descriptor in descriptor.message_types_by_name.items(): # Find message type for given class - msg_id = find_msg_type(msg_types, t) + try: + msg_id = msg_types.Value("MessageType_%s" % name) + except ValueError: + msg_id = None - out = process_type(t, cls, msg_id, indexfile, is_upy) + out = process_message(message_descriptor, protobuf_module, msg_id, indexfile, is_upy) - write_to_file(genpath, t, out) + write_to_file(genpath, name, out) if modlist: - modlist.write("from .%s import *\n" % t) + modlist.write(create_message_import(name) + "\n") - enums = dict([(name, cls) for name, cls in mod.__dict__.items() - if isinstance(cls, EnumTypeWrapper)]) - - for t, cls in enums.items(): - out = process_enum(t, cls, is_upy) - write_to_file(genpath, t, out) + for name, enum_descriptor in descriptor.enum_types_by_name.items(): + out = process_enum(enum_descriptor, is_upy) + write_to_file(genpath, name, out) if modlist: - modlist.write("from . import %s\n" % t) + modlist.write("from . import %s\n" % name) def write_to_file(genpath, t, out): @@ -188,29 +169,19 @@ def write_to_file(genpath, t, out): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('modulename', type=str, help="Name of module to generate") - parser.add_argument('genpath', type=str, help="Directory for generated source code") - parser.add_argument('-i', '--indexfile', type=str, help="[optional] Generate index file of wire types") - parser.add_argument('-l', '--modlist', type=str, help="[optional] Generate list of modules") - parser.add_argument('-p', '--protopath', type=str, help="[optional] Path to search for pregenerated Google's python sources") + parser.add_argument('module', help="Name of module to generate") + parser.add_argument('genpath', help="Directory for generated source code") + parser.add_argument('-P', '--protobuf-module', default="protobuf", help="Name of protobuf module") + parser.add_argument('-i', '--indexfile', type=argparse.FileType('a'), help="Generate index file of wire types") + parser.add_argument('-l', '--modlist', type=argparse.FileType('a'), help="Generate list of modules") + parser.add_argument('-p', '--protopath', type=str, help="Path to search for pregenerated Google's python sources") parser.add_argument('-m', '--micropython', action='store_true', help="Use micropython-favoured source code") args = parser.parse_args() - if args.indexfile: - indexfile = open(args.indexfile, 'a') - else: - indexfile = None - - if args.modlist: - modlist = open(args.modlist, 'a') - else: - modlist = None - if args.protopath: sys.path.append(args.protopath) - # Dynamically load module from argv[1] - tmp = __import__('pb2', globals(), locals(), ['%s_pb2' % args.modulename]) - mod = getattr(tmp, "%s_pb2" % args.modulename) + # This must be done after sys.path.append + module = import_pb2(args.module) - process_module(mod, args.genpath, indexfile, modlist, args.micropython) + process_file(module.DESCRIPTOR, args.protobuf_module, args.genpath, args.indexfile, args.modlist, args.micropython)