diff --git a/tools/build_protobuf b/tools/build_protobuf index 778d7e1004..2bc74c0e31 100755 --- a/tools/build_protobuf +++ b/tools/build_protobuf @@ -8,7 +8,6 @@ mkdir -p ../trezorlib/messages INDEX=../trezorlib/messages/__init__.py rm -f $INDEX echo '# Automatically generated by pb2py' >> $INDEX -echo 'from __future__ import absolute_import' >> $INDEX echo '' >> $INDEX for i in types messages storage ; do @@ -24,7 +23,7 @@ sed -i 's/^import types_pb2/from . import types_pb2/g' $CURDIR/pb2/storage_pb2.p for i in types messages storage ; do # Convert google protobuf library to trezor's internal format cd $CURDIR - ./pb2py -p $CURDIR -l $INDEX $i ../trezorlib/messages/ + ./pb2py -P "trezorlib.protobuf" -p $CURDIR -l $INDEX $i ../trezorlib/messages/ done rm -rf $CURDIR/pb2/ diff --git a/tools/pb2py b/tools/pb2py index c897af4712..8df4abcad0 100755 --- a/tools/pb2py +++ b/tools/pb2py @@ -7,69 +7,73 @@ import importlib import os import argparse -from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper - def import_pb2(name): return importlib.import_module("pb2.%s_pb2" % name) -def process_type(t, cls, msg_id, indexfile, is_upy): - print(" * type %s" % t) +def create_message_import(name): + return "from .%s import %s" % (name, name) + + +def create_const(name, value, is_upy): + if is_upy: + return "%s = const(%s)" % (name, value) + else: + return "%s = %s" % (name, value) + + +def remove_from_start(s, prefix): + if s.startswith(prefix): + return s[len(prefix):] + else: + return s + + +def process_message(descriptor, protobuf_module, msg_id, indexfile, is_upy): + print(" * type %s" % descriptor.name) imports = [] - out = ["", "", "class %s(p.MessageType):" % t, ] + out = ["", "", "class %s(p.MessageType):" % descriptor.name, ] - if cls.DESCRIPTOR.fields_by_name: + if descriptor.fields_by_number: out.append(" FIELDS = {") elif msg_id is None: out.append(" pass") - for v in sorted(cls.DESCRIPTOR.fields_by_name.values(), key=lambda x: x.number): - number = v.number - fieldname = v.name - type = None - repeated = v.label == 3 - required = v.label == 2 + for number, field in descriptor.fields_by_number.items(): + field_name = field.name + field_type = None + repeated = (field.label == field.LABEL_REPEATED) + required = (field.label == field.LABEL_REQUIRED) - # print v.has_default_value, v.default_value - - if v.type in (4, 13, 14): - # TYPE_UINT64 = 4 - # TYPE_UINT32 = 13 - # TYPE_ENUM = 14 - type = 'p.UVarintType' - - elif v.type in (17,): - # TYPE_SINT32 = 17 - type = 'p.Sint32Type' - - elif v.type == 9: - # TYPE_STRING = 9 - type = 'p.UnicodeType' - - elif v.type == 8: - # TYPE_BOOL = 8 - type = 'p.BoolType' - - elif v.type == 12: - # TYPE_BYTES = 12 - type = 'p.BytesType' - - elif v.type == 11: - # TYPE_MESSAGE = 1 - type = v.message_type.name - imports.append("from .%s import %s" % - (v.message_type.name, v.message_type.name)) + types = { + field.TYPE_UINT64: 'p.UVarintType', + field.TYPE_UINT32: 'p.UVarintType', + field.TYPE_ENUM: 'p.UVarintType', + field.TYPE_SINT32: 'p.Sint32Type', + field.TYPE_STRING: 'p.UnicodeType', + field.TYPE_BOOL: 'p.BoolType', + field.TYPE_BYTES: 'p.BytesType' + } + if field.type == field.TYPE_MESSAGE: + field_type = field.message_type.name + imports.append(create_message_import(field_type)) else: - raise Exception("Unknown field type %s for field %s" % - (v.type, fieldname)) + try: + field_type = types[field.type] + except KeyError: + raise ValueError("Unknown field type %d for field %s" % (field.type, field_name)) + comments = [] if required: - comment = ' # required' - elif v.has_default_value: - comment = ' # default=%s' % repr(v.default_value) + comments.append('required') + if field.has_default_value: + comments.append("default=%s" % repr(field.default_value)) + + if comments: + comment = " # %s" % ' '.join(comments) else: comment = '' @@ -79,101 +83,71 @@ def process_type(t, cls, msg_id, indexfile, is_upy): flags = '0' out.append(" %d: ('%s', %s, %s),%s" % - (number, fieldname, type, flags, comment)) + (number, field_name, field_type, flags, comment)) - # print fieldname, number, type, repeated, comment - # print v.__dict__ - # print v.CPPTYPE_STRING - # print v.LABEL_REPEATED - # print v.enum_type - # v.has_default_value, v.default_value - # v.label == 3 # repeated - # print v.number - - if cls.DESCRIPTOR.fields_by_name: + if descriptor.fields_by_name: out.append(" }") if msg_id is not None: out.append(" MESSAGE_WIRE_TYPE = %d" % msg_id) if indexfile is not None: - if is_upy: - indexfile.write("%s = const(%d)\n" % (t, msg_id)) - else: - indexfile.write("%s = %d\n" % (t, msg_id)) + indexfile.write(create_const(t, msg_id, is_upy)) # Remove duplicate imports - imports = list(set(imports)) - - if is_upy: - imports = ['import protobuf as p'] + imports - else: - imports = ['from __future__ import absolute_import', - 'from .. import protobuf as p'] + imports - + imports = ["import %s as p" % protobuf_module, *set(imports)] return imports + out -def process_enum(t, cls, is_upy): +def process_enum(descriptor, is_upy): out = [] if is_upy: out += ("from micropython import const", "") - print(" * enum %s" % t) + print(" * enum %s" % descriptor.name) - for k, v in cls.items(): + for name, value in descriptor.values_by_name.items(): # Remove type name from the beginning of the constant # For example "PinMatrixRequestType_Current" -> "Current" - if k.startswith("%s_" % t): - k = k.replace("%s_" % t, '') + enum_prefix = descriptor.name + name = remove_from_start(name, "%s_" % enum_prefix) # If type ends with *Type, but constant use type name without *Type, remove it too :) # For example "ButtonRequestType & ButtonRequest_Other" => "Other" - if t.endswith("Type") and k.startswith("%s_" % t.replace("Type", '')): - k = k.replace("%s_" % t.replace("Type", ''), '') + if enum_prefix.endswith("Type"): + enum_prefix, _ = enum_prefix.rsplit("Type", 1) + name = remove_from_start(name, "%s_" % enum_prefix) - if is_upy: - out.append("%s = const(%s)" % (k, v)) - else: - out.append("%s = %s" % (k, v)) + out.append(create_const(name, value.number, is_upy)) return out -def find_msg_type(msg_types, t): - for k, v in msg_types: - msg_name = k.replace('MessageType_', '') - if msg_name == t: - return v +def process_file(descriptor, protobuf_module, genpath, indexfile, modlist, is_upy): + print("Processing module %s" % descriptor.name) -def process_module(mod, genpath, indexfile, modlist, is_upy): + msg_types = import_pb2('messages').MessageType - print("Processing module %s" % mod.__name__) - types = dict([(name, cls) - for name, cls in mod.__dict__.items() if isinstance(cls, type)]) - - msg_types = import_pb2('messages').MessageType.items() - - for t, cls in sorted(types.items()): + for name, message_descriptor in descriptor.message_types_by_name.items(): # Find message type for given class - msg_id = find_msg_type(msg_types, t) + try: + msg_id = msg_types.Value("MessageType_%s" % name) + except ValueError: + msg_id = None - out = process_type(t, cls, msg_id, indexfile, is_upy) + out = process_message(message_descriptor, protobuf_module, msg_id, indexfile, is_upy) - write_to_file(genpath, t, out) + write_to_file(genpath, name, out) if modlist: - modlist.write("from .%s import *\n" % t) + modlist.write(create_message_import(name) + "\n") - enums = dict([(name, cls) for name, cls in mod.__dict__.items() - if isinstance(cls, EnumTypeWrapper)]) - - for t, cls in enums.items(): - out = process_enum(t, cls, is_upy) - write_to_file(genpath, t, out) + for name, enum_descriptor in descriptor.enum_types_by_name.items(): + out = process_enum(enum_descriptor, is_upy) + write_to_file(genpath, name, out) if modlist: - modlist.write("from . import %s\n" % t) + modlist.write("from . import %s\n" % name) def write_to_file(genpath, t, out): @@ -189,8 +163,9 @@ def write_to_file(genpath, t, out): if __name__ == '__main__': parser = argparse.ArgumentParser() - parser.add_argument('modulename', type=import_pb2, help="Name of module to generate") + parser.add_argument('module', type=import_pb2, help="Name of module to generate") parser.add_argument('genpath', type=str, help="Directory for generated source code") + parser.add_argument('-P', '--protobuf-module', default="protobuf", help="Name of protobuf module") parser.add_argument('-i', '--indexfile', type=argparse.FileType('a'), help="Generate index file of wire types") parser.add_argument('-l', '--modlist', type=argparse.FileType('a'), help="Generate list of modules") parser.add_argument('-p', '--protopath', type=str, help="Path to search for pregenerated Google's python sources") @@ -200,4 +175,4 @@ if __name__ == '__main__': if args.protopath: sys.path.append(args.protopath) - process_module(args.modulename, args.genpath, args.indexfile, args.modlist, args.micropython) + process_file(args.module.DESCRIPTOR, args.protobuf_module, args.genpath, args.indexfile, args.modlist, args.micropython)