diff --git a/tools/build_protobuf b/tools/build_protobuf index 791ddf63c..d85079e6a 100755 --- a/tools/build_protobuf +++ b/tools/build_protobuf @@ -1,34 +1,72 @@ #!/bin/bash -CURDIR=$(pwd) -PB2DIR=$CURDIR/pb2 -OUTDIR=../src/trezor/messages -INDEX=$OUTDIR/wire_types.py - -set -x - -rm -f $OUTDIR/[A-Z]*.py -mkdir -p $OUTDIR -mkdir -p $PB2DIR -touch $PB2DIR/__init__.py - -rm -f $INDEX -echo '# Automatically generated by pb2py' >> $INDEX -echo 'from micropython import const' >> $INDEX -echo '' >> $INDEX - -for i in types messages ; do - # Compile .proto files to python2 modules using google protobuf library - cd $CURDIR/../vendor/trezor-common/protob - protoc --python_out=$PB2DIR -I/usr/include -I. $i.proto +set -e + +IS_CORE="" + +if [ "$1" == "--core" ]; then + shift + IS_CORE=yes +elif [ "$1" == "--no-core" ]; then + shift +elif echo $PWD | grep -q "trezor-core"; then + IS_CORE=yes +fi + +if [ -n "$1" ]; then + OUTDIR=`readlink -f "$1"` +fi + +cd "$(dirname "$0")" + +# set up paths +INDEX="__init__.py" +GENPATH="${OUTDIR:-../src/trezor/messages}" +PROTO_PATH="../../trezor-common/protob" +PROTO_FILES="types messages" + +# set up temporary directory & cleanup +TMPDIR=$(mktemp -d) +function cleanup { + rm -r $TMPDIR +} +trap cleanup EXIT + +# set up pb2 outdir +PB2_OUT="$TMPDIR/pb2" +mkdir -p "$PB2_OUT" + +# compile .proto files to python2 modules using google protobuf library +for file in $PROTO_FILES; do + protoc --python_out="$PB2_OUT" -I/usr/include -I"$PROTO_PATH" "$PROTO_PATH/$file.proto" done -# hack to make output python 3 compatible -sed -i 's/^import types_pb2/from . import types_pb2/g' $CURDIR/pb2/messages_pb2.py -for i in types messages ; do - # Convert google protobuf library to trezor's internal format - cd $CURDIR - ./pb2py -m -p $CURDIR -i $INDEX $i $OUTDIR +if [ -n "$IS_CORE" ]; then + # generate for micropython + PB2PY_OPTS="-m" +else + # create index (__init__.py) + echo "# Automatically generated by pb2py" > $TMPDIR/$INDEX + echo >> $TMPDIR/$INDEX + PB2PY_OPTS="-l $TMPDIR/$INDEX" +fi + +# convert google protobuf library to trezor's internal format +for file in $PROTO_FILES; do + ./pb2py $PB2PY_OPTS -P "trezorlib.protobuf" -p "$PB2_OUT" "$file" "$TMPDIR" done -rm -rf $PB2DIR +if [ -n "$IS_CORE" ]; then + cp "$TMPDIR/MessageType.py" "$TMPDIR/wire_types.py" +fi + +# ensure $GENPATH exists and is empty of messages +mkdir -p "$GENPATH" +# only remove messages - there could possibly be other files not starting with capital letter +rm -f "$GENPATH"/[A-Z]*.py + +# move generated files to the destination +# (this assumes $INDEX is *.py, otherwise we'd have to add $INDEX separately) +mv "$TMPDIR"/*.py "$GENPATH" + +# the exit trap handles removing the tmp directory diff --git a/tools/pb2py b/tools/pb2py index cd2efd340..c9bedbaf5 100755 --- a/tools/pb2py +++ b/tools/pb2py @@ -2,239 +2,218 @@ # Converts Google's protobuf python definitions of TREZOR wire messages # to plain-python objects as used in TREZOR Core and python-trezor -import sys -import os import argparse +import importlib +import logging +import os +import sys +from collections import namedtuple + +ProtoField = namedtuple('ProtoField', 'name, number, proto_type, py_type, repeated, required, orig') + + +def parse_field(number, field): + FIELD_TYPES = { + field.TYPE_UINT64: ('p.UVarintType', 'int'), + field.TYPE_UINT32: ('p.UVarintType', 'int'), + field.TYPE_ENUM: ('p.UVarintType', 'int'), + field.TYPE_SINT32: ('p.SVarintType', 'int'), + field.TYPE_SINT64: ('p.SVarintType', 'int'), + field.TYPE_STRING: ('p.UnicodeType', 'str'), + field.TYPE_BOOL: ('p.BoolType', 'bool'), + field.TYPE_BYTES: ('p.BytesType', 'bytes'), + } + repeated = (field.label == field.LABEL_REPEATED) + required = (field.label == field.LABEL_REQUIRED) + if field.type == field.TYPE_MESSAGE: + proto_type = py_type = field.message_type.name + else: + try: + proto_type, py_type = FIELD_TYPES[field.type] + except KeyError: + raise ValueError("Unknown field type %d for field %s" % (field.type, field.name)) from None -from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper - - -def process_type(t, cls, msg_id, indexfile, is_upy): - print(" * type %s" % t) - - imports = [] - out = ["", "", "class %s(p.MessageType):" % t, ] - args = [] - assigns = [] - - if cls.DESCRIPTOR.fields_by_name: - out.append(" FIELDS = {") - elif msg_id is None: - out.append(" pass") - - for v in sorted(cls.DESCRIPTOR.fields_by_name.values(), key=lambda x: x.number): - number = v.number - fieldname = v.name - type = None - repeated = v.label == 3 - required = v.label == 2 - - # print v.has_default_value, v.default_value - - if v.type in (4, 13, 14): - # TYPE_UINT64 = 4 - # TYPE_UINT32 = 13 - # TYPE_ENUM = 14 - type = 'p.UVarintType' - pytype = 'int' - - elif v.type in (17,): - # TYPE_SINT32 = 17 - type = 'p.Sint32Type' - pytype = 'int' - - elif v.type in (18,): - # TYPE_SINT64 = 18 - type = 'p.Sint64Type' - pytype = 'int' - - elif v.type == 9: - # TYPE_STRING = 9 - type = 'p.UnicodeType' - pytype = 'str' - - elif v.type == 8: - # TYPE_BOOL = 8 - type = 'p.BoolType' - pytype = 'bool' - - elif v.type == 12: - # TYPE_BYTES = 12 - type = 'p.BytesType' - pytype = 'bytes' - - elif v.type == 11: - # TYPE_MESSAGE = 1 - type = v.message_type.name - pytype = v.message_type.name - imports.append("from .%s import %s" % - (v.message_type.name, v.message_type.name)) + if repeated: + py_type = "List[%s]" % py_type - else: - raise Exception("Unknown field type %s for field %s" % - (v.type, fieldname)) + return ProtoField( + name=field.name, + number=number, + proto_type=proto_type, + py_type=py_type, + repeated=repeated, + required=required, + orig=field, + ) - if required: - comment = ' # required' - elif v.has_default_value: - comment = ' # default=%s' % repr(v.default_value) - else: - comment = '' - if repeated: - flags = 'p.FLAG_REPEATED' - pytype = "list" - value = None - else: - flags = '0' - value = None +def import_pb2(name): + return importlib.import_module("%s_pb2" % name) - out.append(" %d: ('%s', %s, %s),%s" % - (number, fieldname, type, flags, comment)) - args.append(" %s: %s = %s," % (fieldname, pytype, value)) +def create_message_import(name): + return "from .%s import %s" % (name, name) - if repeated: - assigns.append(" self.%s = [] if %s is None else %s" % (fieldname, fieldname, fieldname)) - else: - assigns.append(" self.%s = %s" % (fieldname, fieldname)) - # print fieldname, number, type, repeated, comment - # print v.__dict__ - # print v.CPPTYPE_STRING - # print v.LABEL_REPEATED - # print v.enum_type - # v.has_default_value, v.default_value - # v.label == 3 # repeated - # print v.number +def create_const(name, value, is_upy): + if is_upy: + return "%s = const(%s)" % (name, value) + else: + return "%s = %s" % (name, value) - if cls.DESCRIPTOR.fields_by_name: - out.append(" }") - if msg_id is not None: - out.append(" MESSAGE_WIRE_TYPE = %d" % msg_id) - if indexfile is not None: - if is_upy: - indexfile.write("%s = const(%d)\n" % (t, msg_id)) - else: - indexfile.write("%s = %d\n" % (t, msg_id)) +def remove_from_start(s, prefix): + if s.startswith(prefix): + return s[len(prefix):] + else: + return s + + +def process_message_imports(descriptor): + imports = set() + + for field in descriptor.fields: + if field.type == field.TYPE_MESSAGE: + imports.add(field.message_type.name) + + for name in sorted(imports): + yield create_message_import(name) + + +def create_init_method(fields): + yield " def __init__(" + yield " self," + for field in fields: + yield " %s: Optional[%s] = None," % (field.name, field.py_type) + yield " **kwargs" + yield " ) -> None:" + for field in fields: + if field.repeated: + yield " self.{0} = {0} if {0} is not None else []".format(field.name) + else: + yield " self.{0} = {0}".format(field.name) + yield " super().__init__(**kwargs)" - # Remove duplicate imports - imports = sorted(list(set(imports))) + +def process_message(descriptor, protobuf_module, msg_id, is_upy): + logging.debug("Processing message %s", descriptor.name) if is_upy: - imports = ['import protobuf as p'] + imports + yield "import protobuf as p" else: - imports = ['from __future__ import absolute_import', - 'from .. import protobuf as p'] + imports + yield "from .. import protobuf as p" + yield "from typing import List, Optional" + + yield from process_message_imports(descriptor) + + yield "" + yield "" + yield "class %s(p.MessageType):" % descriptor.name + + fields = list(parse_field(number, field) + for number, field + in descriptor.fields_by_number.items()) + + if fields: + yield " FIELDS = {" + for field in fields: + comments = [] + if field.required: + comments.append('required') + if field.orig.has_default_value: + comments.append("default=%s" % repr(field.orig.default_value)) + + if comments: + comment = " # %s" % ' '.join(comments) + else: + comment = '' - args.append(" **kwargs,") - assigns.append(" p.MessageType.__init__(self, **kwargs)") + if field.repeated: + flags = 'p.FLAG_REPEATED' + else: + flags = '0' - init = ["", " def __init__(", " self,"] + args + [" ):"] + assigns + yield " %d: ('%s', %s, %s),%s" % (field.number, field.name, field.proto_type, flags, comment) - return imports + out + init + yield " }" + if msg_id is not None: + yield " MESSAGE_WIRE_TYPE = %d" % msg_id -def process_enum(t, cls, is_upy): - out = [] + yield "" + yield from create_init_method(fields) - if is_upy: - out += ("from micropython import const", "") - print(" * enum %s" % t) +def process_enum(descriptor, is_upy): + logging.debug("Processing enum %s", descriptor.name) - for k, v in cls.items(): + if is_upy: + yield "from micropython import const" + yield "" + + for name, value in descriptor.values_by_name.items(): # Remove type name from the beginning of the constant # For example "PinMatrixRequestType_Current" -> "Current" - if k.startswith("%s_" % t): - k = k.replace("%s_" % t, '') + enum_prefix = descriptor.name + name = remove_from_start(name, "%s_" % enum_prefix) # If type ends with *Type, but constant use type name without *Type, remove it too :) # For example "ButtonRequestType & ButtonRequest_Other" => "Other" - if t.endswith("Type") and k.startswith("%s_" % t.replace("Type", '')): - k = k.replace("%s_" % t.replace("Type", ''), '') - - if is_upy: - out.append("%s = const(%s)" % (k, v)) - else: - out.append("%s = %s" % (k, v)) - - return out - + if enum_prefix.endswith("Type"): + enum_prefix, _ = enum_prefix.rsplit("Type", 1) + name = remove_from_start(name, "%s_" % enum_prefix) -def find_msg_type(msg_types, t): - for k, v in msg_types: - msg_name = k.replace('MessageType_', '') - if msg_name == t: - return v + yield create_const(name, value.number, is_upy) -def process_module(mod, genpath, indexfile, modlist, is_upy): +def process_file(descriptor, protobuf_module, genpath, modlist, is_upy): + logging.info("Processing module %s", descriptor.name) - print("Processing module %s" % mod.__name__) - types = dict([(name, cls) - for name, cls in mod.__dict__.items() if isinstance(cls, type)]) + msg_types = import_pb2('messages').MessageType - msg_types = __import__('pb2', globals(), locals(), [ - 'messages_pb2', ]).messages_pb2.MessageType.items() - - for t, cls in sorted(types.items()): + for name, message_descriptor in sorted(descriptor.message_types_by_name.items()): # Find message type for given class - msg_id = find_msg_type(msg_types, t) - - out = process_type(t, cls, msg_id, indexfile, is_upy) + try: + msg_id = msg_types.Value("MessageType_%s" % name) + except ValueError: + msg_id = None - write_to_file(genpath, t, out) + out = process_message(message_descriptor, protobuf_module, msg_id, is_upy) + write_to_file(genpath, name, out) if modlist: - modlist.write("from .%s import *\n" % t) - - enums = dict([(name, cls) for name, cls in mod.__dict__.items() - if isinstance(cls, EnumTypeWrapper)]) + modlist.write(create_message_import(name) + "\n") - for t, cls in enums.items(): - out = process_enum(t, cls, is_upy) - write_to_file(genpath, t, out) + for name, enum_descriptor in descriptor.enum_types_by_name.items(): + out = process_enum(enum_descriptor, is_upy) + write_to_file(genpath, name, out) if modlist: - modlist.write("from . import %s\n" % t) + modlist.write("from . import %s\n" % name) def write_to_file(genpath, t, out): # Write generated sourcecode to given file - f = open(os.path.join(genpath, "%s.py" % t), 'w') - out = ["# Automatically generated by pb2py"] + out - - data = "\n".join(out) + "\n" - - f.write(data) - f.close() + with open(os.path.join(genpath, "%s.py" % t), 'w') as f: + f.write("# Automatically generated by pb2py\n") + for line in out: + f.write(line + "\n") if __name__ == '__main__': + logging.basicConfig(level=logging.DEBUG) + parser = argparse.ArgumentParser() - parser.add_argument('modulename', type=str, help="Name of module to generate") - parser.add_argument('genpath', type=str, help="Directory for generated source code") - parser.add_argument('-i', '--indexfile', type=str, help="[optional] Generate index file of wire types") - parser.add_argument('-l', '--modlist', type=str, help="[optional] Generate list of modules") - parser.add_argument('-p', '--protopath', type=str, help="[optional] Path to search for pregenerated Google's python sources") + parser.add_argument('module', help="Name of module to generate") + parser.add_argument('genpath', help="Directory for generated source code") + parser.add_argument('-P', '--protobuf-module', default="protobuf", help="Name of protobuf module") + parser.add_argument('-l', '--modlist', type=argparse.FileType('a'), help="Generate list of modules") + parser.add_argument('-p', '--protopath', type=str, help="Path to search for pregenerated Google's python sources") parser.add_argument('-m', '--micropython', action='store_true', help="Use micropython-favoured source code") args = parser.parse_args() - if args.indexfile: - indexfile = open(args.indexfile, 'a') - else: - indexfile = None - - if args.modlist: - modlist = open(args.modlist, 'a') - else: - modlist = None - if args.protopath: sys.path.append(args.protopath) - # Dynamically load module from argv[1] - tmp = __import__('pb2', globals(), locals(), ['%s_pb2' % args.modulename]) - mod = getattr(tmp, "%s_pb2" % args.modulename) + # This must be done after sys.path.append + module = import_pb2(args.module) - process_module(mod, args.genpath, indexfile, modlist, args.micropython) + process_file(module.DESCRIPTOR, args.protobuf_module, args.genpath, args.modlist, args.micropython)