1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-10 15:30:55 +00:00
This commit is contained in:
Jan Pochyla 2016-08-05 12:35:45 +02:00 committed by Pavol Rusnak
parent a4d1b27541
commit 455a436123
No known key found for this signature in database
GPG Key ID: 91F3B339B9A02A3D
4 changed files with 65 additions and 40 deletions

View File

@ -1,13 +1,19 @@
from trezor import ui, loop, wire from trezor import ui, wire
from trezor.utils import unimport_gen from trezor.utils import unimport_gen
def ord(n):
return str(n)+("th" if 4<=n%100<=20 else {1:"st",2:"nd",3:"rd"}.get(n%10, "th")) def nth(n):
if 4 <= n % 100 <= 20:
sfx = 'th'
else:
sfx = {1: 'st', 2: 'nd', 3: 'rd'}.get(n % 10, 'th')
return str(n) + sfx
@unimport_gen @unimport_gen
def layout_recovery_device(message): def layout_recovery_device(message):
msg = 'Please enter ' + ord(message.word_count) + ' word' msg = 'Please enter ' + nth(message.word_count) + ' word'
ui.clear() ui.clear()
ui.display.text(10, 30, 'Recovering device', ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) ui.display.text(10, 30, 'Recovering device', ui.BOLD, ui.LIGHT_GREEN, ui.BLACK)

View File

