diff --git a/src/apps/common/confirm.py b/src/apps/common/confirm.py index 4560faa952..c276ca7317 100644 --- a/src/apps/common/confirm.py +++ b/src/apps/common/confirm.py @@ -19,9 +19,9 @@ async def confirm(ctx, content, code=None, *args, **kwargs): dialog = ConfirmDialog(content, *args, **kwargs) if __debug__: - waiter = loop.wait(signal, dialog) + waiter = ctx.wait(signal, dialog) else: - waiter = dialog + waiter = ctx.wait(dialog) return await waiter == CONFIRMED @@ -34,9 +34,9 @@ async def hold_to_confirm(ctx, content, code=None, *args, **kwargs): dialog = HoldToConfirmDialog(content, 'Hold to confirm', *args, **kwargs) if __debug__: - waiter = loop.wait(signal, dialog) + waiter = ctx.wait(signal, dialog) else: - waiter = dialog + waiter = ctx.wait(dialog) return await waiter == CONFIRMED diff --git a/src/apps/common/request_passphrase.py b/src/apps/common/request_passphrase.py index 73aba9f686..0f42f98a6d 100644 --- a/src/apps/common/request_passphrase.py +++ b/src/apps/common/request_passphrase.py @@ -1,4 +1,4 @@ -from trezor import ui, wire +from trezor import loop, ui, wire from trezor.messages import ButtonRequestType, wire_types from trezor.messages.ButtonRequest import ButtonRequest from trezor.messages.FailureType import ActionCancelled, ProcessError @@ -24,7 +24,8 @@ async def request_passphrase_entry(ctx): if ack.MESSAGE_WIRE_TYPE == wire_types.Cancel: raise wire.FailureError(ActionCancelled, 'Passphrase cancelled') - return await EntrySelector(text) + selector = EntrySelector(text) + return await ctx.wait(selector) @ui.layout @@ -43,7 +44,8 @@ async def request_passphrase_ack(ctx, on_device): if on_device: if ack.passphrase is not None: raise wire.FailureError(ProcessError, 'Passphrase provided when it should not be') - passphrase = await PassphraseKeyboard('Enter passphrase') + keyboard = PassphraseKeyboard('Enter passphrase') + passphrase = await ctx.wait(keyboard) if passphrase == CANCELLED: raise wire.FailureError(ActionCancelled, 'Passphrase cancelled') else: diff --git a/src/apps/management/change_pin.py b/src/apps/management/change_pin.py index 84b978e6b3..bf6bc4192c 100644 --- a/src/apps/management/change_pin.py +++ b/src/apps/management/change_pin.py @@ -66,7 +66,7 @@ async def request_pin_ack(ctx, *args, **kwargs): # TODO: send PinMatrixRequest here, with specific code? await ctx.call(ButtonRequest(code=Other), wire_types.ButtonAck) try: - return await request_pin(*args, **kwargs) + return await ctx.wait(request_pin(*args, **kwargs)) except PinCancelled: raise wire.FailureError(FailureType.ActionCancelled, 'Cancelled') diff --git a/src/apps/management/recovery_device.py b/src/apps/management/recovery_device.py index 34913f8870..41bfdde40f 100644 --- a/src/apps/management/recovery_device.py +++ b/src/apps/management/recovery_device.py @@ -60,7 +60,7 @@ async def request_wordcount(ctx): content = Text('Device recovery', ui.ICON_RECOVERY, 'Number of words?') select = WordSelector(content) - count = await select + count = await ctx.wait(select) return count @@ -73,7 +73,7 @@ async def request_mnemonic(ctx, count: int) -> str: board = MnemonicKeyboard() for i in range(count): board.prompt = 'Type the %s word:' % format_ordinal(i + 1) - word = await board + word = await ctx.wait(board) words.append(word) return ' '.join(words) diff --git a/src/apps/management/reset_device.py b/src/apps/management/reset_device.py index 2bebb71155..b978ecdeb9 100644 --- a/src/apps/management/reset_device.py +++ b/src/apps/management/reset_device.py @@ -155,7 +155,8 @@ async def show_mnemonic(ctx, mnemonic: str): words_per_page = const(4) words = list(enumerate(mnemonic.split())) pages = list(chunks(words, words_per_page)) - await paginate(show_mnemonic_page, len(pages), first_page, pages) + paginator = paginate(show_mnemonic_page, len(pages), first_page, pages) + await ctx.wait(paginator) @ui.layout @@ -171,13 +172,22 @@ async def show_mnemonic_page(page: int, page_count: int, pages: list): await animate_swipe() -@ui.layout async def check_mnemonic(ctx, mnemonic: str) -> bool: words = mnemonic.split() - index = random.uniform(len(words) // 2) # first half - result = await MnemonicKeyboard('Type the %s word:' % format_ordinal(index + 1)) - if result != words[index]: + + # check a word from the first half + index = random.uniform(len(words) // 2) + if not await check_word(ctx, words, index): return False - index = len(words) // 2 + random.uniform(len(words) // 2) # second half - result = await MnemonicKeyboard('Type the %s word:' % format_ordinal(index + 1)) + + # check a word from the second half + index = random.uniform(len(words) // 2) + len(words) // 2 + if not await check_word(ctx, words, index): + return False + + +@ui.layout +async def check_word(ctx, words: list, index: int): + keyboard = MnemonicKeyboard('Type the %s word:' % format_ordinal(index + 1)) + result = await ctx.wait(keyboard) return result == words[index] diff --git a/src/trezor/wire/__init__.py b/src/trezor/wire/__init__.py index 7e7d816950..fef1ac53bc 100644 --- a/src/trezor/wire/__init__.py +++ b/src/trezor/wire/__init__.py @@ -19,11 +19,7 @@ def register(mtype, handler, *args): def setup(iface): '''Initialize the wire stack on passed USB interface.''' - # session_supervisor = codec_v2.SesssionSupervisor(iface, session_handler) - # session_supervisor.open(codec_v1.SESSION_ID) - # loop.schedule(session_supervisor.listen()) - handler = session_handler(iface, codec_v1.SESSION_ID) - loop.schedule(handler) + loop.schedule(session_handler(iface, codec_v1.SESSION_ID)) class Context: @@ -81,6 +77,14 @@ class Context: await protobuf.dump_message(writer, msg) await writer.aclose() + def wait(self, *tasks): + ''' + Wait until one of the passed tasks finishes, and return the result, + while servicing the wire context. If a message comes until one of the + tasks ends, `UnexpectedMessageError` is raised. + ''' + return loop.wait(self.read(()), *tasks) + def getreader(self): return codec_v1.Reader(self.iface)