diff --git a/common/protob/pb2py b/common/protob/pb2py index 56899f8dc..da2594d0b 100755 --- a/common/protob/pb2py +++ b/common/protob/pb2py @@ -2,36 +2,68 @@ # Converts Google's protobuf python definitions of Trezor wire messages # to plain-python objects as used in Trezor Core and python-trezor -import argparse -import importlib +import itertools import logging import os +import re import shutil import subprocess import sys import tempfile -from collections import namedtuple, defaultdict +from pathlib import Path + +from typing import List, Optional import attr +import click import construct as c +import mako +import mako.template from google.protobuf import descriptor_pb2 +FieldDescriptor = descriptor_pb2.FieldDescriptorProto -AUTO_HEADER = "# Automatically generated by pb2py\n" +FIELD_TYPES_PYTHON = { + FieldDescriptor.TYPE_UINT64: "int", + FieldDescriptor.TYPE_UINT32: "int", + FieldDescriptor.TYPE_SINT64: "int", + FieldDescriptor.TYPE_SINT32: "int", + FieldDescriptor.TYPE_BOOL: "bool", + FieldDescriptor.TYPE_BYTES: "bytes", + FieldDescriptor.TYPE_STRING: "str", +} -# fmt: off -FIELD_TYPES = { - descriptor_pb2.FieldDescriptorProto.TYPE_UINT64: ('p.UVarintType', 'int'), - descriptor_pb2.FieldDescriptorProto.TYPE_UINT32: ('p.UVarintType', 'int'), -# descriptor_pb2.FieldDescriptorProto.TYPE_ENUM: ('p.UVarintType', 'int'), - descriptor_pb2.FieldDescriptorProto.TYPE_SINT32: ('p.SVarintType', 'int'), - descriptor_pb2.FieldDescriptorProto.TYPE_SINT64: ('p.SVarintType', 'int'), - descriptor_pb2.FieldDescriptorProto.TYPE_STRING: ('p.UnicodeType', 'str'), - descriptor_pb2.FieldDescriptorProto.TYPE_BOOL: ('p.BoolType', 'bool'), - descriptor_pb2.FieldDescriptorProto.TYPE_BYTES: ('p.BytesType', 'bytes'), +TYPE_NAMES = { + FieldDescriptor.TYPE_UINT64: "uint64", + FieldDescriptor.TYPE_UINT32: "uint32", + FieldDescriptor.TYPE_SINT64: "sint64", + FieldDescriptor.TYPE_SINT32: "sint32", + FieldDescriptor.TYPE_BOOL: "bool", + FieldDescriptor.TYPE_BYTES: "bytes", + FieldDescriptor.TYPE_STRING: "string", } -# fmt: on + +FIELD_TYPES_RUST_BLOB = { + FieldDescriptor.TYPE_UINT64: 0, + FieldDescriptor.TYPE_UINT32: 0, + FieldDescriptor.TYPE_SINT64: 1, + FieldDescriptor.TYPE_SINT32: 1, + FieldDescriptor.TYPE_BOOL: 2, + FieldDescriptor.TYPE_BYTES: 3, + FieldDescriptor.TYPE_STRING: 4, + FieldDescriptor.TYPE_ENUM: 5, + FieldDescriptor.TYPE_MESSAGE: 6, +} + +INT_TYPES = ( + FieldDescriptor.TYPE_UINT64, + FieldDescriptor.TYPE_UINT32, + FieldDescriptor.TYPE_SINT64, + FieldDescriptor.TYPE_SINT32, +) + +MESSAGE_TYPE_ENUM = "MessageType" ListOfSimpleValues = c.GreedyRange( c.Struct( @@ -54,169 +86,350 @@ if not PROTOC: print("protoc command not found") sys.exit(1) -PROTOC_PREFIX = os.path.dirname(os.path.dirname(PROTOC)) -PROTOC_INCLUDE = os.path.join(PROTOC_PREFIX, "include") +PROTOC_PREFIX = Path(PROTOC).resolve().parent.parent + + +ENUM_ENTRY = c.PrefixedArray(c.Byte, c.Int16ul) + +FIELD_STRUCT = c.Struct( + "tag" / c.Byte, + "flags_and_type" + / c.BitStruct( + "is_required" / c.Flag, + "is_repeated" / c.Flag, + "is_experimental" / c.Flag, + c.Padding(1), + "type" / c.BitsInteger(4), + ), + "enum_or_msg_offset" / c.Int16ul, + "name" / c.Int16ul, +) + +MSG_ENTRY = c.Struct( + "fields_count" / c.Rebuild(c.Byte, c.len_(c.this.fields)), + "defaults_size" / c.Rebuild(c.Byte, c.len_(c.this.defaults)), + # highest bit = is_experimental + # the rest = wire_id, 0x7FFF iff unset + "flags_and_wire_type" / c.Int16ul, + "fields" / c.Array(c.this.fields_count, FIELD_STRUCT), + "defaults" / c.Bytes(c.this.defaults_size), +) + +DEFAULT_VARINT_ENTRY = c.Sequence(c.Byte, c.VarInt) +DEFAULT_LENGTH_ENTRY = c.Sequence(c.Byte, c.Prefixed(c.VarInt, c.GreedyRange(c.Byte))) + +NAME_ENTRY = c.Sequence( + "msg_name" / c.Int16ul, + "msg_offset" / c.Int16ul, +) + +WIRETYPE_ENTRY = c.Sequence( + "wire_id" / c.Int16ul, + "msg_offset" / c.Int16ul, +) + +# QDEF(MP_QSTR_copysign, (const byte*)"\x33\x14\x08" "copysign") +QDEF_RE = re.compile( + r'^QDEF\(MP_QSTR(\S+), \(const byte\*\)"(\\x..\\x..\\x..)" "(.*)"\)$' +) -@attr.s +@attr.s(auto_attribs=True) class ProtoField: - name = attr.ib() - number = attr.ib() - orig = attr.ib() - repeated = attr.ib() - required = attr.ib() - experimental = attr.ib() - type_name = attr.ib() - proto_type = attr.ib() - py_type = attr.ib() - default_value = attr.ib() + name: str + number: int + type: object + extensions: dict + orig: object + # name of type without package path + type_name: str + + descriptor: "Descriptor" + + @property + def repeated(self): + return self.orig.label == FieldDescriptor.LABEL_REPEATED + + @property + def required(self): + return self.orig.label == FieldDescriptor.LABEL_REQUIRED @property def optional(self): return not self.required and not self.repeated + @property + def experimental(self): + return bool(self.extensions.get("experimental")) + + @property + def is_message(self): + return self.type == FieldDescriptor.TYPE_MESSAGE + + @property + def is_enum(self): + return self.type == FieldDescriptor.TYPE_ENUM + + @property + def python_type(self): + return FIELD_TYPES_PYTHON.get(self.type, self.type_name) + + @property + def default_value(self): + if not self.orig.HasField("default_value"): + return None + return self.orig.default_value + + @property + def default_value_repr(self): + if self.default_value is None: + return "None" + + elif self.is_enum: + selected_enum_value = strip_enum_prefix(self.type_name, self.default_value) + return f"{self.type_name}.{selected_enum_value}" + + elif self.type == FieldDescriptor.TYPE_STRING: + return repr(self.default_value) + elif self.type == FieldDescriptor.TYPE_BYTES: + return "b" + repr(self.default_value) + elif self.type == FieldDescriptor.TYPE_BOOL: + return "True" if self.default_value == "true" else "False" + else: + return str(self.default_value) + + @property + def type_object(self): + if self.is_enum: + return find_by_name(self.descriptor.enums, self.type_name) + if self.is_message: + return find_by_name(self.descriptor.messages, self.type_name) + return None + @classmethod def from_field(cls, descriptor, field): - repeated = field.label == field.LABEL_REPEATED - required = field.label == field.LABEL_REQUIRED - experimental = bool(descriptor._get_extension(field, "experimental")) - # ignore package path - type_name = field.type_name.rsplit(".")[-1] - - if field.type == field.TYPE_MESSAGE: - proto_type = py_type = type_name - elif field.type == field.TYPE_ENUM: - value_dict = descriptor.enum_types[type_name] - valuestr = ", ".join(str(v) for v in value_dict.values()) - proto_type = 'p.EnumType("{}", ({},))'.format(type_name, valuestr) - py_type = "EnumType" + type_name + if not field.type_name: + type_name = TYPE_NAMES[field.type] else: - try: - proto_type, py_type = FIELD_TYPES[field.type] - except KeyError: - raise ValueError( - "Unknown field type {} for field {}".format(field.type, field.name) - ) from None - - if not field.HasField("default_value"): - default_value = None - elif field.type == field.TYPE_ENUM: - default_value = str(descriptor.enum_types[type_name][field.default_value]) - elif field.type == field.TYPE_STRING: - default_value = f'"{field.default_value}"' - elif field.type == field.TYPE_BYTES: - default_value = f'b"{field.default_value}"' - elif field.type == field.TYPE_BOOL: - default_value = "True" if field.default_value == "true" else "False" - else: - default_value = field.default_value - + type_name = field.type_name.rsplit(".")[-1] return cls( name=field.name, number=field.number, + type=field.type, orig=field, - repeated=repeated, - required=required, - experimental=experimental, + extensions=descriptor.get_extensions(field), type_name=type_name, - proto_type=proto_type, - py_type=py_type, - default_value=default_value, + descriptor=descriptor, + ) + + +@attr.s(auto_attribs=True) +class ProtoMessage: + name: str + wire_type: Optional[int] + orig: object + extensions: dict + + fields: List[ProtoField] + + @classmethod + def from_message(cls, descriptor: "Descriptor", message): + message_type = find_by_name(descriptor.message_type_enum.value, message.name) + # use extensions set on the message_type entry (if any) + extensions = descriptor.get_extensions(message_type) + # override with extensions set on the message itself + extensions.update(descriptor.get_extensions(message)) + + if "wire_type" in extensions: + wire_type = extensions["wire_type"] + elif message_type is not None: + wire_type = message_type.number + else: + wire_type = None + + return cls( + name=message.name, + wire_type=wire_type, + orig=message, + extensions=extensions, + fields=[ + ProtoField.from_field(descriptor, f) + for f in descriptor._filter_items(message.field) + ], ) -def protoc(files, additional_includes=()): +def protoc(files): """Compile code with protoc and return the data.""" include_dirs = set() - include_dirs.add(PROTOC_INCLUDE) - include_dirs.update(additional_includes) + include_dirs.add(str(PROTOC_PREFIX / "include")) + if "PROTOC_INCLUDE" in os.environ: + include_dirs.add(os.environ["PROTOC_INCLUDE"]) for file in files: - dirname = os.path.dirname(file) or "." - include_dirs.add(dirname) + include_dirs.add(os.path.dirname(file) or ".") protoc_includes = ["-I" + dir for dir in include_dirs if dir] - # Note that we could avoid creating temp files if protoc let us write to stdout - # directly. this is currently only possible on Unix, by passing /dev/stdout as - # the file name. Since there's no direct Windows equivalent, not counting - # being creative with named pipes, special-casing this is not worth the effort. - with tempfile.TemporaryDirectory() as tmpdir: - outfile = os.path.join(tmpdir, "DESCRIPTOR_SET") - subprocess.check_call( - [PROTOC, "--descriptor_set_out={}".format(outfile)] - + protoc_includes - + files - ) - with open(outfile, "rb") as f: - return f.read() + return subprocess.check_output( + [PROTOC, "--descriptor_set_out=/dev/stdout"] + protoc_includes + list(files) + ) -def strip_leader(s, prefix): - """Remove given prefix from underscored name.""" - leader = prefix + "_" - if s.startswith(leader): - return s[len(leader) :] - else: - return s +def strip_enum_prefix(enum_name, value_name): + """Generate stripped-down enum value name, given the enum type name. + There are three kinds of enums in the codebase: -def import_statement_from_path(path): - # separate leading dots - dot_prefix = "" - while path.startswith("."): - dot_prefix += "." - path = path[1:] + (1) New-style: - # split on remaining dots - split_path = path.rsplit(".", maxsplit=1) - leader, import_name = split_path[:-1], split_path[-1] + enum SomeEnum { + First_Value = 1; + SecondValue = 2; + } - if leader: - from_part = dot_prefix + leader - elif dot_prefix: - from_part = dot_prefix - else: - from_part = "" + (2) Old-style without "Type": - if from_part: - return "from {} import {}".format(from_part, import_name) - else: - return "import {}".format(import_name) + enum SomeEnum { + SomeEnum_First_Value = 1; + SomeEnum_SecondValue = 2; + } + + (3) Old-style with "Type": + + enum SomeEnumType { + SomeEnum_First_Value = 1; + SomeEnum_SecondValue = 2; + } + + This function accepts the name of the enum ("SomeEnum") and the name of the value, + and returns the name of the value as it would look in the new-style -- i.e., + for any variation of the above, the values returned would be "First_Value" and + "SecondValue". + """ + leader = enum_name + "_" + if value_name.startswith(leader): + return value_name[len(leader) :] + + if enum_name.endswith("Type"): + leader = enum_name[: -len("Type")] + "_" + if value_name.startswith(leader): + return value_name[len(leader) :] + + return value_name + + +def find_by_name(haystack, name, default=None): + return next((item for item in haystack if item.name == name), default) class Descriptor: - def __init__(self, data, message_type="MessageType", import_path="protobuf"): + def __init__(self, data, include_deprecated: bool, bitcoin_only: bool): self.descriptor = descriptor_pb2.FileDescriptorSet() self.descriptor.ParseFromString(data) + self.include_deprecated = include_deprecated + self.bitcoin_only = bitcoin_only self.files = self.descriptor.file - logging.debug("found {} files".format(len(self.files))) + logging.debug(f"found {len(self.files)} files") + + # collect extensions across all files + # this is required for self._get_extension() to work + self.extensions = { + ext.name: ext.number for file in self.files for ext in file.extension + } + + # find message_type enum + top_level_enums = itertools.chain.from_iterable(f.enum_type for f in self.files) + self.message_type_enum = find_by_name(top_level_enums, MESSAGE_TYPE_ENUM, ()) + self.convert_enum_value_names(self.message_type_enum) + + # top-level message inclusion filter that takes bitcoin_only into account + def should_include_message(message: ProtoMessage): + return ( + # include all messages when not in bitcoin_only mode + not self.bitcoin_only + # include all non-wire messages + or message.wire_type is None + # include messages that are marked bitcoin_only + or message.extensions.get("bitcoin_only") + ) # find messages and enums self.messages = [] self.enums = [] - self.enum_types = defaultdict(dict) - - self.extensions = {} for file in self.files: - self.messages += file.message_type - self.enums += file.enum_type - for message in file.message_type: - self._nested_types_from_message(message) - for extension in file.extension: - self.extensions[extension.name] = extension.number + messages = ( + ProtoMessage.from_message(self, m) + for m in self._filter_items(file.message_type) + ) + # use exclusion list on top-level messages + messages = [m for m in messages if should_include_message(m)] + self.messages += messages + self.enums += self._filter_items(file.enum_type) + + for message in messages: + # recursively search for nested types in newly added messages + self._nested_types_from_message(message.orig) if not self.messages and not self.enums: raise RuntimeError("No messages and no enums found.") - self.message_types = self.find_message_types(message_type) - self.protobuf_import = import_statement_from_path(import_path) + for enum in self.enums: + self.convert_enum_value_names(enum) + + self.depsort_messages() + + def depsort_messages(self): + # sort messages according to dependencies + messages_dict = {m.name: m for m in reversed(self.messages)} + messages_sorted = [] + seen_messages = set() - self.out_dir = None + # pop first message + _, message = messages_dict.popitem() + + while True: + dependent_type = next( + ( + f.type_name + for f in message.fields + if f.is_message and f.type_name not in seen_messages + ), + None, + ) + if dependent_type: + # return current message to unprocessed + messages_dict[message.name] = message + # pop the dependency + message = messages_dict.pop(dependent_type) + else: + # see message + seen_messages.add(message.name) + messages_sorted.append(message) + if not messages_dict: + break + else: + _, message = messages_dict.popitem() + + assert len(messages_sorted) == len(self.messages) + self.messages[:] = messages_sorted + + def _filter_items(self, iter): + return [ + item + for item in iter + # exclude deprecated items unless specified + if (self.include_deprecated or not item.options.deprecated) + ] def _get_extension(self, something, extension_name, default=None): + if something is None: + return default + if extension_name not in self.extensions: + return default + # There doesn't seem to be a sane way to access extensions on a descriptor # via the google.protobuf API. # We do have access to descriptors of the extensions... @@ -230,265 +443,359 @@ class Descriptor: # ...and extract the value corresponding to the extension we care about. return simple_values.get(extension_num, default) + def get_extensions(self, something): + return { + extension: self._get_extension(something, extension) + for extension in self.extensions + if self._get_extension(something, extension) is not None + } + def _nested_types_from_message(self, message): - self.messages += message.nested_type - self.enums += message.enum_type - for nested in message.nested_type: - self._nested_types_from_message(nested) - - def find_message_types(self, message_type): - message_types = {} - try: - message_type_enum = next(e for e in self.enums if e.name == message_type) - for value in message_type_enum.value: - name = strip_leader(value.name, message_type) - message_types[name] = value.number - - except StopIteration: - # No message type found. Oh well. - logging.warning( - "Message IDs not found under '{}'".format(args.message_type) - ) + nested_messages = [ + ProtoMessage.from_message(self, m) + for m in self._filter_items(message.nested_type) + ] + self.messages += nested_messages + self.enums += self._filter_items(message.enum_type) + for nested in nested_messages: + self._nested_types_from_message(nested.orig) + + def convert_enum_value_names(self, enum): + for value in enum.value: + value.name = strip_enum_prefix(enum.name, value.name) - return message_types - def create_message_import(self, name): - return "from .{0} import {0}".format(name) +class PythonRenderer: + def __init__(self, descriptor: Descriptor, out_dir="", python_extension="py"): + self.descriptor = descriptor + self.out_dir = Path(out_dir) + self.python_extension = python_extension - def process_subtype_imports(self, fields): - imports = set( - field.proto_type - for field in fields - if field.orig.type == field.orig.TYPE_MESSAGE + def process_message(self, template, message): + logging.debug(f"Processing message {message.name}") + return template.render(message=message) + + def process_enum(self, template, enum): + logging.debug(f"Processing enum {enum.name}") + + all_values = self.descriptor._filter_items(enum.value) + + has_bitcoin_only_values = self.descriptor._get_extension( + enum, "has_bitcoin_only_values" ) + if has_bitcoin_only_values: + values_always = [ + v + for v in all_values + if self.descriptor._get_extension(v, "bitcoin_only") + ] + values_altcoin = [v for v in all_values if v not in values_always] - if len(imports) > 0: - yield "" # make isort happy - for name in sorted(imports): - yield self.create_message_import(name) - - def create_init_method(self, fields): - required_fields = [f for f in fields if f.required] - repeated_fields = [f for f in fields if f.repeated] - optional_fields = [f for f in fields if f.optional] - # please keep the yields aligned - # fmt: off - yield " def __init__(" - yield " self," - yield " *," - for field in required_fields: - yield f" {field.name}: {field.py_type}," - for field in repeated_fields: - yield f" {field.name}: Optional[List[{field.py_type}]] = None," - for field in optional_fields: - if field.default_value is None: - yield f" {field.name}: Optional[{field.py_type}] = None," - else: - yield f" {field.name}: {field.py_type} = {field.default_value}," - yield " ) -> None:" - - for field in repeated_fields: - yield f" self.{field.name} = {field.name} if {field.name} is not None else []" - for field in required_fields + optional_fields: - yield f" self.{field.name} = {field.name}" - # fmt: on - - def create_fields_method(self, fields): - # fmt: off - yield " @classmethod" - yield " def get_fields(cls) -> Dict:" - yield " return {" - for field in fields: - comments = [] - if field.default_value is not None: - comments.append(f"default={field.orig.default_value}") - - if comments: - comment = " # " + " ".join(comments) - else: - comment = "" - - if field.repeated: - if field.default_value is not None: - raise ValueError("Repeated fields can't have default values.") - if field.experimental: - raise ValueError("Repeated experimental fields are currently not supported.") - flags = "p.FLAG_REPEATED" - elif field.required: - if field.default_value is not None: - raise ValueError("Required fields can't have default values.") - if field.experimental: - raise ValueError("Required fields can't be experimental.") - flags = "p.FLAG_REQUIRED" - elif field.experimental: - if field.default_value is not None: - raise ValueError("Experimental fields can't have default values.") - flags = "p.FLAG_EXPERIMENTAL" - else: - flags = field.default_value - - yield " {num}: ('{name}', {type}, {flags}),{comment}".format( - num=field.number, - name=field.name, - type=field.proto_type, - flags=flags, - comment=comment, - ) + else: + values_always = all_values + values_altcoin = [] - yield " }" - # fmt: on + return template.render( + enum=enum, + values_always=values_always, + values_altcoin=values_altcoin, + ) - def process_message(self, message, include_deprecated=False): - logging.debug("Processing message {}".format(message.name)) + def write_to_file(self, item_name, content): + dest = self.out_dir / (item_name + "." + self.python_extension) + dest.write_text(content) + + def generate_messages(self, template_src): + template = mako.template.Template(filename=str(template_src)) + for message in self.descriptor.messages: + self.write_to_file(message.name, self.process_message(template, message)) + + def generate_enums(self, template_src): + template = mako.template.Template(filename=str(template_src)) + for enum in self.descriptor.enums: + self.write_to_file(enum.name, self.process_enum(template, enum)) + + def render_singlefile(self, template_src): + template = mako.template.Template(filename=str(template_src)) + return template.render( + messages=self.descriptor.messages, + enums=self.descriptor.enums, + ) - msg_id = self._get_extension(message, "wire_type") - if msg_id is None: - msg_id = self.message_types.get(message.name) + def generate_python(self): + enum_template = self.out_dir / "_proto_enum_class.mako" + message_template = self.out_dir / "_proto_message_class.mako" + init_template = self.out_dir / "_proto_init.mako" + if enum_template.exists(): + self.generate_enums(enum_template) + if message_template.exists(): + self.generate_messages(message_template) + if init_template.exists(): + init_py = self.render_singlefile(init_template) + self.write_to_file("__init__", init_py) + + +class RustBlobRenderer: + def __init__(self, descriptor: Descriptor, qstr_defs: str = None): + self.descriptor = descriptor + + self.qstr_map = {} + self.enum_map = {} + self.msg_map = {} + + if qstr_defs: + self.build_qstr_map(qstr_defs) + + def write_qstrs(self, qstr_path): + logging.debug(f"Writing qstrings to {qstr_path}") + message_names = {m.name for m in self.descriptor.messages} + field_names = { + f.name for message in self.descriptor.messages for f in message.fields + } + with open(qstr_path, "w") as f: + for name in sorted(message_names | field_names): + f.write(f"Q({name})\n") + + def write_blobs(self, blob_dir): + logging.debug(f"Writing blobs to {blob_dir}") + blob_dir = Path(blob_dir) + + enum_blob = self.build_enums_with_offsets() + + # build msg entries and fill out map + msg_entries = self.build_message_entries() + # fill message offsets + self.fill_enum_or_msg_offsets(msg_entries) + # encode blob + msg_blob = self.build_message_blob(msg_entries) + + name_blob = self.build_blob_names() + wire_blob = self.build_blob_wire() + + (blob_dir / "proto_enums.data").write_bytes(enum_blob) + (blob_dir / "proto_msgs.data").write_bytes(msg_blob) + (blob_dir / "proto_names.data").write_bytes(name_blob) + (blob_dir / "proto_wire.data").write_bytes(wire_blob) + + def build_qstr_map(self, qstr_defs): + # QSTR defs are rolled out into an enum in py/qstr.h, the numeric + # value is simply an incremented integer. + qstr_counter = 0 + with open(qstr_defs, "r") as f: + for line in f: + match = QDEF_RE.match(line) + if not match: + continue + line = match.group(0) + string = match.group(3) + self.qstr_map[string] = qstr_counter + qstr_counter += 1 + logging.debug(f"Found {qstr_counter} Qstr defs") + + def build_enums_with_offsets(self): + enums = [] + cursor = 0 + for enum in sorted(self.descriptor.enums, key=lambda e: e.name): + self.enum_map[enum.name] = cursor + enum_blob = ENUM_ENTRY.build(sorted(v.number for v in enum.value)) + enums.append(enum_blob) + cursor += len(enum_blob) + + return b"".join(enums) + + def encode_flags_and_wire_type(self, message): + wire_type = message.wire_type + if wire_type is None: + wire_type = 0x7FFF + if wire_type > 0x7FFF: + raise ValueError("Unsupported wire type") + + flags_and_wire_type = wire_type + if message.extensions.get("unstable"): + flags_and_wire_type |= 0x8000 + + return flags_and_wire_type + + def encode_field(self, field): + return dict( + tag=field.number, + flags_and_type=dict( + is_required=field.required, + is_repeated=field.repeated, + is_experimental=field.experimental, + type=FIELD_TYPES_RUST_BLOB[field.type], + ), + enum_or_msg_offset=0, + name=self.qstr_map[field.name], + orig_field=field, + ) - unstable = self._get_extension(message, "unstable") + def fill_enum_or_msg_offsets(self, msg_entries): + for msg_dict in msg_entries: + for field_dict in msg_dict["fields"]: + field = field_dict["orig_field"] + if field.is_enum: + field_dict["enum_or_msg_offset"] = self.enum_map[field.type_name] + elif field.is_message: + field_dict["enum_or_msg_offset"] = self.msg_map[field.type_name] + + def build_message_entries(self): + messages = [] + cursor = 0 + + for message in sorted(self.descriptor.messages, key=lambda m: m.name): + self.msg_map[message.name] = cursor + fields = sorted(message.fields, key=lambda f: f.number) + + defaults = b"".join(self.encode_field_default(f) for f in fields) + flags_and_wire_type = self.encode_flags_and_wire_type(message) + entry = dict( + flags_and_wire_type=flags_and_wire_type, + fields=[self.encode_field(f) for f in fields], + defaults=defaults, + ) - # "from .. import protobuf as p" - yield self.protobuf_import + " as p" + messages.append(entry) + cursor += len(MSG_ENTRY.build(entry)) - fields = [ProtoField.from_field(self, field) for field in message.field] - if not include_deprecated: - fields = [field for field in fields if not field.orig.options.deprecated] + return messages - yield from self.process_subtype_imports(fields) + def build_message_blob(self, msg_entries): + return b"".join(MSG_ENTRY.build(entry) for entry in msg_entries) - yield "" - yield "if __debug__:" - yield " try:" - yield " from typing import Dict, List, Optional # noqa: F401" - yield " from typing_extensions import Literal # noqa: F401" + def encode_field_default(self, field): + if field.number > 0xFF: + raise ValueError("Invalid field number") - all_enums = [field for field in fields if field.type_name in self.enum_types] - for field in all_enums: - allowed_values = self.enum_types[field.type_name].values() - valuestr = ", ".join(str(v) for v in sorted(allowed_values)) - yield " {} = Literal[{}]".format(field.py_type, valuestr) + default = field.default_value - yield " except ImportError:" - yield " pass" + if default is None: + return b"" - yield "" - yield "" - yield "class {}(p.MessageType):".format(message.name) + elif field.type in INT_TYPES: + return DEFAULT_VARINT_ENTRY.build((field.number, int(default))) - if msg_id is not None: - yield " MESSAGE_WIRE_TYPE = {}".format(msg_id) + elif field.type == FieldDescriptor.TYPE_BOOL: + return DEFAULT_VARINT_ENTRY.build((field.number, int(default == "True"))) - if unstable is not None: - yield " UNSTABLE = True" + elif field.type == FieldDescriptor.TYPE_BYTES: + if default != "": + raise ValueError( + "Bytes fields can only have empty bytes for default value" + ) + return DEFAULT_LENGTH_ENTRY.build((field.number, b"")) - if fields: - yield "" - yield from self.create_init_method(fields) - yield "" - yield from self.create_fields_method(fields) + elif field.type == FieldDescriptor.TYPE_STRING: + return DEFAULT_LENGTH_ENTRY.build((field.number, default.encode())) - if not fields and not msg_id: - yield " pass" + elif field.is_enum: + # find the right value + value = find_by_name(field.type_object.value, default) + if value is None: + raise ValueError(f"Default not found for field {field.name}") + return DEFAULT_VARINT_ENTRY.build((field.number, value.number)) - def process_enum(self, enum): - logging.debug("Processing enum {}".format(enum.name)) + else: + raise ValueError(f"Cannot encode default value for field {field.name}") + + def build_blob_names(self): + # sorting by Qstr value of the message name + messages = sorted(self.descriptor.messages, key=lambda m: self.qstr_map[m.name]) + return b"".join( + NAME_ENTRY.build((self.qstr_map[m.name], self.msg_map[m.name])) + for m in messages + ) - # file header - yield "if __debug__:" - yield " try:" - yield " from typing_extensions import Literal # noqa: F401" - yield " except ImportError:" - yield " pass" - yield "" + def build_blob_wire(self): + # create wire-type -> message mapping + wire_messages = [m for m in self.descriptor.messages if m.wire_type is not None] + # sorting by wire-type + wire_messages.sort(key=lambda m: m.wire_type) + return b"".join( + WIRETYPE_ENTRY.build((m.wire_type, self.msg_map[m.name])) + for m in wire_messages + ) - for value in enum.value: - # Remove type name from the beginning of the constant - # For example "PinMatrixRequestType_Current" -> "Current" - enum_prefix = enum.name - name = value.name - name = strip_leader(name, enum_prefix) - - # If type ends with *Type, but constant use type name without *Type, remove it too :) - # For example "ButtonRequestType & ButtonRequest_Other" => "Other" - if enum_prefix.endswith("Type"): - enum_prefix, _ = enum_prefix.rsplit("Type", 1) - name = strip_leader(name, enum_prefix) - - self.enum_types[enum.name][value.name] = value.number - yield f"{name}: Literal[{value.number}] = {value.number}" - - def process_messages(self, messages, include_deprecated=False): - for message in sorted(messages, key=lambda m: m.name): - self.write_to_file( - message.name, self.process_message(message, include_deprecated) - ) - def process_enums(self, enums): - for enum in sorted(enums, key=lambda e: e.name): - self.write_to_file(enum.name, self.process_enum(enum)) - - def write_to_file(self, name, out): - # Write generated sourcecode to given file - logging.debug("Writing file {}.py".format(name)) - with open(os.path.join(self.out_dir, name + ".py"), "w") as f: - f.write(AUTO_HEADER) - f.write("# fmt: off\n") - f.write("# isort:skip_file\n") - for line in out: - f.write(line + "\n") - - def write_init_py(self): - filename = os.path.join(self.out_dir, "__init__.py") - with open(filename, "w") as init_py: - init_py.write(AUTO_HEADER) - init_py.write("# fmt: off\n") - init_py.write("# isort:skip_file\n\n") - for message in sorted(self.messages, key=lambda m: m.name): - init_py.write(self.create_message_import(message.name) + "\n") - for enum in sorted(self.enums, key=lambda m: m.name): - init_py.write("from . import {}\n".format(enum.name)) - - def write_classes(self, out_dir, init_py=True, include_deprecated=False): - self.out_dir = out_dir - self.process_enums(self.enums) - self.process_messages(self.messages, include_deprecated) - if init_py: - self.write_init_py() +ReadableFile = click.Path(exists=True, dir_okay=False, readable=True) +WritableFile = click.Path(dir_okay=False, writable=True) +WritableDirectory = click.Path(exists=True, file_okay=False, writable=True) -if __name__ == "__main__": - parser = argparse.ArgumentParser() - # fmt: off - parser.add_argument("proto", nargs="+", help="Protobuf definition files") - parser.add_argument("-o", "--out-dir", help="Directory for generated source code") - parser.add_argument("-P", "--protobuf-module", default="protobuf", help="Name of protobuf module") - parser.add_argument("-l", "--no-init-py", action="store_true", help="Do not generate __init__.py with list of modules") - parser.add_argument("--message-type", default="MessageType", help="Name of enum with message IDs") - parser.add_argument("-I", "--protoc-include", action="append", help="protoc include path") - parser.add_argument("-v", "--verbose", action="store_true", help="Print debug messages") - parser.add_argument("-d", "--include-deprecated", action="store_true", help="Include deprecated fields") - # fmt: on - args = parser.parse_args() - - if args.verbose: +@click.command() +# fmt: off +@click.argument("proto", nargs=-1, type=ReadableFile, required=True) +@click.option("--python-outdir", type=WritableDirectory, help="Output directory for Python classes (contents will be deleted)") +@click.option("--python-extension", default="py", help="Use .pyi to generate type stubs") +@click.option("--outfile", type=WritableFile, help="Output file for single-file generated definitions") +@click.option("--template", type=ReadableFile, help="Template for single-file entry") +@click.option("--blob-outdir", type=WritableDirectory, help="Output directory for protobuf blob files") +@click.option("--qstr-defs", type=ReadableFile, help="Collected Qstr definitions") +@click.option("--qstr-out", type=WritableFile, help="Output Qstr header") +@click.option("-v", "--verbose", is_flag=True) +@click.option("-d", "--include-deprecated", is_flag=True, help="Include deprecated fields, messages and enums") +@click.option("-b", "--bitcoin-only", type=int, default=0, help="Exclude fields, messages and enums that do not belong to bitcoin_only builds") +# fmt: on +def main( + proto, + python_outdir, + python_extension, + outfile, + template, + blob_outdir, + qstr_defs, + qstr_out, + verbose, + include_deprecated, + bitcoin_only, +): + if verbose: logging.basicConfig(level=logging.DEBUG) - protoc_includes = args.protoc_include or (os.environ.get("PROTOC_INCLUDE"),) - descriptor_proto = protoc(args.proto, protoc_includes) - descriptor = Descriptor(descriptor_proto, args.message_type, args.protobuf_module) - - with tempfile.TemporaryDirectory() as tmpdir: - descriptor.write_classes(tmpdir, not args.no_init_py, args.include_deprecated) - - for filename in os.listdir(args.out_dir): - pathname = os.path.join(args.out_dir, filename) - try: - with open(pathname, "r") as f: - if next(f, None) == AUTO_HEADER: - os.unlink(pathname) - except Exception: - pass - - for filename in os.listdir(tmpdir): - src = os.path.join(tmpdir, filename) - shutil.copy(src, args.out_dir) + descriptor_proto = protoc(proto) + descriptor = Descriptor( + descriptor_proto, + include_deprecated=include_deprecated, + bitcoin_only=bitcoin_only, + ) + + if python_outdir: + outdir = Path(python_outdir) + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir_path = Path(tmpdir) + for file in outdir.glob("_proto*.mako"): + shutil.copy(file, tmpdir) + + renderer = PythonRenderer(descriptor, tmpdir_path, python_extension) + renderer.generate_python() + + for file in outdir.glob("*." + python_extension): + if file.name == "__init__." + python_extension: + continue + file.unlink() + + for file in tmpdir_path.iterdir(): + shutil.copy(file, outdir) + + if outfile: + if not template: + raise click.ClickException("Please specify --template") + + renderer = PythonRenderer(descriptor) + with open(outfile, "w") as f: + f.write(renderer.render_singlefile(template)) + + if qstr_out: + renderer = RustBlobRenderer(descriptor) + renderer.write_qstrs(qstr_out) + + if blob_outdir: + if not qstr_defs: + raise click.ClickException("Qstr defs not provided") + + renderer = RustBlobRenderer(descriptor, qstr_defs) + renderer.write_blobs(blob_outdir) + + +if __name__ == "__main__": + main() diff --git a/core/site_scons/site_tools/micropython/__init__.py b/core/site_scons/site_tools/micropython/__init__.py index ea3714645..057a896d6 100644 --- a/core/site_scons/site_tools/micropython/__init__.py +++ b/core/site_scons/site_tools/micropython/__init__.py @@ -25,9 +25,13 @@ def generate(env): source_name = source.replace(env['source_dir'], '') # replace "utils.BITCOIN_ONLY" with literal constant (True/False) # so the compiler can optimize out the things we don't want - btc_only = 'True' if env['bitcoin_only'] == '1' else 'False' + btc_only = env['bitcoin_only'] == '1' interim = "%s.i" % target[:-4] # replace .mpy with .i - return '$SED "s:utils\\.BITCOIN_ONLY:%s:g" %s > %s && $MPY_CROSS -o %s -s %s %s' % (btc_only, source, interim, target, source_name, interim) + sed_scripts = " ".join([ + f"-e 's/utils\.BITCOIN_ONLY/{btc_only}/g'", + "-e 's/if TYPE_CHECKING/if False/'", + ]) + return f'$SED {sed_scripts} {source} > {interim} && $MPY_CROSS -o {target} -s {source_name} {interim}' env['BUILDERS']['FrozenModule'] = SCons.Builder.Builder( generator=generate_frozen_module, diff --git a/core/tools/build_templates b/core/tools/build_templates index 714b17baa..41a54210c 100755 --- a/core/tools/build_templates +++ b/core/tools/build_templates @@ -4,13 +4,11 @@ set -e CWD=`dirname "$0"` RENDER="$CWD/../vendor/trezor-common/tools/cointool.py render" -FIND_TEMPLATES="find $CWD/../src -name *.mako" +FIND_TEMPLATES="find $CWD/../src -name *.mako -not -name _proto*" check_results() { CHECK_FAIL=0 for filename in $($FIND_TEMPLATES); do - # ignore resources.py - if echo $filename | grep -q "resources.py.mako$"; then continue; fi TMP=`mktemp` TARGET="${filename%%.mako}" $RENDER "$filename" -o $TMP diff --git a/tools/build_protobuf b/tools/build_protobuf index 293c49f25..2c6a270f8 100755 --- a/tools/build_protobuf +++ b/tools/build_protobuf @@ -51,55 +51,41 @@ PYTHON_MESSAGES_IGNORE="" RETURN=0 do_rebuild() { - # rebuild protobuf in specified directory - local DESTDIR="$1" + local FILE_OR_DIR="$1" shift - local SOURCES="$1" - shift - local IGNORE="$1" + local OUTPUT="$1" shift - local APPLY_BITCOIN_ONLY="$1" + local SOURCES="$1" shift - mkdir -p "$DESTDIR" - rm -f "$DESTDIR"/[A-Z]*.py - - # note $SOURCES is unquoted - we want wildcard expansion and multiple args - $PROTOB/pb2py "$@" -o "$DESTDIR" $SOURCES - - # TODO: make this less hackish - # maybe introduce attribute "altcoin" in protobuf? - if [ "$APPLY_BITCOIN_ONLY" == "TRUE" ]; then - sed -i "3ifrom trezor import utils\n" "$DESTDIR"/Capability.py - sed -i "3ifrom trezor import utils\n" "$DESTDIR"/MessageType.py - sed -i "/^EthereumGetPublicKey/iif not utils.BITCOIN_ONLY:" "$DESTDIR"/MessageType.py - for altcoin in Ethereum NEM Lisk Tezos Stellar Cardano Ripple Monero DebugMonero Eos Binance WebAuthn; do - sed -i "s:^$altcoin: $altcoin:" "$DESTDIR"/Capability.py - sed -i "s:^$altcoin: $altcoin:" "$DESTDIR"/MessageType.py - done - sed -i "/^Bitcoin_like/iif not utils.BITCOIN_ONLY:" "$DESTDIR"/Capability.py - sed -i "/^EOS/iif not utils.BITCOIN_ONLY:" "$DESTDIR"/Capability.py - for feature in Bitcoin_like EOS U2F; do - sed -i "s:^$feature: $feature:" "$DESTDIR"/Capability.py - done + if [ "$FILE_OR_DIR" == file ]; then + local param="--outfile" + else + local param="--python-outdir" fi - # ENDTODO - # delete unused messages - for F in $IGNORE; do - rm -f "$DESTDIR"/"$F".py - done + # note $SOURCES is unquoted - we want wildcard expansion and multiple args + $PROTOB/pb2py "$@" $param="$OUTPUT" $SOURCES } do_check() { # rebuild protobuf in tmpdir and check result against specified directory local TMPDIR=$(mktemp -d proto-check.XXXXXX) - local DESTDIR="$1" + + local FILE_OR_DIR="$1" + shift + local OUTPUT="$1" shift - cp -rT "$DESTDIR" "$TMPDIR" - do_rebuild "$TMPDIR" "$@" - DIFF=$(diff -ur --exclude __pycache__ "$DESTDIR" "$TMPDIR") + if [ "$FILE_OR_DIR" == file ]; then + local TMPDEST="$TMPDIR/testfile" + else + cp -rT "$OUTPUT" "$TMPDIR" + local TMPDEST="$TMPDIR" + fi + + do_rebuild "$FILE_OR_DIR" "$TMPDEST" "$@" + DIFF=$(diff -ur --exclude __pycache__ "$OUTPUT" "$TMPDEST") rm -r "$TMPDIR" if [ -n "$DIFF" ]; then echo "$DIFF" @@ -113,7 +99,12 @@ else func=do_rebuild fi -$func core/src/trezor/messages "$CORE_PROTOBUF_SOURCES" "$CORE_MESSAGES_IGNORE" TRUE --no-init-py -$func python/src/trezorlib/messages "$PYTHON_PROTOBUF_SOURCES" "$PYTHON_MESSAGES_IGNORE" FALSE --include-deprecated -P ..protobuf +$func dir core/src/trezor/enums "$CORE_PROTOBUF_SOURCES" +$func file core/src/trezor/enums/__init__.py "$CORE_PROTOBUF_SOURCES" --template=core/src/trezor/enums/_proto_init.mako +$func file core/src/trezor/messages.py "$CORE_PROTOBUF_SOURCES" --template=core/src/trezor/_proto_messages.mako + +$func file python/src/trezorlib/messages.py "$PYTHON_PROTOBUF_SOURCES" \ + --template=python/src/trezorlib/_proto_messages.mako \ + --include-deprecated exit $RETURN