@ -3,9 +3,9 @@
# http://eigenein.me/protobuf/ # http://eigenein.me/protobuf/
from uio import BytesIO from uio import BytesIO
import ustruct
# Types. ----------------------------------------------------------------------- # Types. -----------------------------------------------------------------
class UVarintType: class UVarintType:
# Represents an unsigned Varint type. # Represents an unsigned Varint type.
@ -27,10 +27,10 @@ class UVarintType:
value, shift = value + ((quantum & 0x7F) << shift), shift + 7 value, shift = value + ((quantum & 0x7F) << shift), shift + 7
return value return value
# class UInt32Type(UVarintType): pass
class BoolType: class BoolType:
# Represents a boolean type. Encodes True as UVarint 1, and False as UVarint 0. # Represents a boolean type.
# Encodes True as UVarint 1, and False as UVarint 0.
WIRE_TYPE = 0 WIRE_TYPE = 0
@staticmethod @staticmethod
@ -41,9 +41,9 @@ class BoolType:
def load(fp): def load(fp):
return UVarintType.load(fp) != 0 return UVarintType.load(fp) != 0
class BytesType: class BytesType:
# Represents a raw bytes type. # Represents a raw bytes type.
WIRE_TYPE = 2 WIRE_TYPE = 2
@staticmethod @staticmethod
@ -55,7 +55,9 @@ class BytesType:
def load(fp): def load(fp):
return fp.read(UVarintType.load(fp)) return fp.read(UVarintType.load(fp))
class UnicodeType: class UnicodeType:
# Represents an unicode string type.
WIRE_TYPE = 2 WIRE_TYPE = 2
@staticmethod @staticmethod
@ -66,17 +68,20 @@ class UnicodeType:
def load(fp): def load(fp):
return BytesType.load(fp).decode('utf-8', 'strict') return BytesType.load(fp).decode('utf-8', 'strict')
# Messages. --------------------------------------------------------------------
FLAG_SIMPLE = 0 # Messages. --------------------------------------------------------------
FLAG_REQUIRED = 1
FLAG_REQUIRED_MASK = 1 FLAG_SIMPLE = const(0)
FLAG_SINGLE = 0 FLAG_REQUIRED = const(1)
FLAG_REPEATED = 2 FLAG_REQUIRED_MASK = const(1)
FLAG_REPEATED_MASK = 6 FLAG_SINGLE = const(0)
FLAG_REPEATED = const(2)
FLAG_REPEATED_MASK = const(6)
class EofWrapper: class EofWrapper:
# Wraps a stream to raise EOFError instead of just returning of ''. # Wraps a stream to raise EOFError instead of just returning of ''.
def __init__(self, fp, limit=None): def __init__(self, fp, limit=None):
self.__fp = fp self.__fp = fp
self.__limit = limit self.__limit = limit
@ -91,20 +96,22 @@ class EofWrapper:
raise EOFError() raise EOFError()
return s return s
# Packs a tag and a wire_type into single int according to the protobuf spec. # Packs a tag and a wire_type into single int according to the protobuf spec.
_pack_key = lambda tag, wire_type: (tag << 3) | wire_type _pack_key = lambda tag, wire_type: (tag << 3) | wire_type
# Unpacks a key into a tag and a wire_type according to the protobuf spec. # Unpacks a key into a tag and a wire_type according to the protobuf spec.
_unpack_key = lambda key: (key >> 3, key & 7) _unpack_key = lambda key: (key >> 3, key & 7)
class MessageType: class MessageType:
# Represents a message type. # Represents a message type.
def __init__(self, name=None): def __init__(self, name=None):
# Creates a new message type. # Creates a new message type.
self.__tags_to_types = dict() # Maps a tag to a type instance. self.__tags_to_types = {} # Maps a tag to a type instance.
self.__tags_to_names = dict() # Maps a tag to a given field name. self.__tags_to_names = {} # Maps a tag to a given field name.
self.__defaults = dict() # Maps a tag to its default value. self.__defaults = {} # Maps a tag to its default value.
self.__flags = dict() # Maps a tag to FLAG_ self.__flags = {} # Maps a tag to FLAG_
self.__name = name self.__name = name
def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None): def add_field(self, tag, name, field_type, flags=FLAG_SIMPLE, default=None):
@ -123,7 +130,8 @@ class MessageType:
return Message(self, **fields) return Message(self, **fields)
def __has_flag(self, tag, flag, mask): def __has_flag(self, tag, flag, mask):
# Checks whether the field with the specified tag has the specified flag. # Checks whether the field with the specified tag has the specified
# flag.
return (self.__flags[tag] & mask) == flag return (self.__flags[tag] & mask) == flag
def dump(self, fp, value): def dump(self, fp, value):
@ -134,7 +142,8 @@ class MessageType:
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK): if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
# Single value. # Single value.
UVarintType.dump(fp, _pack_key(tag, field_type.WIRE_TYPE)) UVarintType.dump(fp, _pack_key(tag, field_type.WIRE_TYPE))
field_type.dump(fp, getattr(value, self.__tags_to_names[tag])) field_type.dump(fp, getattr(
value, self.__tags_to_names[tag]))
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK): elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
# Repeated value. # Repeated value.
key = _pack_key(tag, field_type.WIRE_TYPE) key = _pack_key(tag, field_type.WIRE_TYPE)
@ -143,7 +152,8 @@ class MessageType:
UVarintType.dump(fp, key) UVarintType.dump(fp, key)
field_type.dump(fp, single_value) field_type.dump(fp, single_value)
elif self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK): elif self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK):
raise ValueError('The field with the tag %s is required but a value is missing.' % tag) raise ValueError(
'The field with the tag %s is required but a value is missing.' % tag)
def load(self, fp): def load(self, fp):
fp = EofWrapper(fp) fp = EofWrapper(fp)
@ -154,22 +164,25 @@ class MessageType:
if tag in self.__tags_to_types: if tag in self.__tags_to_types:
field_type = self.__tags_to_types[tag] field_type = self.__tags_to_types[tag]
field_name = self.__tags_to_names[tag]
if wire_type != field_type.WIRE_TYPE: if wire_type != field_type.WIRE_TYPE:
raise TypeError( raise TypeError(
'Value of tag %s has incorrect wiretype %s, %s expected.' % \ 'Value of tag %s has incorrect wiretype %s, %s expected.' %
(tag, wire_type, field_type.WIRE_TYPE)) (tag, wire_type, field_type.WIRE_TYPE))
if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK): if self.__has_flag(tag, FLAG_SINGLE, FLAG_REPEATED_MASK):
# Single value. # Single value.
setattr(message, self.__tags_to_names[tag], field_type.load(fp)) setattr(message, field_name, field_type.load(fp))
elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK): elif self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
# Repeated value. # Repeated value.
if not self.__tags_to_names[tag] in message.__dict__: if not field_name in message.__dict__:
setattr(message, self.__tags_to_names[tag], list()) setattr(message, field_name, [])
getattr(message, self.__tags_to_names[tag]).append(field_type.load(fp)) getattr(message, field_name).append(
field_type.load(fp))
else: else:
# Skip this field. # Skip this field.
# This used to correctly determine the length of unknown tags when loading a message. # This used to correctly determine the length of unknown
# tags when loading a message.
{0: UVarintType, 2: BytesType}[wire_type].load(fp) {0: UVarintType, 2: BytesType}[wire_type].load(fp)
except EOFError: except EOFError:
@ -181,9 +194,12 @@ class MessageType:
# Check if all required fields are present. # Check if all required fields are present.
if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message.__dict__: if self.__has_flag(tag, FLAG_REQUIRED, FLAG_REQUIRED_MASK) and not name in message.__dict__:
if self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK): if self.__has_flag(tag, FLAG_REPEATED, FLAG_REPEATED_MASK):
setattr(message, name, list()) # Empty list (no values was in input stream). But required field. # Empty list (no values was in input stream). But
# required field.
setattr(message, name, [])
else: else:
raise ValueError('The field %s (\'%s\') is required but missing.' % (tag, name)) raise ValueError(
'The field %s (\'%s\') is required but missing.' % (tag, name))
return message return message
def dumps(self, value): def dumps(self, value):
@ -198,6 +214,7 @@ class MessageType:
def __repr__(self): def __repr__(self):
return '<MessageType: %s>' % self.__name return '<MessageType: %s>' % self.__name
class Message: class Message:
# Represents a message instance. # Represents a message instance.
@ -219,10 +236,11 @@ class Message:
def __repr__(self): def __repr__(self):
values = self.__dict__ values = self.__dict__
values = {k:values[k] for k in values if k != 'message_type'} values = {k: values[k] for k in values if k != 'message_type'}
return '<%s: %s>' % (self.message_type.__name, values) return '<%s: %s>' % (self.message_type.__name, values)
# Embedded message. ------------------------------------------------------------
# Embedded message. ------------------------------------------------------
class EmbeddedMessage: class EmbeddedMessage:
# Represents an embedded message type. # Represents an embedded message type.
@ -230,7 +248,8 @@ class EmbeddedMessage:
WIRE_TYPE = 2 WIRE_TYPE = 2
def __init__(self, message_type): def __init__(self, message_type):
# Initializes a new instance. The argument is an underlying message type. # Initializes a new instance. The argument is an underlying message
# type.
self.message_type = message_type self.message_type = message_type
def __call__(self): def __call__(self):
@ -241,4 +260,4 @@ class EmbeddedMessage:
BytesType.dump(fp, self.message_type.dumps(value)) BytesType.dump(fp, self.message_type.dumps(value))
def load(self, fp): def load(self, fp):
return self.message_type.load(EofWrapper(fp, UVarintType.load(fp))) # Limit with embedded message length. return self.message_type.load(EofWrapper(fp, UVarintType.load(fp)))

