From bb2556a22cd3ffdc612a06da7d40eeb36a38de10 Mon Sep 17 00:00:00 2001 From: Jan Pochyla Date: Tue, 20 Aug 2019 16:20:02 +0200 Subject: [PATCH] core: improve code documentation, simplify wire package * docs: improve loop.py, ui.init * docs: improve trezor.loop, rename spawn to race * docs: wire * core/wire: simplify and document the session handler * core/wire: improve documentation * core/wire: improve docs * core/docs: document ui.grid function * core: decouple ui and workflow, document both * core: improve docs Co-authored-by: Tomas Susanka --- core/src/apps/common/confirm.py | 4 +- core/src/apps/common/mnemonic.py | 5 +- core/src/apps/common/request_pin.py | 2 +- core/src/apps/debug/__init__.py | 10 +- core/src/apps/homescreen/__init__.py | 12 +- core/src/apps/management/common/layout.py | 2 +- .../recovery_device/bip39_keyboard.py | 14 +- .../apps/management/recovery_device/layout.py | 2 +- .../recovery_device/slip39_keyboard.py | 14 +- core/src/apps/monero/layout/common.py | 2 +- core/src/apps/monero/layout/confirms.py | 6 +- core/src/apps/webauthn/__init__.py | 6 +- core/src/main.py | 2 +- core/src/trezor/loop.py | 68 ++- core/src/trezor/ui/__init__.py | 176 +++++-- core/src/trezor/ui/button.py | 2 +- core/src/trezor/ui/checklist.py | 2 +- core/src/trezor/ui/confirm.py | 4 +- core/src/trezor/ui/container.py | 4 +- core/src/trezor/ui/loader.py | 2 +- core/src/trezor/ui/passphrase.py | 16 +- core/src/trezor/ui/pin.py | 2 +- core/src/trezor/ui/popup.py | 2 +- core/src/trezor/ui/qr.py | 2 +- core/src/trezor/ui/scroll.py | 10 +- core/src/trezor/ui/shamir.py | 2 +- core/src/trezor/ui/swipe.py | 2 +- core/src/trezor/ui/text.py | 4 +- core/src/trezor/ui/word_select.py | 2 +- core/src/trezor/wire/__init__.py | 444 +++++++++++------- core/src/trezor/wire/codec_v1.py | 10 +- core/src/trezor/workflow.py | 94 ++-- 32 files changed, 589 insertions(+), 340 deletions(-) diff --git a/core/src/apps/common/confirm.py b/core/src/apps/common/confirm.py index 85e4aa16e2..7605444af1 100644 --- a/core/src/apps/common/confirm.py +++ b/core/src/apps/common/confirm.py @@ -16,7 +16,7 @@ if False: async def confirm( ctx: wire.Context, - content: ui.Control, + content: ui.Component, code: int = ButtonRequestType.Other, confirm: ButtonContent = Confirm.DEFAULT_CONFIRM, confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE, @@ -49,7 +49,7 @@ async def confirm( async def hold_to_confirm( ctx: wire.Context, - content: ui.Control, + content: ui.Component, code: int = ButtonRequestType.Other, confirm: ButtonContent = HoldToConfirm.DEFAULT_CONFIRM, confirm_style: ButtonStyleType = HoldToConfirm.DEFAULT_CONFIRM_STYLE, diff --git a/core/src/apps/common/mnemonic.py b/core/src/apps/common/mnemonic.py index 77a91bbe80..a86ecc8e9f 100644 --- a/core/src/apps/common/mnemonic.py +++ b/core/src/apps/common/mnemonic.py @@ -70,7 +70,10 @@ def type_from_word_count(count: int) -> int: def _start_progress() -> None: - workflow.closedefault() + # Because we are drawing to the screen manually, without a layout, we + # should make sure that no other layout is running. At this point, only + # the homescreen should be on, so shut it down. + workflow.close_default() ui.backlight_fade(ui.BACKLIGHT_DIM) ui.display.clear() ui.header("Please wait") diff --git a/core/src/apps/common/request_pin.py b/core/src/apps/common/request_pin.py index c00afcf222..b27acf9228 100644 --- a/core/src/apps/common/request_pin.py +++ b/core/src/apps/common/request_pin.py @@ -25,7 +25,7 @@ async def request_pin( while True: if __debug__: - result = await loop.spawn(dialog, input_signal) + result = await loop.race(dialog, input_signal) else: result = await dialog if result is CANCELLED: diff --git a/core/src/apps/debug/__init__.py b/core/src/apps/debug/__init__.py index 96ff2151d8..84bd2faaf1 100644 --- a/core/src/apps/debug/__init__.py +++ b/core/src/apps/debug/__init__.py @@ -6,7 +6,7 @@ if not __debug__: if __debug__: from trezor import config, log, loop, utils from trezor.messages import MessageType - from trezor.wire import register, protobuf_workflow + from trezor.wire import register if False: from typing import Optional @@ -79,9 +79,5 @@ if __debug__: if not utils.EMULATOR: config.wipe() - register( - MessageType.DebugLinkDecision, protobuf_workflow, dispatch_DebugLinkDecision - ) - register( - MessageType.DebugLinkGetState, protobuf_workflow, dispatch_DebugLinkGetState - ) + register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision) + register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState) diff --git a/core/src/apps/homescreen/__init__.py b/core/src/apps/homescreen/__init__.py index ae82826e79..f80e08fbfc 100644 --- a/core/src/apps/homescreen/__init__.py +++ b/core/src/apps/homescreen/__init__.py @@ -2,7 +2,7 @@ from trezor import config, utils, wire from trezor.messages import MessageType from trezor.messages.Features import Features from trezor.messages.Success import Success -from trezor.wire import protobuf_workflow, register +from trezor.wire import register from apps.common import cache, storage @@ -75,9 +75,9 @@ async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success: def boot(features_only: bool = False) -> None: - register(MessageType.Initialize, protobuf_workflow, handle_Initialize) - register(MessageType.GetFeatures, protobuf_workflow, handle_GetFeatures) + register(MessageType.Initialize, handle_Initialize) + register(MessageType.GetFeatures, handle_GetFeatures) if not features_only: - register(MessageType.Cancel, protobuf_workflow, handle_Cancel) - register(MessageType.ClearSession, protobuf_workflow, handle_ClearSession) - register(MessageType.Ping, protobuf_workflow, handle_Ping) + register(MessageType.Cancel, handle_Cancel) + register(MessageType.ClearSession, handle_ClearSession) + register(MessageType.Ping, handle_Ping) diff --git a/core/src/apps/management/common/layout.py b/core/src/apps/management/common/layout.py index 4668e01e16..cb386acb37 100644 --- a/core/src/apps/management/common/layout.py +++ b/core/src/apps/management/common/layout.py @@ -424,7 +424,7 @@ def _slip39_split_share_into_pages(share_words): return first, list(chunks), last -class ShamirNumInput(ui.Control): +class ShamirNumInput(ui.Component): SET_SHARES = object() SET_THRESHOLD = object() diff --git a/core/src/apps/management/recovery_device/bip39_keyboard.py b/core/src/apps/management/recovery_device/bip39_keyboard.py index 2f67c45cfd..0dcee6f351 100644 --- a/core/src/apps/management/recovery_device/bip39_keyboard.py +++ b/core/src/apps/management/recovery_device/bip39_keyboard.py @@ -85,7 +85,7 @@ class InputButton(Button): display.icon(ix, iy, self.icon, fg_color, bg_color) -class Prompt(ui.Control): +class Prompt(ui.Component): def __init__(self, prompt: str) -> None: self.prompt = prompt self.repaint = True @@ -192,17 +192,17 @@ class Bip39Keyboard(ui.Layout): async def handle_input(self) -> None: touch = loop.wait(io.TOUCH) timeout = loop.sleep(1000 * 1000 * 1) - spawn_touch = loop.spawn(touch) - spawn_timeout = loop.spawn(touch, timeout) + race_touch = loop.race(touch) + race_timeout = loop.race(touch, timeout) while True: if self.pending_button is not None: - spawn = spawn_timeout + race = race_timeout else: - spawn = spawn_touch - result = await spawn + race = race_touch + result = await race - if touch in spawn.finished: + if touch in race.finished: event, x, y = result self.dispatch(event, x, y) else: diff --git a/core/src/apps/management/recovery_device/layout.py b/core/src/apps/management/recovery_device/layout.py index ce32684ab5..384f2c34b2 100644 --- a/core/src/apps/management/recovery_device/layout.py +++ b/core/src/apps/management/recovery_device/layout.py @@ -169,7 +169,7 @@ async def show_identifier_mismatch(ctx: wire.Context) -> None: ) -class RecoveryHomescreen(ui.Control): +class RecoveryHomescreen(ui.Component): def __init__(self, text: str, subtext: str = None): self.text = text self.subtext = subtext diff --git a/core/src/apps/management/recovery_device/slip39_keyboard.py b/core/src/apps/management/recovery_device/slip39_keyboard.py index a3375e45ba..0d5a2d39c0 100644 --- a/core/src/apps/management/recovery_device/slip39_keyboard.py +++ b/core/src/apps/management/recovery_device/slip39_keyboard.py @@ -88,7 +88,7 @@ class InputButton(Button): display.icon(ix, iy, self.icon, fg_color, bg_color) -class Prompt(ui.Control): +class Prompt(ui.Component): def __init__(self, prompt: str) -> None: self.prompt = prompt self.repaint = True @@ -202,17 +202,17 @@ class Slip39Keyboard(ui.Layout): async def handle_input(self) -> None: touch = loop.wait(io.TOUCH) timeout = loop.sleep(1000 * 1000 * 1) - spawn_touch = loop.spawn(touch) - spawn_timeout = loop.spawn(touch, timeout) + race_touch = loop.race(touch) + race_timeout = loop.race(touch, timeout) while True: if self.pending_button is not None: - spawn = spawn_timeout + race = race_timeout else: - spawn = spawn_touch - result = await spawn + race = race_touch + result = await race - if touch in spawn.finished: + if touch in race.finished: event, x, y = result self.dispatch(event, x, y) else: diff --git a/core/src/apps/monero/layout/common.py b/core/src/apps/monero/layout/common.py index b3555030e2..f31b79600c 100644 --- a/core/src/apps/monero/layout/common.py +++ b/core/src/apps/monero/layout/common.py @@ -30,7 +30,7 @@ async def naive_pagination( while True: await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck) if __debug__: - result = await loop.spawn(paginated, confirm_signal) + result = await loop.race(paginated, confirm_signal) else: result = await paginated if result is CONFIRMED: diff --git a/core/src/apps/monero/layout/confirms.py b/core/src/apps/monero/layout/confirms.py index 41b905b808..88ec28ce8b 100644 --- a/core/src/apps/monero/layout/confirms.py +++ b/core/src/apps/monero/layout/confirms.py @@ -117,7 +117,7 @@ async def _require_confirm_fee(ctx, fee): await require_hold_to_confirm(ctx, content, ButtonRequestType.ConfirmOutput) -class TransactionStep(ui.Control): +class TransactionStep(ui.Component): def __init__(self, state, info): self.state = state self.info = info @@ -133,7 +133,7 @@ class TransactionStep(ui.Control): ui.display.text_center(ui.WIDTH // 2, 235, info[1], ui.NORMAL, ui.FG, ui.BG) -class KeyImageSyncStep(ui.Control): +class KeyImageSyncStep(ui.Component): def __init__(self, current, total_num): self.current = current self.total_num = total_num @@ -146,7 +146,7 @@ class KeyImageSyncStep(ui.Control): ui.display.loader(p, False, 18, ui.WHITE, ui.BG) -class LiveRefreshStep(ui.Control): +class LiveRefreshStep(ui.Component): def __init__(self, current): self.current = current diff --git a/core/src/apps/webauthn/__init__.py b/core/src/apps/webauthn/__init__.py index 50d1bce220..10510616a8 100644 --- a/core/src/apps/webauthn/__init__.py +++ b/core/src/apps/webauthn/__init__.py @@ -377,10 +377,10 @@ class ConfirmState: async def confirm_workflow(self) -> None: try: - workflow.onstart(self.workflow) + workflow.on_start(self.workflow) await self.confirm_layout() finally: - workflow.onclose(self.workflow) + workflow.on_close(self.workflow) self.workflow = None async def confirm_layout(self) -> None: @@ -402,7 +402,7 @@ class ConfirmState: self.confirmed = await dialog is CONFIRMED -class ConfirmContent(ui.Control): +class ConfirmContent(ui.Component): def __init__(self, action: int, app_id: bytes) -> None: self.action = action self.app_id = app_id diff --git a/core/src/main.py b/core/src/main.py index 737661825b..859153b611 100644 --- a/core/src/main.py +++ b/core/src/main.py @@ -70,7 +70,7 @@ def _boot_default() -> None: # run main event loop and specify which screen is the default from apps.homescreen.homescreen import homescreen - workflow.startdefault(homescreen) + workflow.start_default(homescreen) from trezor import loop, wire, workflow diff --git a/core/src/trezor/loop.py b/core/src/trezor/loop.py index db8ec00769..ddaa38b4eb 100644 --- a/core/src/trezor/loop.py +++ b/core/src/trezor/loop.py @@ -4,7 +4,7 @@ the form of python coroutines (either plain generators or `async` functions) are stepped through until completion, and can get asynchronously blocked by `yield`ing or `await`ing a syscall. -See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `spawn`. +See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `race`. """ import utime @@ -57,6 +57,8 @@ def schedule( """ Schedule task to be executed with `value` on given `deadline` (in microseconds). Does not start the event loop itself, see `run`. + Usually done in very low-level cases, see `race` for more user-friendly + and correct concept. """ if deadline is None: deadline = utime.ticks_us() @@ -66,6 +68,11 @@ def schedule( def pause(task: Task, iface: int) -> None: + """ + Block task on given message interface. Task is resumed when the interface + is activated. It is most probably wrong to call `pause` from user code, + see the `wait` syscall for the correct concept. + """ tasks = _paused.get(iface, None) if tasks is None: tasks = _paused[iface] = set() @@ -73,12 +80,17 @@ def pause(task: Task, iface: int) -> None: def finalize(task: Task, value: Any) -> None: + """Call and remove any finalization callbacks registered for given task.""" fn = _finalizers.pop(id(task), None) if fn is not None: fn(task, value) def close(task: Task) -> None: + """ + Deschedule and unblock a task, close it so it can release all resources, and + call its finalizer. + """ for iface in _paused: _paused[iface].discard(task) _queue.discard(task) @@ -137,6 +149,21 @@ def clear() -> None: def _step(task: Task, value: Any) -> None: + """ + Step through the task by sending `value` to `Task`. This can result in either: + 1. The task raises an exception: + a) StopIteration + - The Task is completed and we call finalize to finish it. + b) Exception + - An error occurred. We still need to call finalize. + 2. Task does not raise exception and returns either: + a) Syscall + - Syscall.handle is called. + b) None + - The Task is simply scheduled to continue. + c) Something else + - That should not happen - error. + """ try: if isinstance(value, BaseException): result = task.throw(value) # type: ignore @@ -144,7 +171,7 @@ def _step(task: Task, value: Any) -> None: # rationale: In micropython, generator.throw() accepts the exception object directly. else: result = task.send(value) - except StopIteration as e: # as e: + except StopIteration as e: if __debug__: log.debug(__name__, "finish: %s", task) finalize(task, e.value) @@ -205,7 +232,7 @@ class wait(Syscall): """ Pause current task, and resume only after a message on `msg_iface` is received. Messages are received either from an USB interface, or the - touch display. Result value a tuple of message values. + touch display. Result value is a tuple of message values. Example: @@ -223,29 +250,33 @@ class wait(Syscall): _type_gen = type((lambda: (yield))()) -class spawn(Syscall): +class race(Syscall): """ - Execute one or more children tasks and wait until one of them exits. - Return value of `spawn` is the return value of task that triggered the - completion. By default, `spawn` returns after the first child completes, and - other running children are killed (by cancelling any pending schedules and - calling `close()`). + Given a list of either children tasks or syscalls, `race` waits until one of + them completes (tasks are executed in parallel, syscalls are waited upon, + directly). Return value of `race` is the return value of the child that + triggered the completion. Other running children are killed (by cancelling + any pending schedules and raising a `GeneratorExit` by calling `close()`). + Child that caused the completion is present in `self.finished`. Example: >>> # async def wait_for_touch(): ... >>> # async def animate_logo(): ... + >>> some_signal = loop.signal() >>> touch_task = wait_for_touch() >>> animation_task = animate_logo() - >>> waiter = loop.spawn(touch_task, animation_task) - >>> result = await waiter - >>> if animation_task in waiter.finished: - >>> print('animation task returned', result) + >>> racer = loop.race(some_signal, touch_task, animation_task) + >>> result = await racer + >>> if animation_task in racer.finished: + >>> print('animation task returned value:', result) + >>> elif touch_task in racer.finished: + >>> print('touch task returned value:', result) >>> else: - >>> print('touch task returned', result) + >>> print('signal was triggered with value:', result) - Note: You should not directly `yield` a `spawn` instance, see logic in - `spawn.__iter__` for explanation. Always use `await`. + Note: You should not directly `yield` a `race` instance, see logic in + `race.__iter__` for explanation. Always use `await`. """ def __init__(self, *children: Awaitable, exit_others: bool = True) -> None: @@ -255,6 +286,9 @@ class spawn(Syscall): self.scheduled = [] # type: List[Task] # scheduled wrapper tasks def handle(self, task: Task) -> None: + """ + Schedule all children Tasks and set `task` as callback. + """ finalizer = self._finish scheduled = self.scheduled finished = self.finished @@ -279,6 +313,8 @@ class spawn(Syscall): def _finish(self, task: Task, result: Any) -> None: if not self.finished: + # because we create tasks for children that are not generators yet, + # we need to find the child value that the caller supplied for index, child_task in enumerate(self.scheduled): if child_task is task: child = self.children[index] diff --git a/core/src/trezor/ui/__init__.py b/core/src/trezor/ui/__init__.py index e53f0dbf46..cbe31315f6 100644 --- a/core/src/trezor/ui/__init__.py +++ b/core/src/trezor/ui/__init__.py @@ -3,7 +3,7 @@ import utime from micropython import const from trezorui import Display -from trezor import io, loop, res, utils, workflow +from trezor import io, loop, res, utils if False: from typing import Any, Generator, Iterable, Tuple, TypeVar @@ -12,9 +12,25 @@ if False: Area = Tuple[int, int, int, int] ResultValue = TypeVar("ResultValue") - +# all rendering is done through a singleton of `Display` display = Display() +# re-export constants from modtrezorui +NORMAL = Display.FONT_NORMAL +BOLD = Display.FONT_BOLD +MONO = Display.FONT_MONO +MONO_BOLD = Display.FONT_MONO_BOLD +SIZE = Display.FONT_SIZE +WIDTH = Display.WIDTH +HEIGHT = Display.HEIGHT + +# viewport margins +VIEWX = const(6) +VIEWY = const(9) + +# channel used to cancel layouts, see `Cancelled` exception +layout_chan = loop.chan() + # in debug mode, display an indicator in top right corner if __debug__: @@ -30,19 +46,6 @@ if __debug__: elif utils.EMULATOR: loop.after_step_hook = display.refresh -# re-export constants from modtrezorui -NORMAL = Display.FONT_NORMAL -BOLD = Display.FONT_BOLD -MONO = Display.FONT_MONO -MONO_BOLD = Display.FONT_MONO_BOLD -SIZE = Display.FONT_SIZE -WIDTH = Display.WIDTH -HEIGHT = Display.HEIGHT - -# viewport margins -VIEWX = const(6) -VIEWY = const(9) - def lerpi(a: int, b: int, t: float) -> int: return int(a + t * (b - a)) @@ -67,9 +70,9 @@ from trezor.ui import style # isort:skip from trezor.ui.style import * # isort:skip # noqa: F401,F403 -def pulse(delay: int) -> float: +def pulse(coef: int) -> float: # normalize sin from interval -1:1 to 0:1 - return 0.5 + 0.5 * math.sin(utime.ticks_us() / delay) + return 0.5 + 0.5 * math.sin(utime.ticks_us() / coef) async def click() -> Pos: @@ -111,7 +114,6 @@ def header( def header_warning(message: str, clear: bool = True) -> None: - # TODO: review: is the clear=True really needed? display.bar(0, 0, WIDTH, 30, style.YELLOW) display.text_center(WIDTH // 2, 22, message, BOLD, style.BLACK, style.YELLOW) if clear: @@ -119,7 +121,6 @@ def header_warning(message: str, clear: bool = True) -> None: def header_error(message: str, clear: bool = True) -> None: - # TODO: review: as above display.bar(0, 0, WIDTH, 30, style.RED) display.text_center(WIDTH // 2, 22, message, BOLD, style.WHITE, style.RED) if clear: @@ -127,17 +128,31 @@ def header_error(message: str, clear: bool = True) -> None: def grid( - i: int, - n_x: int = 3, - n_y: int = 5, - start_x: int = VIEWX, - start_y: int = VIEWY, - end_x: int = (WIDTH - VIEWX), - end_y: int = (HEIGHT - VIEWY), - cells_x: int = 1, - cells_y: int = 1, - spacing: int = 0, + i: int, # i-th cell of the table of which we wish to return Area (snake-like starting with 0) + n_x: int = 3, # number of rows in the table + n_y: int = 5, # number of columns in the table + start_x: int = VIEWX, # where the table starts on x-axis + start_y: int = VIEWY, # where the table starts on y-axis + end_x: int = (WIDTH - VIEWX), # where the table ends on x-axis + end_y: int = (HEIGHT - VIEWY), # where the table ends on y-axis + cells_x: int = 1, # number of cells to be merged into one in the direction of x-axis + cells_y: int = 1, # number of cells to be merged into one in the direction of y-axis + spacing: int = 0, # spacing size between cells ) -> Area: + """ + Returns area (tuple of four integers, in pixels) of a cell on i-th possition + in a table you define yourself. Example: + + >>> ui.grid(4, n_x=2, n_y=3, start_x=20, start_y=20) + (20, 160, 107, 70) + + Returns 5th cell from the following table. It has two columns, three rows + and starts on coordinates 20-20. + + |____|____| + |____|____| + |XXXX|____| + """ w = (end_x - start_x) // n_x h = (end_y - start_y) // n_y x = (i % n_x) * w @@ -150,12 +165,30 @@ def in_area(area: Area, x: int, y: int) -> bool: return ax <= x <= ax + aw and ay <= y <= ay + ah -# render events +# Component events. Should be different from `io.TOUCH_*` events. +# Event dispatched when components should draw to the display, if they are +# marked for re-paint. RENDER = const(-255) +# Event dispatched when components should mark themselves for re-painting. REPAINT = const(-256) +# How long, in microseconds, should the layout rendering task sleep betweeen +# the render calls. +_RENDER_DELAY_US = const(10000) # 10 msec + + +class Component: + """ + Abstract class. + + Components are GUI classes that inherit `Component` and form a tree, with a + `Layout` at the root, and other components underneath. Components that + have children, and therefore need to dispatch events to them, usually + override the `dispatch` method. Leaf components usually override the event + methods (`on_*`). Components signal a completion to the layout by raising + an instance of `Result`. + """ -class Control: def dispatch(self, event: int, x: int, y: int) -> None: if event is RENDER: self.on_render() @@ -181,58 +214,107 @@ class Control: pass -_RENDER_DELAY_US = const(10000) # 10 msec - - -class LayoutCancelled(Exception): - pass - - class Result(Exception): + """ + When components want to trigger layout completion, they do so through + raising an instance of `Result`. + + See `Layout.__iter__` for details. + """ + def __init__(self, value: ResultValue) -> None: self.value = value -class Layout(Control): +class Cancelled(Exception): """ + Layouts can be explicitly cancelled. This usually happens when another + layout starts, because only one layout can be running at the same time, + and is done by raising `Cancelled` on the cancelled layout. Layouts + should always re-raise such exceptions. + + See `Layout.__iter__` for details. + """ + + pass + + +class Layout(Component): + """ + Abstract class. + + Layouts are top-level components. Only one layout can be running at the + same time. Layouts provide asynchronous interface, so a running task can + wait for the layout to complete. Layouts complete when a `Result` is + raised, usually from some of the child components. """ async def __iter__(self) -> ResultValue: + """ + Run the layout and wait until it completes. Returns the result value. + Usually not overriden. + """ value = None try: - if workflow.layout_signal.takers: - await workflow.layout_signal.put(LayoutCancelled()) - workflow.onlayoutstart(self) + # If any other layout is running (waiting on the layout channel), + # we close it with the Cancelled exception, and wait until it is + # closed, just to be sure. + if layout_chan.takers: + await layout_chan.put(Cancelled()) + # Now, no other layout should be running. In a loop, we create new + # layout tasks and execute them in parallel, while waiting on the + # layout channel. This allows other layouts to cancel us, and the + # layout tasks to trigger restart by exiting (new tasks are created + # and we continue, because we are in a loop). while True: - layout_tasks = self.create_tasks() - await loop.spawn(workflow.layout_signal.take, *layout_tasks) + await loop.race(layout_chan.take, *self.create_tasks()) except Result as result: + # Result exception was raised, this means this layout is complete. value = result.value - finally: - workflow.onlayoutclose(self) return value def __await__(self) -> Generator[Any, Any, ResultValue]: return self.__iter__() # type: ignore def create_tasks(self) -> Iterable[loop.Task]: + """ + Called from `__iter__`. Creates and returns a sequence of tasks that + run this layout. Tasks are executed in parallel. When one of them + returns, the others are closed and `create_tasks` is called again. + + Usually overriden to add another task to the list.""" return self.handle_input(), self.handle_rendering() def handle_input(self) -> loop.Task: # type: ignore + """Task that is waiting for the user input.""" touch = loop.wait(io.TOUCH) while True: event, x, y = yield touch self.dispatch(event, x, y) + # We dispatch a render event right after the touch. Quick and dirty + # way to get the lowest input-to-render latency. self.dispatch(RENDER, 0, 0) def handle_rendering(self) -> loop.Task: # type: ignore + """Task that is rendering the layout in a busy loop.""" + # Before the first render, we dim the display. backlight_fade(style.BACKLIGHT_DIM) + # Clear the screen of any leftovers, make sure everything is marked for + # repaint (we can be running the same layout instance multiple times) + # and paint it. display.clear() self.dispatch(REPAINT, 0, 0) self.dispatch(RENDER, 0, 0) + # Display is usually refreshed after every loop step, but here we are + # rendering everything synchronously, so refresh it manually and turn + # the brightness on again. display.refresh() backlight_fade(style.BACKLIGHT_NORMAL) sleep = loop.sleep(_RENDER_DELAY_US) while True: - self.dispatch(RENDER, 0, 0) + # Wait for a couple of ms and render the layout again. Because + # components use re-paint marking, they do not really draw on the + # display needlessly. + # TODO: remove the busy loop yield sleep + self.dispatch(RENDER, 0, 0) diff --git a/core/src/trezor/ui/button.py b/core/src/trezor/ui/button.py index c433185b34..411c2c4a7b 100644 --- a/core/src/trezor/ui/button.py +++ b/core/src/trezor/ui/button.py @@ -118,7 +118,7 @@ _ICON = const(16) # icon size in pixels _BORDER = const(4) # border size in pixels -class Button(ui.Control): +class Button(ui.Component): def __init__( self, area: ui.Area, diff --git a/core/src/trezor/ui/checklist.py b/core/src/trezor/ui/checklist.py index 2c957058d7..8853cfe5e5 100644 --- a/core/src/trezor/ui/checklist.py +++ b/core/src/trezor/ui/checklist.py @@ -13,7 +13,7 @@ _CHECKLIST_OFFSET_X = const(24) _CHECKLIST_OFFSET_X_ICON = const(0) -class Checklist(ui.Control): +class Checklist(ui.Component): def __init__(self, title: str, icon: str) -> None: self.title = title self.icon = icon diff --git a/core/src/trezor/ui/confirm.py b/core/src/trezor/ui/confirm.py index 772c0183d8..4b4814cad4 100644 --- a/core/src/trezor/ui/confirm.py +++ b/core/src/trezor/ui/confirm.py @@ -19,7 +19,7 @@ class Confirm(ui.Layout): def __init__( self, - content: ui.Control, + content: ui.Component, confirm: Optional[ButtonContent] = DEFAULT_CONFIRM, confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE, cancel: Optional[ButtonContent] = DEFAULT_CANCEL, @@ -75,7 +75,7 @@ class HoldToConfirm(ui.Layout): def __init__( self, - content: ui.Control, + content: ui.Component, confirm: str = DEFAULT_CONFIRM, confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE, loader_style: LoaderStyleType = DEFAULT_LOADER_STYLE, diff --git a/core/src/trezor/ui/container.py b/core/src/trezor/ui/container.py index 476a7f5c9b..2c97bc8616 100644 --- a/core/src/trezor/ui/container.py +++ b/core/src/trezor/ui/container.py @@ -1,8 +1,8 @@ from trezor import ui -class Container(ui.Control): - def __init__(self, *children: ui.Control): +class Container(ui.Component): + def __init__(self, *children: ui.Component): self.children = children def dispatch(self, event: int, x: int, y: int) -> None: diff --git a/core/src/trezor/ui/loader.py b/core/src/trezor/ui/loader.py index bb21955d75..85b24e911b 100644 --- a/core/src/trezor/ui/loader.py +++ b/core/src/trezor/ui/loader.py @@ -37,7 +37,7 @@ if False: _TARGET_MS = const(1000) -class Loader(ui.Control): +class Loader(ui.Component): def __init__(self, style: LoaderStyleType = LoaderDefault) -> None: self.normal_style = style.normal self.active_style = style.active diff --git a/core/src/trezor/ui/passphrase.py b/core/src/trezor/ui/passphrase.py index 0544f973f8..cc344e737f 100644 --- a/core/src/trezor/ui/passphrase.py +++ b/core/src/trezor/ui/passphrase.py @@ -114,7 +114,7 @@ class Input(Button): pass -class Prompt(ui.Control): +class Prompt(ui.Component): def __init__(self, text: str) -> None: self.text = text self.repaint = True @@ -210,17 +210,17 @@ class PassphraseKeyboard(ui.Layout): async def handle_input(self) -> None: touch = loop.wait(io.TOUCH) timeout = loop.sleep(1000 * 1000 * 1) - spawn_touch = loop.spawn(touch) - spawn_timeout = loop.spawn(touch, timeout) + race_touch = loop.race(touch) + race_timeout = loop.race(touch, timeout) while True: if self.pending_button is not None: - spawn = spawn_timeout + race = race_timeout else: - spawn = spawn_touch - result = await spawn + race = race_touch + result = await race - if touch in spawn.finished: + if touch in race.finished: event, x, y = result self.dispatch(event, x, y) else: @@ -249,7 +249,7 @@ class PassphraseKeyboard(ui.Layout): class PassphraseSource(ui.Layout): - def __init__(self, content: ui.Control) -> None: + def __init__(self, content: ui.Component) -> None: self.content = content self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device") diff --git a/core/src/trezor/ui/pin.py b/core/src/trezor/ui/pin.py index c66c2e2d81..a69cc0526a 100644 --- a/core/src/trezor/ui/pin.py +++ b/core/src/trezor/ui/pin.py @@ -29,7 +29,7 @@ def generate_digits() -> Iterable[int]: return digits[6:] + digits[3:6] + digits[:3] -class PinInput(ui.Control): +class PinInput(ui.Component): def __init__(self, prompt: str, subprompt: str, pin: str) -> None: self.prompt = prompt self.subprompt = subprompt diff --git a/core/src/trezor/ui/popup.py b/core/src/trezor/ui/popup.py index f1ed196a6d..b9ce61e31f 100644 --- a/core/src/trezor/ui/popup.py +++ b/core/src/trezor/ui/popup.py @@ -5,7 +5,7 @@ if False: class Popup(ui.Layout): - def __init__(self, content: ui.Control, time_ms: int = 0) -> None: + def __init__(self, content: ui.Component, time_ms: int = 0) -> None: self.content = content self.time_ms = time_ms diff --git a/core/src/trezor/ui/qr.py b/core/src/trezor/ui/qr.py index bde3de18e9..3389172ae0 100644 --- a/core/src/trezor/ui/qr.py +++ b/core/src/trezor/ui/qr.py @@ -1,7 +1,7 @@ from trezor import ui -class Qr(ui.Control): +class Qr(ui.Component): def __init__(self, data: bytes, x: int, y: int, scale: int): self.data = data self.x = x diff --git a/core/src/trezor/ui/scroll.py b/core/src/trezor/ui/scroll.py index e370211f55..9a3f0a8c5e 100644 --- a/core/src/trezor/ui/scroll.py +++ b/core/src/trezor/ui/scroll.py @@ -46,7 +46,7 @@ def render_swipe_text() -> None: class Paginated(ui.Layout): def __init__( - self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False + self, pages: Sequence[ui.Component], page: int = 0, one_by_one: bool = False ): self.pages = pages self.page = page @@ -77,7 +77,7 @@ class Paginated(ui.Layout): directions = SWIPE_VERTICAL if __debug__: - swipe = await loop.spawn(Swipe(directions), swipe_signal) + swipe = await loop.race(Swipe(directions), swipe_signal) else: swipe = await Swipe(directions) @@ -99,10 +99,10 @@ class Paginated(ui.Layout): raise ui.Result(self.page) -class PageWithButtons(ui.Control): +class PageWithButtons(ui.Component): def __init__( self, - content: ui.Control, + content: ui.Component, paginated: "PaginatedWithButtons", index: int, count: int, @@ -157,7 +157,7 @@ class PageWithButtons(ui.Control): class PaginatedWithButtons(ui.Layout): def __init__( - self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False + self, pages: Sequence[ui.Component], page: int = 0, one_by_one: bool = False ) -> None: self.pages = [ PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages) diff --git a/core/src/trezor/ui/shamir.py b/core/src/trezor/ui/shamir.py index 5f77639df8..cbd3bfe87b 100644 --- a/core/src/trezor/ui/shamir.py +++ b/core/src/trezor/ui/shamir.py @@ -3,7 +3,7 @@ from trezor.ui.button import Button from trezor.ui.text import LABEL_CENTER, Label -class NumInput(ui.Control): +class NumInput(ui.Component): def __init__(self, count: int = 5, max_count: int = 16, min_count: int = 1) -> None: self.count = count self.max_count = max_count diff --git a/core/src/trezor/ui/swipe.py b/core/src/trezor/ui/swipe.py index b5ecc25143..ffeae2c3a8 100644 --- a/core/src/trezor/ui/swipe.py +++ b/core/src/trezor/ui/swipe.py @@ -17,7 +17,7 @@ _SWIPE_DISTANCE = const(120) _SWIPE_TRESHOLD = const(30) -class Swipe(ui.Control): +class Swipe(ui.Component): def __init__(self, directions: int = SWIPE_ALL, area: ui.Area = None) -> None: if area is None: area = (0, 0, ui.WIDTH, ui.HEIGHT) diff --git a/core/src/trezor/ui/text.py b/core/src/trezor/ui/text.py index 86eeacac5d..9b059f5c9c 100644 --- a/core/src/trezor/ui/text.py +++ b/core/src/trezor/ui/text.py @@ -120,7 +120,7 @@ def render_text( offset_x += SPACE -class Text(ui.Control): +class Text(ui.Component): def __init__( self, header_text: str, @@ -177,7 +177,7 @@ LABEL_CENTER = const(1) LABEL_RIGHT = const(2) -class Label(ui.Control): +class Label(ui.Component): def __init__( self, area: ui.Area, diff --git a/core/src/trezor/ui/word_select.py b/core/src/trezor/ui/word_select.py index 38fb3ede18..b9b1c33ccd 100644 --- a/core/src/trezor/ui/word_select.py +++ b/core/src/trezor/ui/word_select.py @@ -5,7 +5,7 @@ from trezor.ui.button import Button class WordSelector(ui.Layout): - def __init__(self, content: ui.Control) -> None: + def __init__(self, content: ui.Component) -> None: self.content = content self.w12 = Button(ui.grid(6, n_y=4), "12") self.w12.on_click = self.on_w12 # type: ignore diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 74aeb19128..99b442dc23 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -1,18 +1,57 @@ +""" +# Wire + +Handles on-the-wire communication with a host computer. The communication is: + +- Request / response. +- Protobuf-encoded, see `protobuf.py`. +- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`. +- Transferred over USB interface, or UDP in case of Unix emulation. + +This module: + +1. Provides API for registering messages. In other words binds what functions are invoked + when some particular message is received. See the `add` function. +2. Runs workflows, also called `handlers`, to process the message. +3. Creates and passes the `Context` object to the handlers. This provides an interface to + wait, read, write etc. on the wire. + +## `add` function + +The `add` function registers what function is invoked when some particular `message_type` +is received. The following example binds the `apps.wallet.get_address` function with +the `GetAddress` message: + +```python +wire.add(MessageType.GetAddress, "apps.wallet", "get_address") +``` + +## Session handler + +When the `wire.setup` is called the `handle_session` coroutine is scheduled. The +`handle_session` waits for some messages to be received on some particular interface and +reads the message's header. When the message type is known the first handler is called. This way the +`handle_session` goes through all the workflows. + +""" + import protobuf from trezor import log, loop, messages, utils, workflow from trezor.messages import FailureType +from trezor.messages.Failure import Failure from trezor.wire import codec_v1 from trezor.wire.errors import Error -# import all errors into namespace, so that `wire.Error` is available elsewhere +# Import all errors into namespace, so that `wire.Error` is available from +# other packages. from trezor.wire.errors import * # isort:skip # noqa: F401,F403 if False: from typing import ( Any, Awaitable, - Dict, Callable, + Dict, Iterable, List, Optional, @@ -20,61 +59,58 @@ if False: Type, ) from trezorio import WireInterface - from protobuf import LoadedMessageType, MessageType Handler = Callable[..., loop.Task] -workflow_handlers = {} # type: Dict[int, Tuple[Handler, Iterable]] +# Maps a wire type directly to a handler. +workflow_handlers = {} # type: Dict[int, Handler] + +# Maps a wire type to a tuple of package and module. This allows handlers +# to be dynamically imported when such message arrives. +workflow_packages = {} # type: Dict[int, Tuple[str, str]] + +# Maps a wire type to a "keychain namespace". Such workflows are created +# with an instance of `seed.Keychain` with correctly derived keys. +workflow_namespaces = {} # type: Dict[int, List] -def add(mtype: int, pkgname: str, modname: str, namespace: List = None) -> None: +def add(wire_type: int, pkgname: str, modname: str, namespace: List = None) -> None: """Shortcut for registering a dynamically-imported Protobuf workflow.""" if namespace is not None: - register( - mtype, - protobuf_workflow, - keychain_workflow, - namespace, - import_workflow, - pkgname, - modname, - ) - else: - register(mtype, protobuf_workflow, import_workflow, pkgname, modname) + workflow_namespaces[wire_type] = namespace + workflow_packages[wire_type] = (pkgname, modname) -def register(mtype: int, handler: Handler, *args: Any) -> None: - """Register `handler` to get scheduled after `mtype` message is received.""" - if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType): - mtype = mtype.MESSAGE_WIRE_TYPE - if mtype in workflow_handlers: - raise KeyError - workflow_handlers[mtype] = (handler, args) +def register(wire_type: int, handler: Handler) -> None: + """Register `handler` to get scheduled after `wire_type` message is received.""" + workflow_handlers[wire_type] = handler def setup(iface: WireInterface) -> None: """Initialize the wire stack on passed USB interface.""" - loop.schedule(session_handler(iface, codec_v1.SESSION_ID)) + loop.schedule(handle_session(iface, codec_v1.SESSION_ID)) def clear() -> None: """Remove all registered handlers.""" workflow_handlers.clear() + workflow_packages.clear() + workflow_namespaces.clear() class DummyContext: - async def call(*argv): + async def call(*argv: Any) -> None: pass - async def read(*argv): + async def read(*argv: Any) -> None: pass - async def write(*argv): + async def write(*argv: Any) -> None: pass async def wait(self, *tasks: Awaitable) -> Any: - return await loop.spawn(*tasks) + return await loop.race(*tasks) class Context: @@ -83,43 +119,22 @@ class Context: self.sid = sid async def call( - self, msg: MessageType, exptype: Type[LoadedMessageType] - ) -> LoadedMessageType: + self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType] + ) -> protobuf.LoadedMessageType: await self.write(msg) del msg - return await self.read(exptype) + return await self.read(expected_type) - async def call_any(self, msg: MessageType, *allowed_types: int) -> MessageType: + async def call_any( + self, msg: protobuf.MessageType, *expected_wire_types: int + ) -> protobuf.MessageType: await self.write(msg) del msg - return await self.read_any(allowed_types) + return await self.read_any(expected_wire_types) async def read( - self, exptype: Optional[Type[LoadedMessageType]] - ) -> LoadedMessageType: - reader = self.make_reader() - - if __debug__: - log.debug( - __name__, "%s:%x expect: %s", self.iface.iface_num(), self.sid, exptype - ) - - await reader.aopen() # wait for the message header - - # if we got a message with unexpected type, raise the reader via - # `UnexpectedMessageError` and let the session handler deal with it - if exptype is None or reader.type != exptype.MESSAGE_WIRE_TYPE: - raise UnexpectedMessageError(reader) - - if __debug__: - log.debug( - __name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype - ) - - # parse the message and return it - return await protobuf.load_message(reader, exptype) - - async def read_any(self, allowed_types: Iterable[int]) -> MessageType: + self, expected_type: Type[protobuf.LoadedMessageType] + ) -> protobuf.LoadedMessageType: reader = self.make_reader() if __debug__: @@ -128,14 +143,51 @@ class Context: "%s:%x expect: %s", self.iface.iface_num(), self.sid, - allowed_types, + expected_type, ) - await reader.aopen() # wait for the message header + # Wait for the message header, contained in the first report. After + # we receive it, we have a message type to match on. + await reader.aopen() - # if we got a message with unexpected type, raise the reader via - # `UnexpectedMessageError` and let the session handler deal with it - if reader.type not in allowed_types: + # If we got a message with unexpected type, raise the reader via + # `UnexpectedMessageError` and let the session handler deal with it. + if reader.type != expected_type.MESSAGE_WIRE_TYPE: + raise UnexpectedMessageError(reader) + + if __debug__: + log.debug( + __name__, + "%s:%x read: %s", + self.iface.iface_num(), + self.sid, + expected_type, + ) + + # parse the message and return it + return await protobuf.load_message(reader, expected_type) + + async def read_any( + self, expected_wire_types: Iterable[int] + ) -> protobuf.MessageType: + reader = self.make_reader() + + if __debug__: + log.debug( + __name__, + "%s:%x expect: %s", + self.iface.iface_num(), + self.sid, + expected_wire_types, + ) + + # Wait for the message header, contained in the first report. After + # we receive it, we have a message type to match on. + await reader.aopen() + + # If we got a message with unexpected type, raise the reader via + # `UnexpectedMessageError` and let the session handler deal with it. + if reader.type not in expected_wire_types: raise UnexpectedMessageError(reader) # find the protobuf type @@ -172,7 +224,7 @@ class Context: while servicing the wire context. If a message comes until one of the tasks ends, `UnexpectedMessageError` is raised. """ - return loop.spawn(self.read(None), *tasks) + return loop.race(self.read_any(()), *tasks) def make_reader(self) -> codec_v1.Reader: return codec_v1.Reader(self.iface) @@ -183,120 +235,198 @@ class Context: class UnexpectedMessageError(Exception): def __init__(self, reader: codec_v1.Reader) -> None: - super().__init__() self.reader = reader -async def session_handler(iface: WireInterface, sid: int) -> None: - reader = None - ctx = Context(iface, sid) +async def handle_session(iface: WireInterface, session_id: int) -> None: + ctx = Context(iface, session_id) + next_reader = None # type: Optional[codec_v1.Reader] while True: try: - # wait for new message, if needed, and find handler - if not reader: - reader = ctx.make_reader() - await reader.aopen() - try: - handler, args = workflow_handlers[reader.type] - except KeyError: - handler, args = unexpected_msg, () + if next_reader is None: + # We are not currently reading a message, so let's wait for one. + # If the decoding fails, exception is raised and we try again + # (with the same `Reader` instance, it's OK). Even in case of + # de-synchronized wire communication, report with a message + # header is eventually received, after a couple of tries. + req_reader = ctx.make_reader() + await req_reader.aopen() + else: + # We have a reader left over from earlier. We should process + # this message instead of waiting for new one. + req_reader = next_reader + next_reader = None - m = utils.unimport_begin() - w = handler(ctx, reader, *args) - try: - workflow.onstart(w) - await w - finally: - workflow.onclose(w) - utils.unimport_end(m) + # Now we are in a middle of reading a message and we need to decide + # what to do with it, based on its type from the message header. + # From this point on, we should take care to read it in full and + # send a response. - except UnexpectedMessageError as exc: - # retry with opened reader from the exception - reader = exc.reader - continue - except Error as exc: - # we log wire.Error as warning, not as exception - if __debug__: - log.warning(__name__, "failure: %s", exc.message) - except Exception as exc: - # sessions are never closed by raised exceptions + # Take a mark of modules that are imported at this point, so we can + # roll back and un-import any others. Should not raise. + modules = utils.unimport_begin() + + # We need to find a handler for this message type. Should not + # raise. + handler = get_workflow_handler(req_reader) + + if handler is None: + # If no handler is found, we can skip decoding and directly + # respond with failure, but first, we should read the rest of + # the message reports. Should not raise. + await read_and_throw_away(req_reader) + res_msg = unexpected_message() + + else: + # We found a valid handler for this message type. + + # Workflow task, declared for the `workflow.on_close` call later. + wf_task = None # type: Optional[loop.Task] + + # Here we make sure we always respond with a Failure response + # in case of any errors. + try: + # Find a protobuf.MessageType subclass that describes this + # message. Raises if the type is not found. + req_type = messages.get_type(req_reader.type) + + # Try to decode the message according to schema from + # `req_type`. Raises if the message is malformed. + req_msg = await protobuf.load_message(req_reader, req_type) + + # At this point, message reports are all processed and + # correctly parsed into `req_msg`. + + # Create the workflow task. + wf_task = handler(ctx, req_msg) + + # Register the task into the workflow management system. + workflow.on_start(wf_task) + + # Run the workflow task. Workflow can do more on-the-wire + # communication inside, but it should eventually return a + # response message, or raise an exception (a rather common + # thing to do). Exceptions are handled in the code below. + res_msg = await wf_task + + except UnexpectedMessageError as exc: + # Workflow was trying to read a message from the wire, and + # something unexpected came in. See Context.read() for + # example, which expects some particular message and raises + # UnexpectedMessageError if another one comes in. + # In order not to lose the message, we pass on the reader + # to get picked up by the workflow logic in the beginning of + # the cycle, which processes it in the usual manner. + # TODO: + # We might handle only the few common cases here, like + # Initialize and Cancel. + next_reader = exc.reader + res_msg = None + + except Exception as exc: + # Either: + # - the first workflow message had a type that has a + # registered handler, but does not have a protobuf class + # - the first workflow message was not a valid protobuf + # - workflow raised some kind of an exception while running + if __debug__: + log.exception(__name__, exc) + res_msg = failure(exc) + + finally: + # De-register the task from the workflow system, if we + # registered it before. + if wf_task is not None: + workflow.on_close(wf_task) + + if res_msg is not None: + # Either the workflow returned a response, or we created one. + # Write it on the wire. Possibly, the incoming message haven't + # been read in full. We ignore this case here and let the rest + # of the reports get processed while waiting for the message + # header. + # TODO: if the write fails, we do not unimport the loaded modules + await ctx.write(res_msg) + + # Cleanup, so garbage collection triggered after un-importing can + # pick up the trash. + req_reader = None + req_type = None + req_msg = None + res_msg = None + handler = None + wf_task = None + + # Unload modules imported by the workflow. Should not raise. + utils.unimport_end(modules) + + except BaseException as exc: + # The session handling should never exit, just log and continue. if __debug__: log.exception(__name__, exc) - # read new message in next iteration - reader = None + +def get_workflow_handler(reader: codec_v1.Reader) -> Optional[Handler]: + msg_type = reader.type + + if msg_type in workflow_handlers: + # Message has a handler available, return it directly. + handler = workflow_handlers[msg_type] + + elif msg_type in workflow_packages: + # Message needs a dynamically imported handler, import it. + pkgname, modname = workflow_packages[msg_type] + handler = import_workflow(pkgname, modname) + + else: + # Message does not have any registered handler. + return None + + if msg_type in workflow_namespaces: + # Workflow needs a keychain, wrap it with a keychain provider. + namespace = workflow_namespaces[msg_type] + handler = wrap_keychain_workflow(handler, namespace) + + return handler -async def protobuf_workflow( - ctx: Context, reader: codec_v1.Reader, handler: Handler, *args: Any -) -> None: - from trezor.messages.Failure import Failure - - req = await protobuf.load_message(reader, messages.get_type(reader.type)) - - if __debug__: - log.debug(__name__, "%s:%x request: %s", ctx.iface.iface_num(), ctx.sid, req) - - try: - res = await handler(ctx, req, *args) - except UnexpectedMessageError: - # session handler takes care of this one - raise - except Error as exc: - # respond with specific code and message - await ctx.write(Failure(code=exc.code, message=exc.message)) - raise - except Exception as e: - # respond with a generic code and message - message = "Firmware error" - if __debug__: - message = "{}: {}".format(type(e), e) - await ctx.write(Failure(code=FailureType.FirmwareError, message=message)) - raise - if res: - # respond with a specific response - await ctx.write(res) - - -async def keychain_workflow( - ctx: Context, - req: protobuf.MessageType, - namespace: List, - handler: Handler, - *args: Any -) -> Any: - from apps.common import seed - - keychain = await seed.get_keychain(ctx, namespace) - args += (keychain,) - try: - return await handler(ctx, req, *args) - finally: - keychain.__del__() - - -def import_workflow( - ctx: Context, req: protobuf.MessageType, pkgname: str, modname: str, *args: Any -) -> Any: +def import_workflow(pkgname: str, modname: str) -> Handler: modpath = "%s.%s" % (pkgname, modname) module = __import__(modpath, None, None, (modname,), 0) # type: ignore handler = getattr(module, modname) - return handler(ctx, req, *args) + return handler -async def unexpected_msg(ctx: Context, reader: codec_v1.Reader) -> None: - from trezor.messages.Failure import Failure +def wrap_keychain_workflow(handler: Handler, namespace: List) -> Handler: + async def keychain_workflow(ctx: Context, req: protobuf.MessageType) -> Any: + from apps.common import seed - # receive the message and throw it away - await read_full_msg(reader) + # Workflow that is hiding behind `handler` expects a keychain + # instance, in addition to the request message. Acquire it from + # the seed module. More on-the-wire communication, and also UI + # interaction, might happen here. + keychain = await seed.get_keychain(ctx, namespace) + try: + return await handler(ctx, req, keychain) + finally: + # Be hygienic and wipe the keys from memory. + keychain.__del__() - # respond with an unknown message error - await ctx.write( - Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") - ) + return keychain_workflow -async def read_full_msg(reader: codec_v1.Reader) -> None: +def failure(exc: BaseException) -> Failure: + if isinstance(exc, Error): + return Failure(code=exc.code, message=exc.message) + else: + return Failure(code=FailureType.FirmwareError, message="Firmware error") + + +def unexpected_message() -> Failure: + return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message") + + +async def read_and_throw_away(reader: codec_v1.Reader) -> None: while reader.size > 0: buf = bytearray(reader.size) await reader.areadinto(buf) diff --git a/core/src/trezor/wire/codec_v1.py b/core/src/trezor/wire/codec_v1.py index 4cda632306..bcc8891426 100644 --- a/core/src/trezor/wire/codec_v1.py +++ b/core/src/trezor/wire/codec_v1.py @@ -20,7 +20,7 @@ INVALID_TYPE = const(-1) class Reader: """ - Decoder for legacy codec over the HID layer. Provides readable + Decoder for a wire codec over the HID (or UDP) layer. Provides readable async-file-like interface. """ @@ -33,9 +33,9 @@ class Reader: async def aopen(self) -> None: """ - Begin the message transmission by waiting for initial V2 message report - on this session. `self.type` and `self.size` are initialized and - available after `aopen()` returns. + Start reading a message by waiting for initial message report. Because + the first report contains the message header, `self.type` and + `self.size` are initialized and available after `aopen()` returns. """ read = loop.wait(self.iface.iface_num() | io.POLL_READ) while True: @@ -88,7 +88,7 @@ class Reader: class Writer: """ - Encoder for legacy codec over the HID layer. Provides writable + Encoder for a wire codec over the HID (or UDP) layer. Provides writable async-file-like interface. """ diff --git a/core/src/trezor/workflow.py b/core/src/trezor/workflow.py index b5ab85719a..7d2f6aa5a8 100644 --- a/core/src/trezor/workflow.py +++ b/core/src/trezor/workflow.py @@ -1,64 +1,66 @@ from trezor import loop if False: - from trezor import ui - from typing import List, Callable, Optional + from typing import Callable, Optional, Set -workflows = [] # type: List[loop.Task] -layouts = [] # type: List[ui.Layout] -layout_signal = loop.chan() -default = None # type: Optional[loop.Task] -default_layout = None # type: Optional[Callable[[], loop.Task]] +if __debug__: + # Used in `on_close` bellow for memory statistics. + + import micropython + + from trezor import utils -def onstart(w: loop.Task) -> None: - workflows.append(w) +# Set of workflow tasks. Multiple workflows can be running at the same time. +tasks = set() # type: Set[loop.Task] + +# Default workflow task, if a default workflow is running. Default workflow +# is not contained in the `tasks` set above. +default_task = None # type: Optional[loop.Task] + +# Constructor for the default workflow. Returns a workflow task. +default_constructor = None # type: Optional[Callable[[], loop.Task]] -def onclose(w: loop.Task) -> None: - workflows.remove(w) - if not layouts and default_layout: - startdefault(default_layout) +def on_start(workflow: loop.Task) -> None: + """ + Call after creating a workflow task, but before running it. You should + make sure to always call `on_close` when the task is finished. + """ + # Take note that this workflow task is running. + tasks.add(workflow) + +def on_close(workflow: loop.Task) -> None: + """Call when a workflow task has finished running.""" + # Remove task from the running set. + tasks.remove(workflow) + if not tasks and default_constructor: + # If no workflows are running, we should create a new default workflow + # and run it. + start_default(default_constructor) if __debug__: - import micropython - from trezor import utils - + # In debug builds, we dump a memory info right after a workflow is + # finished. if utils.LOG_MEMORY: micropython.mem_info() -def closedefault() -> None: - global default +def start_default(constructor: Callable[[], loop.Task]) -> None: + """Start a default workflow, created from `constructor`.""" + global default_task + global default_constructor - if default: - loop.close(default) - default = None + if not default_task: + default_constructor = constructor + default_task = constructor() + loop.schedule(default_task) -def startdefault(layout: Callable[[], loop.Task]) -> None: - global default - global default_layout +def close_default() -> None: + """Explicitly close the default workflow task.""" + global default_task - if not default: - default_layout = layout - default = layout() - loop.schedule(default) - - -def restartdefault() -> None: - global default_layout - - closedefault() - if default_layout: - startdefault(default_layout) - - -def onlayoutstart(l: ui.Layout) -> None: - closedefault() - layouts.append(l) - - -def onlayoutclose(l: ui.Layout) -> None: - if l in layouts: - layouts.remove(l) + if default_task: + loop.close(default_task) + default_task = None