diff --git a/src/lib/__init__.py b/src/lib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/lib/protobuf/protobuf.py b/src/lib/protobuf/protobuf.py index 88fa614469..d027a0c34f 100644 --- a/src/lib/protobuf/protobuf.py +++ b/src/lib/protobuf/protobuf.py @@ -76,7 +76,6 @@ FLAG_REQUIRED = 1 FLAG_REQUIRED_MASK = 1 FLAG_SINGLE = 0 FLAG_REPEATED = 2 -FLAG_PACKED_REPEATED = 6 FLAG_REPEATED_MASK = 6 FLAG_PRIMITIVE = 0 FLAG_EMBEDDED = 8 @@ -110,12 +109,15 @@ class MessageType: # Creates a new message type. self.__tags_to_types = dict() # Maps a tag to a type instance. self.__tags_to_names = dict() # Maps a tag to a given field name. + self.__defaults = dict() # Maps a tag to its default value. self.__flags = dict() # Maps a tag to FLAG_ - def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE): + def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None): # Adds a field to the message type. if tag in self.__tags_to_names or tag in self.__tags_to_types: raise ValueError('The tag %s is already used.' % tag) + if default != None: + self.__defaults[tag] = default self.__tags_to_names[tag] = name self.__tags_to_types[tag] = field_type self.__flags[tag] = flags @@ -137,19 +139,12 @@ class MessageType: if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK): # Single value. UVarintType.dump(fp, _pack_key(tag, field_type.WIRE_TYPE)) - field_type.dump(fp, value.__dict__[self.__tags_to_names[tag]]) - elif self.__has_flag(tag, FLAG_PACKED_REPEATED, FLAG_REPEATED_MASK): - # Repeated packed value. - UVarintType.dump(fp, _pack_key(tag, BytesType.WIRE_TYPE)) - internal_fp = BytesIO() - for single_value in value[self.__tags_to_names[tag]]: - field_type.dump(internal_fp, single_value) - BytesType.dump(fp, internal_fp.getvalue()) + field_type.dump(fp, getattr(value, self.__tags_to_names[tag])) elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK): # Repeated value. key = _pack_key(tag, field_type.WIRE_TYPE) # Put it together sequently. - for single_value in value[self.__tags_to_names[tag]]: + for single_value in getattr(value, self.__tags_to_names[tag]): UVarintType.dump(fp, key) field_type.dump(fp, single_value) elif self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK): @@ -163,30 +158,18 @@ class MessageType: if tag in self.__tags_to_types: field_type = self.__tags_to_types[tag] - if not self.__has_flag(tag, FLAG_PACKED_REPEATED, FLAG_REPEATED_MASK): - if wire_type != field_type.WIRE_TYPE: - raise TypeError( - 'Value of tag %s has incorrect wiretype %s, %s expected.' % \ - (tag, wire_type, field_type.WIRE_TYPE)) - elif wire_type != BytesType.WIRE_TYPE: - raise TypeError('Tag %s has wiretype %s while the field is packed repeated.' % (tag, wire_type)) + if wire_type != field_type.WIRE_TYPE: + raise TypeError( + 'Value of tag %s has incorrect wiretype %s, %s expected.' % \ + (tag, wire_type, field_type.WIRE_TYPE)) if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK): # Single value. setattr(message, self.__tags_to_names[tag], field_type.load(fp)) - elif self.__has_flag(tag, FLAG_PACKED_REPEATED, FLAG_REPEATED_MASK): - # Repeated packed value. - repeated_value = message[self.__tags_to_names[tag]] = list() - internal_fp = EofWrapper(fp, UVarintType.load(fp)) # Limit with value length. - while True: - try: - repeated_value.append(field_type.load(internal_fp)) - except EOFError: - break elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK): # Repeated value. - if not self.__tags_to_names[tag] in message: - repeated_value = message[self.__tags_to_names[tag]] = list() - repeated_value.append(field_type.load(fp)) + if not self.__tags_to_names[tag] in message.__dict__: + setattr(message, self.__tags_to_names[tag], list()) + getattr(message, self.__tags_to_names[tag]).append(field_type.load(fp)) else: # Skip this field. @@ -194,11 +177,15 @@ class MessageType: {0: UVarintType, 2: BytesType}[wire_type].load(fp) except EOFError: - # Check if all required fields are present. for tag, name in iter(self.__tags_to_names.items()): - if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message: + # Fill in default value if value not set + if name not in message.__dict__ and tag in self.__defaults: + setattr(message, name, self.__defaults[tag]) + + # Check if all required fields are present. + if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message.__dict__: if self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK): - message[name] = list() # Empty list (no values was in input stream). But required field. + setattr(message, name, list()) # Empty list (no values was in input stream). But required field. else: raise ValueError('The field %s (\'%s\') is required but missing.' % (tag, name)) return message diff --git a/src/playground/__init__.py b/src/playground/__init__.py index 41a1d98a06..8da9b82f30 100644 --- a/src/playground/__init__.py +++ b/src/playground/__init__.py @@ -15,15 +15,16 @@ from uasyncio import core from trezor import ui from trezor import msg +from trezor.utils import unimport logging.basicConfig(level=logging.INFO) loop = core.get_event_loop() def perf_info(): - mem_free = gc.mem_free() + mem_alloc = gc.mem_alloc() gc.collect() - print("free_mem: %s/%s, last_sleep: %.06f" % \ - (mem_free, gc.mem_free(), loop.last_sleep)) + print("mem_alloc: %s/%s, last_sleep: %.06f" % \ + (mem_alloc, gc.mem_alloc(), loop.last_sleep)) loop.call_later(1, perf_info) def animate(): @@ -100,10 +101,31 @@ def on_read(): print("READY TO READ") print(msg.read()) +@unimport +def zprava(): + from _io import BytesIO + + from trezor.messages.GetAddress import GetAddress + + m = GetAddress() + m.address_n = [1, 2, 3] + m.show_display = True + + print(m.__dict__) + f = BytesIO() + m.dump(f) + data = f.getvalue() + f.close() + print(data) + # m2 = GetAddress.load(BytesIO(data)) + # print(m2.__dict__) + def run(): # pipe.init('../pipe', on_read) # msg.set_notify(on_read) + zprava() + loop.call_soon(perf_info) loop.call_soon(tap_to_confirm()) # loop.call_soon(animate()) diff --git a/src/trezor/messages/__init__.py b/src/trezor/messages/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/trezor/utils.py b/src/trezor/utils.py index 99b828962e..115087a6cc 100644 --- a/src/trezor/utils.py +++ b/src/trezor/utils.py @@ -1,5 +1,12 @@ -def hexlify(data: bytes) -> str: - return ''.join(['%02x' % b for b in data]) +import sys +import gc -def unhexlify(data: str) -> bytes: - return bytes([int(data[i:i+2], 16) for i in range(0, len(data), 2)]) +def unimport(func): + def inner(*args, **kwargs): + mods = set(sys.modules) + ret = func(*args, **kwargs) + for to_remove in set(sys.modules) - mods: + print(to_remove) + del sys.modules[to_remove] + return ret + return inner diff --git a/tools/build_pb.sh b/tools/build_pb.sh new file mode 100755 index 0000000000..73962a1d67 --- /dev/null +++ b/tools/build_pb.sh @@ -0,0 +1,15 @@ +#!/bin/bash +CURDIR=$(pwd) + + +for i in messages types storage ; do + + # Compile .proto files to python2 modules using google protobuf library + cd $CURDIR/../../trezor-common/protob + protoc --python_out=$CURDIR/pb2/ -I/usr/include -I. $i.proto + + # Convert google protobuf library to trezor's internal format + cd $CURDIR + ./pb2py $i ../src/trezor/messages/ +done + diff --git a/tools/pb2/__init__.py b/tools/pb2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tools/pb2py b/tools/pb2py new file mode 100755 index 0000000000..fc9573b268 --- /dev/null +++ b/tools/pb2py @@ -0,0 +1,140 @@ +#!/usr/bin/env python2 + +import sys +import os + +from google.protobuf.internal.enum_type_wrapper import EnumTypeWrapper + +def process_type(t, cls): + imports = ["from protobuf import protobuf as p",] + + out = ["t = p.MessageType()", ] + + print("Processing type %s" % t) + + TYPE_STRING = 9 + TYPE_BYTES = 12 + + TYPE_MESSAGE = 11 + + for k, v in cls.DESCRIPTOR.fields_by_name.items(): + + #print k + + number = v.number + fieldname = k + type = None + repeated = v.label == 3 + required = v.label == 2 + + #print v.has_default_value, v.default_value + + if v.type in (4, 13, 14): + # TYPE_UINT64 = 4 + # TYPE_UINT32 = 13 + # TYPE_ENUM = 14 + type = 'p.UVarintType' + + elif v.type == 9: + # TYPE_STRING = 9 + type = 'p.UnicodeType' + + elif v.type == 8: + # TYPE_BOOL = 8 + type = 'p.BoolType' + + elif v.type == 12: + # TYPE_BYTES = 12 + type = 'p.BytesType' + + elif v.type == 11: + # TYPE_MESSAGE = 1 + type = "p.EmbeddedMessage(%s)" % v.message_type.name + imports.append("from .%s import %s" % (v.message_type.name, v.message_type.name)) + + else: + raise Exception("Unknown field type %s for field %s" % (v.type, k)) + + if repeated: + flags = ', flags=p.FLAG_REPEATED' + elif required: + flags = ', flags=p.FLAG_REQUIRED' + else: + flags = '' + + if v.has_default_value: + default = ', default=%s' % repr(v.default_value) + else: + default = '' + + out.append("t.add_field(%d, '%s', %s%s%s)" % \ + (number, fieldname, type, flags, default)) + + #print fieldname, number, type, repeated, default + #print v.__dict__ + #print v.CPPTYPE_STRING + #print v.LABEL_REPEATED + #print v.enum_type + # v.has_default_value, v.default_value + # v.label == 3 # repeated + #print v.number + + out.append("%s = t" % t) + return imports + out + +def process_enum(t, cls): + out = [] + + print("Processing enum %s" % t) + + for k, v in cls.items(): + # Remove type name from the beginning of the constant + # For example "PinMatrixRequestType_Current" -> "Current" + if k.startswith("%s_" % t): + k = k.replace("%s_" % t, '') + + # If type ends with *Type, but constant use type name without *Type, remove it too :) + # For example "ButtonRequestType & ButtonRequest_Other" => "Other" + if t.endswith("Type") and k.startswith("%s_" % t.replace("Type", '')): + k = k.replace("%s_" % t.replace("Type", ''), '') + + out.append("%s = %s" % (k, v)) + + return out + +def process_module(mod, genpath): + types = dict([(name, cls) for name, cls in mod.__dict__.items() if isinstance(cls, type)]) + + for t, cls in types.iteritems(): + out = process_type(t, cls) + write_to_file(genpath, t, out) + + enums = dict([(name, cls) for name, cls in mod.__dict__.items() if isinstance(cls, EnumTypeWrapper)]) + + for t, cls in enums.iteritems(): + out = process_enum(t, cls) + write_to_file(genpath, t, out) + +def write_to_file(genpath, t, out): + # Write generated sourcecode to given file + f = open(os.path.join(genpath, "%s.py" % t), 'w') + out = ["# Automatically generated by ./pb2py"] + out + + data = "\n".join(out) + + f.write(data) + f.close() + +if __name__ == '__main__': + if len(sys.argv) < 2: + print("Usage: ./pb2py modulename genpath") + sys.exit() + + modulename = sys.argv[1] + genpath = sys.argv[2] + + # Dynamically load module from argv[1] + tmp = __import__('pb2', globals(), locals(), ['%s_pb2' % modulename]) + mod = getattr(tmp, "%s_pb2" % modulename) + + process_module(mod, genpath)