View File

@ -1,5 +1,5 @@
from . import display, in_area, rotate_coords from . import display, in_area, rotate_coords
from trezor import ui, loop, res from trezor import ui, loop
DEFAULT_BUTTON = { DEFAULT_BUTTON = {

View File

@ -46,7 +46,8 @@ def read_wire_msg():
mlen += 4 # Account for the checksum mlen += 4 # Account for the checksum
data = rep[13:][:mlen] # Skip magic and header, trim to data len data = rep[13:][:mlen] # Skip magic and header, trim to data len
remaining = mlen - len(data) remaining = mlen - len(data)
buffered = bytearray(data) if remaining > 0 else data # Avoid the copy if we don't append # Avoid the copy if we don't append
buffered = bytearray(data) if remaining > 0 else data
while remaining > 0: while remaining > 0:
rep = yield from _read_report() rep = yield from _read_report()
@ -71,8 +72,7 @@ def read_wire_msg():
def write_wire_msg(sid, mtype, mbuf): def write_wire_msg(sid, mtype, mbuf):
rep = bytearray(_REPORT_LEN) rep = bytearray(_REPORT_LEN)
rep[0] = _HEADER_MAGIC ustruct.pack_into('>BLLL', rep, 0, _HEADER_MAGIC, sid, mtype, len(mbuf))
ustruct.pack_into('>LLL', rep, 1, sid, mtype, len(mbuf))
rep = memoryview(rep) rep = memoryview(rep)
mbuf = memoryview(mbuf) mbuf = memoryview(mbuf)