diff --git a/src/apps/management/layout_reset_device.py b/src/apps/management/layout_reset_device.py index e25c9a510..ea852ee3a 100644 --- a/src/apps/management/layout_reset_device.py +++ b/src/apps/management/layout_reset_device.py @@ -1,4 +1,5 @@ from trezor import wire, ui +from trezor.workflows.request_pin import request_new_pin from trezor.messages.wire_types import EntropyAck from trezor.ui.button import Button, CONFIRM_BUTTON, CONFIRM_BUTTON_ACTIVE from trezor.ui.scroll import paginate, render_scrollbar, animate_swipe @@ -6,75 +7,72 @@ from trezor.crypto import hashlib, random, bip39 from trezor.utils import unimport, chunks +@unimport async def generate_mnemonic(strength, display_random, session_id): from trezor.messages.EntropyRequest import EntropyRequest + from trezor.messages.FailureType import Other - await wire.write_message(session_id, EntropyRequest()) - ack = await wire.read_message(session_id, EntropyAck) + if strength not in (128, 192, 256): + raise wire.FailureError(Other, 'Invalid seed strength') + # if display_random: + # raise wire.FailureError(Other, 'Entropy display not implemented') + + ack = await wire.reply_message(session_id, + EntropyRequest(), + EntropyAck) + + strength_bytes = strength // 8 ctx = hashlib.sha256() - ctx.update(random.bytes(32)) - ctx.update(ack.entropy) + ctx.update(random.bytes(strength_bytes)) + ctx.update(ack.entropy[:strength_bytes]) entropy = ctx.digest() - # TODO: handle strength - # TODO: handle display_random - return bip39.from_data(entropy) -async def request_new_pin(): - from trezor.workflows.request_pin import request_pin - - pin = await request_pin() - pin_again = await request_pin('Enter PIN again') - - if pin == pin_again: - return pin +async def show_mnemonic_page(page, page_count, mnemonic): + ui.clear() + ui.display.text(10, 30, 'Write down your seed', + ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) + render_scrollbar(page, page_count) + + for pi, (wi, word) in enumerate(mnemonic[page]): + top = pi * 30 + 74 + pos = wi + 1 + ui.display.text_right(40, top, '%d.' % pos, + ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) + ui.display.text(45, top, '%s' % word, + ui.BOLD, ui.WHITE, ui.BLACK) + + if page + 1 == page_count: + await Button((0, 240 - 48, 240, 48), 'Finish', + normal_style=CONFIRM_BUTTON, + active_style=CONFIRM_BUTTON_ACTIVE) else: - raise Exception() # TODO: wrong PIN should be handled in unified way + await animate_swipe() async def show_mnemonic(mnemonic): + first_page = const(0) words_per_page = const(4) - mnemonic_words = list(enumerate(mnemonic.split())) - mnemonic_pages = list(chunks(mnemonic_words, words_per_page)) - - async def render(page, page_count): - - # render header & scrollbar - ui.clear() - ui.display.text(10, 30, 'Write down your seed', - ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) - render_scrollbar(page, page_count) - - # render mnemonic page - for pi, (wi, word) in enumerate(mnemonic_pages[page]): - top = pi * 30 + 74 - pos = wi + 1 - ui.display.text_right(40, top, '%d.' % - pos, ui.BOLD, ui.LIGHT_GREEN, ui.BLACK) - ui.display.text(45, top, '%s' % word, ui.BOLD, ui.WHITE, ui.BLACK) - - if page + 1 == page_count: - await Button((0, 240 - 48, 240, 48), 'Finish', - normal_style=CONFIRM_BUTTON, - active_style=CONFIRM_BUTTON_ACTIVE) - else: - await animate_swipe() - - await paginate(render, len(mnemonic_pages)) + words = list(enumerate(mnemonic.split())) + pages = list(chunks(words, words_per_page)) + await paginate(show_mnemonic_page, len(pages), first_page, pages) +@unimport async def layout_reset_device(message, session_id): - # TODO: Failure if not empty + from trezor.messages.Success import Success - mnemonic = await generate_mnemonic( - message.strength, message.display_random, session_id) + mnemonic = await generate_mnemonic(message.strength, + message.display_random, + session_id) + # await show_mnemonic(mnemonic) - # if m.pin_protection: - # pin = yield from request_new_pin() - # else: - # pin = None + if message.pin_protection: + pin = await request_new_pin(session_id) + else: + pin = None - await show_mnemonic(mnemonic) + return Success() diff --git a/src/trezor/ui/scroll.py b/src/trezor/ui/scroll.py index c3fdb52bf..b67f8b6e4 100644 --- a/src/trezor/ui/scroll.py +++ b/src/trezor/ui/scroll.py @@ -11,10 +11,10 @@ async def change_page(page, page_count): return page - 1 # scroll up -async def paginate(render_page, page_count, page=0): +async def paginate(render_page, page_count, page=0, *args): while True: changer = change_page(page, page_count) - renderer = render_page(page, page_count) + renderer = render_page(page, page_count, *args) waiter = loop.Wait([changer, renderer]) result = await waiter if changer in waiter.finished: @@ -41,7 +41,7 @@ def render_scrollbar(page, page_count): if page_count * padding > screen_height: padding = screen_height // page_count - x = 225 + x = const(225) y = (screen_height // 2) - (page_count // 2) * padding for i in range(0, page_count): diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 61d5270ff..66896b9dd 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -40,8 +40,7 @@ def close_session(session_id): def register_type(wire_type, genfunc, *args): if wire_type in _workflow_genfuncs: raise KeyError('message of type %d already registered' % wire_type) - log.info(__name__, 'registering %s for type %d', - (genfunc, args), wire_type) + log.info(__name__, 'registering message type %d', wire_type) _workflow_genfuncs[wire_type] = (genfunc, args) @@ -49,8 +48,8 @@ def register_session(session_id, handler): if session_id not in _opened_sessions: raise KeyError('session %d is unknown' % session_id) if session_id in _session_handlers: - raise KeyError('session %d is already registered' % session_id) - log.info(__name__, 'registering %s for session %d', handler, session_id) + raise KeyError('session %d is already being listened on' % session_id) + log.info(__name__, 'listening on session %d', session_id) _session_handlers[session_id] = handler @@ -78,6 +77,7 @@ def setup(): async def read_message(session_id, *exp_types): + log.info(__name__, 'reading message, one of %s', exp_types) future = Future() wire_decoder = decode_wire_stream( _dispatch_and_build_protobuf, session_id, exp_types, future) @@ -87,6 +87,7 @@ async def read_message(session_id, *exp_types): async def write_message(session_id, pbuf_message): + log.info(__name__, 'writing message %s', pbuf_message) msg_data = await pbuf_message.dumps() msg_type = pbuf_message.message_type.wire_type writer = write_report_stream() @@ -94,9 +95,56 @@ async def write_message(session_id, pbuf_message): encode_wire_message(msg_type, msg_data, session_id, writer) +async def reply_message(session_id, pbuf_message, *exp_types): + await write_message(session_id, pbuf_message) + return await read_message(session_id, *exp_types) + + +class FailureError(Exception): + + def __init__(self, code, message): + super(FailureError, self).__init__(code, message) + + def to_protobuf(self): + from trezor.messages.Failure import Failure + return Failure(code=self.args[0], + message=self.args[1]) + + +async def monitor_workflow(workflow, session_id): + try: + result = await workflow + + except FailureError as e: + await write_message(session_id, e.to_protobuf()) + raise + + except Exception as e: + from trezor.messages.Failure import Failure + from trezor.messages.FailureType import FirmwareError + await write_message(session_id, + Failure(code=FirmwareError, + message='Firmware Error')) + raise + + else: + if result is not None: + await write_message(session_id, result) + return result + + finally: + if session_id in _opened_sessions: + wire_decoder = decode_wire_stream( + _handle_registered_type, session_id) + wire_decoder.send(None) + register_session(session_id, wire_decoder) + + def protobuf_handler(msg_type, data_len, session_id, callback, *args): def finalizer(message): - start_workflow(callback(message, session_id, *args)) + workflow = callback(message, session_id, *args) + monitored = monitor_workflow(workflow, session_id) + start_workflow(monitored) pbuf_type = get_protobuf_type(msg_type) builder = build_protobuf_message(pbuf_type, finalizer) builder.send(None) @@ -125,10 +173,6 @@ def _handle_unknown_session(): yield # TODO -class UnexpectedMessageError(Exception): - pass - - def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, future): if msg_type in exp_types: pbuf_type = get_protobuf_type(msg_type) @@ -136,18 +180,20 @@ def _dispatch_and_build_protobuf(msg_type, data_len, session_id, exp_types, futu builder.send(None) return pbuf_type.load(builder) else: - future.resolve(UnexpectedMessageError(msg_type)) + from trezor.messages.FailureType import UnexpectedMessage + future.resolve(FailureError(UnexpectedMessage, 'Unexpected message')) return _handle_registered_type(msg_type, data_len, session_id) def _handle_registered_type(msg_type, data_len, session_id): - genfunc, args = _workflow_genfuncs.get( - msg_type, (_handle_unexpected_type, ())) + fallback = (_handle_unexpected_type, ()) + genfunc, args = _workflow_genfuncs.get(msg_type, fallback) return genfunc(msg_type, data_len, session_id, *args) def _handle_unexpected_type(msg_type, data_len, session_id): - log.info(__name__, 'skipping message %d of len %d' % (msg_type, data_len)) + log.info(__name__, 'skipping message %d of len %d on session %d' % + (msg_type, data_len, session_id)) try: while True: yield diff --git a/src/trezor/workflows/request_pin.py b/src/trezor/workflows/request_pin.py index 1dd42d219..e057aef61 100644 --- a/src/trezor/workflows/request_pin.py +++ b/src/trezor/workflows/request_pin.py @@ -1,58 +1,37 @@ from trezor import ui from trezor import wire -from trezor import config from trezor.utils import unimport -MANAGEMENT_APP = const(1) -PASSPHRASE_PROTECT = (1) # 0 | 1 -PIN_PROTECT = const(2) # 0 | 1 -PIN = const(4) # str - - -def prompt_pin(*args, **kwargs): - from trezor.ui.pin import PinMatrix +@unimport +async def request_pin(session_id, *args, **kwargs): + from trezor.messages.ButtonRequest import ButtonRequest + from trezor.messages.ButtonRequestType import ProtectCall + from trezor.messages.FailureType import PinCancelled + from trezor.messages.wire_types import ButtonAck from trezor.ui.confirm import ConfirmDialog, CONFIRMED + from trezor.ui.pin import PinMatrix - ui.clear() + await wire.reply_message(session_id, + ButtonRequest(code=ProtectCall), + ButtonAck) + ui.clear() matrix = PinMatrix(*args, **kwargs) dialog = ConfirmDialog(matrix) - result = yield from dialog.wait() + if await dialog != CONFIRMED: + raise wire.FailureError(PinCancelled, 'PIN cancelled') - return matrix.pin if result == CONFIRMED else None + return matrix.pin -def request_pin(*args, **kwargs): - from trezor.messages.ButtonRequest import ButtonRequest - from trezor.messages.ButtonRequestType import ProtectCall - from trezor.messages.ButtonAck import ButtonAck - - ack = yield from wire.call(ButtonRequest(code=ProtectCall), ButtonAck) - pin = yield from prompt_pin(*args, **kwargs) - - return pin - - -def change_pin(): - pass - - -def protect_with_pin(): - from trezor.messages.Failure import Failure +@unimport +async def request_new_pin(session_id): from trezor.messages.FailureType import PinInvalid - from trezor.messages.FailureType import ActionCancelled - - pin_protect = config.get(MANAGEMENT_APP, PIN_PROTECT) - if not pin_protect: - return - entered_pin = yield from request_pin() - if entered_pin is None: - yield from wire.write(Failure(code=ActionCancelled, message='Cancelled')) - raise Exception('Cancelled') + pin_first = await request_pin(session_id) + pin_again = await request_pin(session_id, 'Enter PIN again') + if pin_first != pin_again: + raise wire.FailureError(PinInvalid, 'PIN invalid') - stored_pin = config.get(MANAGEMENT_APP, PIN) - if stored_pin != entered_pin: - yield from wire.write(Failure(code=PinInvalid, message='PIN invalid')) - raise Exception('PIN invalid') + return pin_first