1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-06-26 09:52:34 +00:00

core: improve code documentation, simplify wire package

* docs: improve loop.py, ui.init

* docs: improve trezor.loop, rename spawn to race

* docs: wire

* core/wire: simplify and document the session handler

* core/wire: improve documentation

* core/wire: improve docs

* core/docs: document ui.grid function

* core: decouple ui and workflow, document both

* core: improve docs


Co-authored-by: Tomas Susanka <tsusanka@gmail.com>
This commit is contained in:
Jan Pochyla 2019-08-20 16:20:02 +02:00 committed by GitHub
parent a25c9fd307
commit bb2556a22c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 589 additions and 340 deletions

View File

@ -16,7 +16,7 @@ if False:
async def confirm( async def confirm(
ctx: wire.Context, ctx: wire.Context,
content: ui.Control, content: ui.Component,
code: int = ButtonRequestType.Other, code: int = ButtonRequestType.Other,
confirm: ButtonContent = Confirm.DEFAULT_CONFIRM, confirm: ButtonContent = Confirm.DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = Confirm.DEFAULT_CONFIRM_STYLE,
@ -49,7 +49,7 @@ async def confirm(
async def hold_to_confirm( async def hold_to_confirm(
ctx: wire.Context, ctx: wire.Context,
content: ui.Control, content: ui.Component,
code: int = ButtonRequestType.Other, code: int = ButtonRequestType.Other,
confirm: ButtonContent = HoldToConfirm.DEFAULT_CONFIRM, confirm: ButtonContent = HoldToConfirm.DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = HoldToConfirm.DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = HoldToConfirm.DEFAULT_CONFIRM_STYLE,

View File

@ -70,7 +70,10 @@ def type_from_word_count(count: int) -> int:
def _start_progress() -> None: def _start_progress() -> None:
workflow.closedefault() # Because we are drawing to the screen manually, without a layout, we
# should make sure that no other layout is running. At this point, only
# the homescreen should be on, so shut it down.
workflow.close_default()
ui.backlight_fade(ui.BACKLIGHT_DIM) ui.backlight_fade(ui.BACKLIGHT_DIM)
ui.display.clear() ui.display.clear()
ui.header("Please wait") ui.header("Please wait")

View File

@ -25,7 +25,7 @@ async def request_pin(
while True: while True:
if __debug__: if __debug__:
result = await loop.spawn(dialog, input_signal) result = await loop.race(dialog, input_signal)
else: else:
result = await dialog result = await dialog
if result is CANCELLED: if result is CANCELLED:

View File

@ -6,7 +6,7 @@ if not __debug__:
if __debug__: if __debug__:
from trezor import config, log, loop, utils from trezor import config, log, loop, utils
from trezor.messages import MessageType from trezor.messages import MessageType
from trezor.wire import register, protobuf_workflow from trezor.wire import register
if False: if False:
from typing import Optional from typing import Optional
@ -79,9 +79,5 @@ if __debug__:
if not utils.EMULATOR: if not utils.EMULATOR:
config.wipe() config.wipe()
register( register(MessageType.DebugLinkDecision, dispatch_DebugLinkDecision)
MessageType.DebugLinkDecision, protobuf_workflow, dispatch_DebugLinkDecision register(MessageType.DebugLinkGetState, dispatch_DebugLinkGetState)
)
register(
MessageType.DebugLinkGetState, protobuf_workflow, dispatch_DebugLinkGetState
)

View File

@ -2,7 +2,7 @@ from trezor import config, utils, wire
from trezor.messages import MessageType from trezor.messages import MessageType
from trezor.messages.Features import Features from trezor.messages.Features import Features
from trezor.messages.Success import Success from trezor.messages.Success import Success
from trezor.wire import protobuf_workflow, register from trezor.wire import register
from apps.common import cache, storage from apps.common import cache, storage
@ -75,9 +75,9 @@ async def handle_Ping(ctx: wire.Context, msg: Ping) -> Success:
def boot(features_only: bool = False) -> None: def boot(features_only: bool = False) -> None:
register(MessageType.Initialize, protobuf_workflow, handle_Initialize) register(MessageType.Initialize, handle_Initialize)
register(MessageType.GetFeatures, protobuf_workflow, handle_GetFeatures) register(MessageType.GetFeatures, handle_GetFeatures)
if not features_only: if not features_only:
register(MessageType.Cancel, protobuf_workflow, handle_Cancel) register(MessageType.Cancel, handle_Cancel)
register(MessageType.ClearSession, protobuf_workflow, handle_ClearSession) register(MessageType.ClearSession, handle_ClearSession)
register(MessageType.Ping, protobuf_workflow, handle_Ping) register(MessageType.Ping, handle_Ping)

View File

@ -424,7 +424,7 @@ def _slip39_split_share_into_pages(share_words):
return first, list(chunks), last return first, list(chunks), last
class ShamirNumInput(ui.Control): class ShamirNumInput(ui.Component):
SET_SHARES = object() SET_SHARES = object()
SET_THRESHOLD = object() SET_THRESHOLD = object()

View File

@ -85,7 +85,7 @@ class InputButton(Button):
display.icon(ix, iy, self.icon, fg_color, bg_color) display.icon(ix, iy, self.icon, fg_color, bg_color)
class Prompt(ui.Control): class Prompt(ui.Component):
def __init__(self, prompt: str) -> None: def __init__(self, prompt: str) -> None:
self.prompt = prompt self.prompt = prompt
self.repaint = True self.repaint = True
@ -192,17 +192,17 @@ class Bip39Keyboard(ui.Layout):
async def handle_input(self) -> None: async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1) timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch) race_touch = loop.race(touch)
spawn_timeout = loop.spawn(touch, timeout) race_timeout = loop.race(touch, timeout)
while True: while True:
if self.pending_button is not None: if self.pending_button is not None:
spawn = spawn_timeout race = race_timeout
else: else:
spawn = spawn_touch race = race_touch
result = await spawn result = await race
if touch in spawn.finished: if touch in race.finished:
event, x, y = result event, x, y = result
self.dispatch(event, x, y) self.dispatch(event, x, y)
else: else:

View File

@ -169,7 +169,7 @@ async def show_identifier_mismatch(ctx: wire.Context) -> None:
) )
class RecoveryHomescreen(ui.Control): class RecoveryHomescreen(ui.Component):
def __init__(self, text: str, subtext: str = None): def __init__(self, text: str, subtext: str = None):
self.text = text self.text = text
self.subtext = subtext self.subtext = subtext

View File

@ -88,7 +88,7 @@ class InputButton(Button):
display.icon(ix, iy, self.icon, fg_color, bg_color) display.icon(ix, iy, self.icon, fg_color, bg_color)
class Prompt(ui.Control): class Prompt(ui.Component):
def __init__(self, prompt: str) -> None: def __init__(self, prompt: str) -> None:
self.prompt = prompt self.prompt = prompt
self.repaint = True self.repaint = True
@ -202,17 +202,17 @@ class Slip39Keyboard(ui.Layout):
async def handle_input(self) -> None: async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1) timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch) race_touch = loop.race(touch)
spawn_timeout = loop.spawn(touch, timeout) race_timeout = loop.race(touch, timeout)
while True: while True:
if self.pending_button is not None: if self.pending_button is not None:
spawn = spawn_timeout race = race_timeout
else: else:
spawn = spawn_touch race = race_touch
result = await spawn result = await race
if touch in spawn.finished: if touch in race.finished:
event, x, y = result event, x, y = result
self.dispatch(event, x, y) self.dispatch(event, x, y)
else: else:

View File

@ -30,7 +30,7 @@ async def naive_pagination(
while True: while True:
await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck) await ctx.call(ButtonRequest(code=ButtonRequestType.SignTx), ButtonAck)
if __debug__: if __debug__:
result = await loop.spawn(paginated, confirm_signal) result = await loop.race(paginated, confirm_signal)
else: else:
result = await paginated result = await paginated
if result is CONFIRMED: if result is CONFIRMED:

View File

@ -117,7 +117,7 @@ async def _require_confirm_fee(ctx, fee):
await require_hold_to_confirm(ctx, content, ButtonRequestType.ConfirmOutput) await require_hold_to_confirm(ctx, content, ButtonRequestType.ConfirmOutput)
class TransactionStep(ui.Control): class TransactionStep(ui.Component):
def __init__(self, state, info): def __init__(self, state, info):
self.state = state self.state = state
self.info = info self.info = info
@ -133,7 +133,7 @@ class TransactionStep(ui.Control):
ui.display.text_center(ui.WIDTH // 2, 235, info[1], ui.NORMAL, ui.FG, ui.BG) ui.display.text_center(ui.WIDTH // 2, 235, info[1], ui.NORMAL, ui.FG, ui.BG)
class KeyImageSyncStep(ui.Control): class KeyImageSyncStep(ui.Component):
def __init__(self, current, total_num): def __init__(self, current, total_num):
self.current = current self.current = current
self.total_num = total_num self.total_num = total_num
@ -146,7 +146,7 @@ class KeyImageSyncStep(ui.Control):
ui.display.loader(p, False, 18, ui.WHITE, ui.BG) ui.display.loader(p, False, 18, ui.WHITE, ui.BG)
class LiveRefreshStep(ui.Control): class LiveRefreshStep(ui.Component):
def __init__(self, current): def __init__(self, current):
self.current = current self.current = current

View File

@ -377,10 +377,10 @@ class ConfirmState:
async def confirm_workflow(self) -> None: async def confirm_workflow(self) -> None:
try: try:
workflow.onstart(self.workflow) workflow.on_start(self.workflow)
await self.confirm_layout() await self.confirm_layout()
finally: finally:
workflow.onclose(self.workflow) workflow.on_close(self.workflow)
self.workflow = None self.workflow = None
async def confirm_layout(self) -> None: async def confirm_layout(self) -> None:
@ -402,7 +402,7 @@ class ConfirmState:
self.confirmed = await dialog is CONFIRMED self.confirmed = await dialog is CONFIRMED
class ConfirmContent(ui.Control): class ConfirmContent(ui.Component):
def __init__(self, action: int, app_id: bytes) -> None: def __init__(self, action: int, app_id: bytes) -> None:
self.action = action self.action = action
self.app_id = app_id self.app_id = app_id

View File

@ -70,7 +70,7 @@ def _boot_default() -> None:
# run main event loop and specify which screen is the default # run main event loop and specify which screen is the default
from apps.homescreen.homescreen import homescreen from apps.homescreen.homescreen import homescreen
workflow.startdefault(homescreen) workflow.start_default(homescreen)
from trezor import loop, wire, workflow from trezor import loop, wire, workflow

View File

@ -4,7 +4,7 @@ the form of python coroutines (either plain generators or `async` functions) are
stepped through until completion, and can get asynchronously blocked by stepped through until completion, and can get asynchronously blocked by
`yield`ing or `await`ing a syscall. `yield`ing or `await`ing a syscall.
See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `spawn`. See `schedule`, `run`, and syscalls `sleep`, `wait`, `signal` and `race`.
""" """
import utime import utime
@ -57,6 +57,8 @@ def schedule(
""" """
Schedule task to be executed with `value` on given `deadline` (in Schedule task to be executed with `value` on given `deadline` (in
microseconds). Does not start the event loop itself, see `run`. microseconds). Does not start the event loop itself, see `run`.
Usually done in very low-level cases, see `race` for more user-friendly
and correct concept.
""" """
if deadline is None: if deadline is None:
deadline = utime.ticks_us() deadline = utime.ticks_us()
@ -66,6 +68,11 @@ def schedule(
def pause(task: Task, iface: int) -> None: def pause(task: Task, iface: int) -> None:
"""
Block task on given message interface. Task is resumed when the interface
is activated. It is most probably wrong to call `pause` from user code,
see the `wait` syscall for the correct concept.
"""
tasks = _paused.get(iface, None) tasks = _paused.get(iface, None)
if tasks is None: if tasks is None:
tasks = _paused[iface] = set() tasks = _paused[iface] = set()
@ -73,12 +80,17 @@ def pause(task: Task, iface: int) -> None:
def finalize(task: Task, value: Any) -> None: def finalize(task: Task, value: Any) -> None:
"""Call and remove any finalization callbacks registered for given task."""
fn = _finalizers.pop(id(task), None) fn = _finalizers.pop(id(task), None)
if fn is not None: if fn is not None:
fn(task, value) fn(task, value)
def close(task: Task) -> None: def close(task: Task) -> None:
"""
Deschedule and unblock a task, close it so it can release all resources, and
call its finalizer.
"""
for iface in _paused: for iface in _paused:
_paused[iface].discard(task) _paused[iface].discard(task)
_queue.discard(task) _queue.discard(task)
@ -137,6 +149,21 @@ def clear() -> None:
def _step(task: Task, value: Any) -> None: def _step(task: Task, value: Any) -> None:
"""
Step through the task by sending `value` to `Task`. This can result in either:
1. The task raises an exception:
a) StopIteration
- The Task is completed and we call finalize to finish it.
b) Exception
- An error occurred. We still need to call finalize.
2. Task does not raise exception and returns either:
a) Syscall
- Syscall.handle is called.
b) None
- The Task is simply scheduled to continue.
c) Something else
- That should not happen - error.
"""
try: try:
if isinstance(value, BaseException): if isinstance(value, BaseException):
result = task.throw(value) # type: ignore result = task.throw(value) # type: ignore
@ -144,7 +171,7 @@ def _step(task: Task, value: Any) -> None:
# rationale: In micropython, generator.throw() accepts the exception object directly. # rationale: In micropython, generator.throw() accepts the exception object directly.
else: else:
result = task.send(value) result = task.send(value)
except StopIteration as e: # as e: except StopIteration as e:
if __debug__: if __debug__:
log.debug(__name__, "finish: %s", task) log.debug(__name__, "finish: %s", task)
finalize(task, e.value) finalize(task, e.value)
@ -205,7 +232,7 @@ class wait(Syscall):
""" """
Pause current task, and resume only after a message on `msg_iface` is Pause current task, and resume only after a message on `msg_iface` is
received. Messages are received either from an USB interface, or the received. Messages are received either from an USB interface, or the
touch display. Result value a tuple of message values. touch display. Result value is a tuple of message values.
Example: Example:
@ -223,29 +250,33 @@ class wait(Syscall):
_type_gen = type((lambda: (yield))()) _type_gen = type((lambda: (yield))())
class spawn(Syscall): class race(Syscall):
""" """
Execute one or more children tasks and wait until one of them exits. Given a list of either children tasks or syscalls, `race` waits until one of
Return value of `spawn` is the return value of task that triggered the them completes (tasks are executed in parallel, syscalls are waited upon,
completion. By default, `spawn` returns after the first child completes, and directly). Return value of `race` is the return value of the child that
other running children are killed (by cancelling any pending schedules and triggered the completion. Other running children are killed (by cancelling
calling `close()`). any pending schedules and raising a `GeneratorExit` by calling `close()`).
Child that caused the completion is present in `self.finished`.
Example: Example:
>>> # async def wait_for_touch(): ... >>> # async def wait_for_touch(): ...
>>> # async def animate_logo(): ... >>> # async def animate_logo(): ...
>>> some_signal = loop.signal()
>>> touch_task = wait_for_touch() >>> touch_task = wait_for_touch()
>>> animation_task = animate_logo() >>> animation_task = animate_logo()
>>> waiter = loop.spawn(touch_task, animation_task) >>> racer = loop.race(some_signal, touch_task, animation_task)
>>> result = await waiter >>> result = await racer
>>> if animation_task in waiter.finished: >>> if animation_task in racer.finished:
>>> print('animation task returned', result) >>> print('animation task returned value:', result)
>>> elif touch_task in racer.finished:
>>> print('touch task returned value:', result)
>>> else: >>> else:
>>> print('touch task returned', result) >>> print('signal was triggered with value:', result)
Note: You should not directly `yield` a `spawn` instance, see logic in Note: You should not directly `yield` a `race` instance, see logic in
`spawn.__iter__` for explanation. Always use `await`. `race.__iter__` for explanation. Always use `await`.
""" """
def __init__(self, *children: Awaitable, exit_others: bool = True) -> None: def __init__(self, *children: Awaitable, exit_others: bool = True) -> None:
@ -255,6 +286,9 @@ class spawn(Syscall):
self.scheduled = [] # type: List[Task] # scheduled wrapper tasks self.scheduled = [] # type: List[Task] # scheduled wrapper tasks
def handle(self, task: Task) -> None: def handle(self, task: Task) -> None:
"""
Schedule all children Tasks and set `task` as callback.
"""
finalizer = self._finish finalizer = self._finish
scheduled = self.scheduled scheduled = self.scheduled
finished = self.finished finished = self.finished
@ -279,6 +313,8 @@ class spawn(Syscall):
def _finish(self, task: Task, result: Any) -> None: def _finish(self, task: Task, result: Any) -> None:
if not self.finished: if not self.finished:
# because we create tasks for children that are not generators yet,
# we need to find the child value that the caller supplied
for index, child_task in enumerate(self.scheduled): for index, child_task in enumerate(self.scheduled):
if child_task is task: if child_task is task:
child = self.children[index] child = self.children[index]

View File

@ -3,7 +3,7 @@ import utime
from micropython import const from micropython import const
from trezorui import Display from trezorui import Display
from trezor import io, loop, res, utils, workflow from trezor import io, loop, res, utils
if False: if False:
from typing import Any, Generator, Iterable, Tuple, TypeVar from typing import Any, Generator, Iterable, Tuple, TypeVar
@ -12,9 +12,25 @@ if False:
Area = Tuple[int, int, int, int] Area = Tuple[int, int, int, int]
ResultValue = TypeVar("ResultValue") ResultValue = TypeVar("ResultValue")
# all rendering is done through a singleton of `Display`
display = Display() display = Display()
# re-export constants from modtrezorui
NORMAL = Display.FONT_NORMAL
BOLD = Display.FONT_BOLD
MONO = Display.FONT_MONO
MONO_BOLD = Display.FONT_MONO_BOLD
SIZE = Display.FONT_SIZE
WIDTH = Display.WIDTH
HEIGHT = Display.HEIGHT
# viewport margins
VIEWX = const(6)
VIEWY = const(9)
# channel used to cancel layouts, see `Cancelled` exception
layout_chan = loop.chan()
# in debug mode, display an indicator in top right corner # in debug mode, display an indicator in top right corner
if __debug__: if __debug__:
@ -30,19 +46,6 @@ if __debug__:
elif utils.EMULATOR: elif utils.EMULATOR:
loop.after_step_hook = display.refresh loop.after_step_hook = display.refresh
# re-export constants from modtrezorui
NORMAL = Display.FONT_NORMAL
BOLD = Display.FONT_BOLD
MONO = Display.FONT_MONO
MONO_BOLD = Display.FONT_MONO_BOLD
SIZE = Display.FONT_SIZE
WIDTH = Display.WIDTH
HEIGHT = Display.HEIGHT
# viewport margins
VIEWX = const(6)
VIEWY = const(9)
def lerpi(a: int, b: int, t: float) -> int: def lerpi(a: int, b: int, t: float) -> int:
return int(a + t * (b - a)) return int(a + t * (b - a))
@ -67,9 +70,9 @@ from trezor.ui import style # isort:skip
from trezor.ui.style import * # isort:skip # noqa: F401,F403 from trezor.ui.style import * # isort:skip # noqa: F401,F403
def pulse(delay: int) -> float: def pulse(coef: int) -> float:
# normalize sin from interval -1:1 to 0:1 # normalize sin from interval -1:1 to 0:1
return 0.5 + 0.5 * math.sin(utime.ticks_us() / delay) return 0.5 + 0.5 * math.sin(utime.ticks_us() / coef)
async def click() -> Pos: async def click() -> Pos:
@ -111,7 +114,6 @@ def header(
def header_warning(message: str, clear: bool = True) -> None: def header_warning(message: str, clear: bool = True) -> None:
# TODO: review: is the clear=True really needed?
display.bar(0, 0, WIDTH, 30, style.YELLOW) display.bar(0, 0, WIDTH, 30, style.YELLOW)
display.text_center(WIDTH // 2, 22, message, BOLD, style.BLACK, style.YELLOW) display.text_center(WIDTH // 2, 22, message, BOLD, style.BLACK, style.YELLOW)
if clear: if clear:
@ -119,7 +121,6 @@ def header_warning(message: str, clear: bool = True) -> None:
def header_error(message: str, clear: bool = True) -> None: def header_error(message: str, clear: bool = True) -> None:
# TODO: review: as above
display.bar(0, 0, WIDTH, 30, style.RED) display.bar(0, 0, WIDTH, 30, style.RED)
display.text_center(WIDTH // 2, 22, message, BOLD, style.WHITE, style.RED) display.text_center(WIDTH // 2, 22, message, BOLD, style.WHITE, style.RED)
if clear: if clear:
@ -127,17 +128,31 @@ def header_error(message: str, clear: bool = True) -> None:
def grid( def grid(
i: int, i: int, # i-th cell of the table of which we wish to return Area (snake-like starting with 0)
n_x: int = 3, n_x: int = 3, # number of rows in the table
n_y: int = 5, n_y: int = 5, # number of columns in the table
start_x: int = VIEWX, start_x: int = VIEWX, # where the table starts on x-axis
start_y: int = VIEWY, start_y: int = VIEWY, # where the table starts on y-axis
end_x: int = (WIDTH - VIEWX), end_x: int = (WIDTH - VIEWX), # where the table ends on x-axis
end_y: int = (HEIGHT - VIEWY), end_y: int = (HEIGHT - VIEWY), # where the table ends on y-axis
cells_x: int = 1, cells_x: int = 1, # number of cells to be merged into one in the direction of x-axis
cells_y: int = 1, cells_y: int = 1, # number of cells to be merged into one in the direction of y-axis
spacing: int = 0, spacing: int = 0, # spacing size between cells
) -> Area: ) -> Area:
"""
Returns area (tuple of four integers, in pixels) of a cell on i-th possition
in a table you define yourself. Example:
>>> ui.grid(4, n_x=2, n_y=3, start_x=20, start_y=20)
(20, 160, 107, 70)
Returns 5th cell from the following table. It has two columns, three rows
and starts on coordinates 20-20.
|____|____|
|____|____|
|XXXX|____|
"""
w = (end_x - start_x) // n_x w = (end_x - start_x) // n_x
h = (end_y - start_y) // n_y h = (end_y - start_y) // n_y
x = (i % n_x) * w x = (i % n_x) * w
@ -150,12 +165,30 @@ def in_area(area: Area, x: int, y: int) -> bool:
return ax <= x <= ax + aw and ay <= y <= ay + ah return ax <= x <= ax + aw and ay <= y <= ay + ah
# render events # Component events. Should be different from `io.TOUCH_*` events.
# Event dispatched when components should draw to the display, if they are
# marked for re-paint.
RENDER = const(-255) RENDER = const(-255)
# Event dispatched when components should mark themselves for re-painting.
REPAINT = const(-256) REPAINT = const(-256)
# How long, in microseconds, should the layout rendering task sleep betweeen
# the render calls.
_RENDER_DELAY_US = const(10000) # 10 msec
class Component:
"""
Abstract class.
Components are GUI classes that inherit `Component` and form a tree, with a
`Layout` at the root, and other components underneath. Components that
have children, and therefore need to dispatch events to them, usually
override the `dispatch` method. Leaf components usually override the event
methods (`on_*`). Components signal a completion to the layout by raising
an instance of `Result`.
"""
class Control:
def dispatch(self, event: int, x: int, y: int) -> None: def dispatch(self, event: int, x: int, y: int) -> None:
if event is RENDER: if event is RENDER:
self.on_render() self.on_render()
@ -181,58 +214,107 @@ class Control:
pass pass
_RENDER_DELAY_US = const(10000) # 10 msec
class LayoutCancelled(Exception):
pass
class Result(Exception): class Result(Exception):
"""
When components want to trigger layout completion, they do so through
raising an instance of `Result`.
See `Layout.__iter__` for details.
"""
def __init__(self, value: ResultValue) -> None: def __init__(self, value: ResultValue) -> None:
self.value = value self.value = value
class Layout(Control): class Cancelled(Exception):
""" """
Layouts can be explicitly cancelled. This usually happens when another
layout starts, because only one layout can be running at the same time,
and is done by raising `Cancelled` on the cancelled layout. Layouts
should always re-raise such exceptions.
See `Layout.__iter__` for details.
"""
pass
class Layout(Component):
"""
Abstract class.
Layouts are top-level components. Only one layout can be running at the
same time. Layouts provide asynchronous interface, so a running task can
wait for the layout to complete. Layouts complete when a `Result` is
raised, usually from some of the child components.
""" """
async def __iter__(self) -> ResultValue: async def __iter__(self) -> ResultValue:
"""
Run the layout and wait until it completes. Returns the result value.
Usually not overriden.
"""
value = None value = None
try: try:
if workflow.layout_signal.takers: # If any other layout is running (waiting on the layout channel),
await workflow.layout_signal.put(LayoutCancelled()) # we close it with the Cancelled exception, and wait until it is
workflow.onlayoutstart(self) # closed, just to be sure.
if layout_chan.takers:
await layout_chan.put(Cancelled())
# Now, no other layout should be running. In a loop, we create new
# layout tasks and execute them in parallel, while waiting on the
# layout channel. This allows other layouts to cancel us, and the
# layout tasks to trigger restart by exiting (new tasks are created
# and we continue, because we are in a loop).
while True: while True:
layout_tasks = self.create_tasks() await loop.race(layout_chan.take, *self.create_tasks())
await loop.spawn(workflow.layout_signal.take, *layout_tasks)
except Result as result: except Result as result:
# Result exception was raised, this means this layout is complete.
value = result.value value = result.value
finally:
workflow.onlayoutclose(self)
return value return value
def __await__(self) -> Generator[Any, Any, ResultValue]: def __await__(self) -> Generator[Any, Any, ResultValue]:
return self.__iter__() # type: ignore return self.__iter__() # type: ignore
def create_tasks(self) -> Iterable[loop.Task]: def create_tasks(self) -> Iterable[loop.Task]:
"""
Called from `__iter__`. Creates and returns a sequence of tasks that
run this layout. Tasks are executed in parallel. When one of them
returns, the others are closed and `create_tasks` is called again.
Usually overriden to add another task to the list."""
return self.handle_input(), self.handle_rendering() return self.handle_input(), self.handle_rendering()
def handle_input(self) -> loop.Task: # type: ignore def handle_input(self) -> loop.Task: # type: ignore
"""Task that is waiting for the user input."""
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
while True: while True:
event, x, y = yield touch event, x, y = yield touch
self.dispatch(event, x, y) self.dispatch(event, x, y)
# We dispatch a render event right after the touch. Quick and dirty
# way to get the lowest input-to-render latency.
self.dispatch(RENDER, 0, 0) self.dispatch(RENDER, 0, 0)
def handle_rendering(self) -> loop.Task: # type: ignore def handle_rendering(self) -> loop.Task: # type: ignore
"""Task that is rendering the layout in a busy loop."""
# Before the first render, we dim the display.
backlight_fade(style.BACKLIGHT_DIM) backlight_fade(style.BACKLIGHT_DIM)
# Clear the screen of any leftovers, make sure everything is marked for
# repaint (we can be running the same layout instance multiple times)
# and paint it.
display.clear() display.clear()
self.dispatch(REPAINT, 0, 0) self.dispatch(REPAINT, 0, 0)
self.dispatch(RENDER, 0, 0) self.dispatch(RENDER, 0, 0)
# Display is usually refreshed after every loop step, but here we are
# rendering everything synchronously, so refresh it manually and turn
# the brightness on again.
display.refresh() display.refresh()
backlight_fade(style.BACKLIGHT_NORMAL) backlight_fade(style.BACKLIGHT_NORMAL)
sleep = loop.sleep(_RENDER_DELAY_US) sleep = loop.sleep(_RENDER_DELAY_US)
while True: while True:
self.dispatch(RENDER, 0, 0) # Wait for a couple of ms and render the layout again. Because
# components use re-paint marking, they do not really draw on the
# display needlessly.
# TODO: remove the busy loop
yield sleep yield sleep
self.dispatch(RENDER, 0, 0)

View File

@ -118,7 +118,7 @@ _ICON = const(16) # icon size in pixels
_BORDER = const(4) # border size in pixels _BORDER = const(4) # border size in pixels
class Button(ui.Control): class Button(ui.Component):
def __init__( def __init__(
self, self,
area: ui.Area, area: ui.Area,

View File

@ -13,7 +13,7 @@ _CHECKLIST_OFFSET_X = const(24)
_CHECKLIST_OFFSET_X_ICON = const(0) _CHECKLIST_OFFSET_X_ICON = const(0)
class Checklist(ui.Control): class Checklist(ui.Component):
def __init__(self, title: str, icon: str) -> None: def __init__(self, title: str, icon: str) -> None:
self.title = title self.title = title
self.icon = icon self.icon = icon

View File

@ -19,7 +19,7 @@ class Confirm(ui.Layout):
def __init__( def __init__(
self, self,
content: ui.Control, content: ui.Component,
confirm: Optional[ButtonContent] = DEFAULT_CONFIRM, confirm: Optional[ButtonContent] = DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE,
cancel: Optional[ButtonContent] = DEFAULT_CANCEL, cancel: Optional[ButtonContent] = DEFAULT_CANCEL,
@ -75,7 +75,7 @@ class HoldToConfirm(ui.Layout):
def __init__( def __init__(
self, self,
content: ui.Control, content: ui.Component,
confirm: str = DEFAULT_CONFIRM, confirm: str = DEFAULT_CONFIRM,
confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE, confirm_style: ButtonStyleType = DEFAULT_CONFIRM_STYLE,
loader_style: LoaderStyleType = DEFAULT_LOADER_STYLE, loader_style: LoaderStyleType = DEFAULT_LOADER_STYLE,

View File

@ -1,8 +1,8 @@
from trezor import ui from trezor import ui
class Container(ui.Control): class Container(ui.Component):
def __init__(self, *children: ui.Control): def __init__(self, *children: ui.Component):
self.children = children self.children = children
def dispatch(self, event: int, x: int, y: int) -> None: def dispatch(self, event: int, x: int, y: int) -> None:

View File

@ -37,7 +37,7 @@ if False:
_TARGET_MS = const(1000) _TARGET_MS = const(1000)
class Loader(ui.Control): class Loader(ui.Component):
def __init__(self, style: LoaderStyleType = LoaderDefault) -> None: def __init__(self, style: LoaderStyleType = LoaderDefault) -> None:
self.normal_style = style.normal self.normal_style = style.normal
self.active_style = style.active self.active_style = style.active

View File

@ -114,7 +114,7 @@ class Input(Button):
pass pass
class Prompt(ui.Control): class Prompt(ui.Component):
def __init__(self, text: str) -> None: def __init__(self, text: str) -> None:
self.text = text self.text = text
self.repaint = True self.repaint = True
@ -210,17 +210,17 @@ class PassphraseKeyboard(ui.Layout):
async def handle_input(self) -> None: async def handle_input(self) -> None:
touch = loop.wait(io.TOUCH) touch = loop.wait(io.TOUCH)
timeout = loop.sleep(1000 * 1000 * 1) timeout = loop.sleep(1000 * 1000 * 1)
spawn_touch = loop.spawn(touch) race_touch = loop.race(touch)
spawn_timeout = loop.spawn(touch, timeout) race_timeout = loop.race(touch, timeout)
while True: while True:
if self.pending_button is not None: if self.pending_button is not None:
spawn = spawn_timeout race = race_timeout
else: else:
spawn = spawn_touch race = race_touch
result = await spawn result = await race
if touch in spawn.finished: if touch in race.finished:
event, x, y = result event, x, y = result
self.dispatch(event, x, y) self.dispatch(event, x, y)
else: else:
@ -249,7 +249,7 @@ class PassphraseKeyboard(ui.Layout):
class PassphraseSource(ui.Layout): class PassphraseSource(ui.Layout):
def __init__(self, content: ui.Control) -> None: def __init__(self, content: ui.Component) -> None:
self.content = content self.content = content
self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device") self.device = Button(ui.grid(8, n_y=4, n_x=4, cells_x=4), "Device")

View File

@ -29,7 +29,7 @@ def generate_digits() -> Iterable[int]:
return digits[6:] + digits[3:6] + digits[:3] return digits[6:] + digits[3:6] + digits[:3]
class PinInput(ui.Control): class PinInput(ui.Component):
def __init__(self, prompt: str, subprompt: str, pin: str) -> None: def __init__(self, prompt: str, subprompt: str, pin: str) -> None:
self.prompt = prompt self.prompt = prompt
self.subprompt = subprompt self.subprompt = subprompt

View File

@ -5,7 +5,7 @@ if False:
class Popup(ui.Layout): class Popup(ui.Layout):
def __init__(self, content: ui.Control, time_ms: int = 0) -> None: def __init__(self, content: ui.Component, time_ms: int = 0) -> None:
self.content = content self.content = content
self.time_ms = time_ms self.time_ms = time_ms

View File

@ -1,7 +1,7 @@
from trezor import ui from trezor import ui
class Qr(ui.Control): class Qr(ui.Component):
def __init__(self, data: bytes, x: int, y: int, scale: int): def __init__(self, data: bytes, x: int, y: int, scale: int):
self.data = data self.data = data
self.x = x self.x = x

View File

@ -46,7 +46,7 @@ def render_swipe_text() -> None:
class Paginated(ui.Layout): class Paginated(ui.Layout):
def __init__( def __init__(
self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False self, pages: Sequence[ui.Component], page: int = 0, one_by_one: bool = False
): ):
self.pages = pages self.pages = pages
self.page = page self.page = page
@ -77,7 +77,7 @@ class Paginated(ui.Layout):
directions = SWIPE_VERTICAL directions = SWIPE_VERTICAL
if __debug__: if __debug__:
swipe = await loop.spawn(Swipe(directions), swipe_signal) swipe = await loop.race(Swipe(directions), swipe_signal)
else: else:
swipe = await Swipe(directions) swipe = await Swipe(directions)
@ -99,10 +99,10 @@ class Paginated(ui.Layout):
raise ui.Result(self.page) raise ui.Result(self.page)
class PageWithButtons(ui.Control): class PageWithButtons(ui.Component):
def __init__( def __init__(
self, self,
content: ui.Control, content: ui.Component,
paginated: "PaginatedWithButtons", paginated: "PaginatedWithButtons",
index: int, index: int,
count: int, count: int,
@ -157,7 +157,7 @@ class PageWithButtons(ui.Control):
class PaginatedWithButtons(ui.Layout): class PaginatedWithButtons(ui.Layout):
def __init__( def __init__(
self, pages: Sequence[ui.Control], page: int = 0, one_by_one: bool = False self, pages: Sequence[ui.Component], page: int = 0, one_by_one: bool = False
) -> None: ) -> None:
self.pages = [ self.pages = [
PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages) PageWithButtons(p, self, i, len(pages)) for i, p in enumerate(pages)

View File

@ -3,7 +3,7 @@ from trezor.ui.button import Button
from trezor.ui.text import LABEL_CENTER, Label from trezor.ui.text import LABEL_CENTER, Label
class NumInput(ui.Control): class NumInput(ui.Component):
def __init__(self, count: int = 5, max_count: int = 16, min_count: int = 1) -> None: def __init__(self, count: int = 5, max_count: int = 16, min_count: int = 1) -> None:
self.count = count self.count = count
self.max_count = max_count self.max_count = max_count

View File

@ -17,7 +17,7 @@ _SWIPE_DISTANCE = const(120)
_SWIPE_TRESHOLD = const(30) _SWIPE_TRESHOLD = const(30)
class Swipe(ui.Control): class Swipe(ui.Component):
def __init__(self, directions: int = SWIPE_ALL, area: ui.Area = None) -> None: def __init__(self, directions: int = SWIPE_ALL, area: ui.Area = None) -> None:
if area is None: if area is None:
area = (0, 0, ui.WIDTH, ui.HEIGHT) area = (0, 0, ui.WIDTH, ui.HEIGHT)

View File

@ -120,7 +120,7 @@ def render_text(
offset_x += SPACE offset_x += SPACE
class Text(ui.Control): class Text(ui.Component):
def __init__( def __init__(
self, self,
header_text: str, header_text: str,
@ -177,7 +177,7 @@ LABEL_CENTER = const(1)
LABEL_RIGHT = const(2) LABEL_RIGHT = const(2)
class Label(ui.Control): class Label(ui.Component):
def __init__( def __init__(
self, self,
area: ui.Area, area: ui.Area,

View File

@ -5,7 +5,7 @@ from trezor.ui.button import Button
class WordSelector(ui.Layout): class WordSelector(ui.Layout):
def __init__(self, content: ui.Control) -> None: def __init__(self, content: ui.Component) -> None:
self.content = content self.content = content
self.w12 = Button(ui.grid(6, n_y=4), "12") self.w12 = Button(ui.grid(6, n_y=4), "12")
self.w12.on_click = self.on_w12 # type: ignore self.w12.on_click = self.on_w12 # type: ignore

View File

@ -1,18 +1,57 @@
"""
# Wire
Handles on-the-wire communication with a host computer. The communication is:
- Request / response.
- Protobuf-encoded, see `protobuf.py`.
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
- Transferred over USB interface, or UDP in case of Unix emulation.
This module:
1. Provides API for registering messages. In other words binds what functions are invoked
when some particular message is received. See the `add` function.
2. Runs workflows, also called `handlers`, to process the message.
3. Creates and passes the `Context` object to the handlers. This provides an interface to
wait, read, write etc. on the wire.
## `add` function
The `add` function registers what function is invoked when some particular `message_type`
is received. The following example binds the `apps.wallet.get_address` function with
the `GetAddress` message:
```python
wire.add(MessageType.GetAddress, "apps.wallet", "get_address")
```
## Session handler
When the `wire.setup` is called the `handle_session` coroutine is scheduled. The
`handle_session` waits for some messages to be received on some particular interface and
reads the message's header. When the message type is known the first handler is called. This way the
`handle_session` goes through all the workflows.
"""
import protobuf import protobuf
from trezor import log, loop, messages, utils, workflow from trezor import log, loop, messages, utils, workflow
from trezor.messages import FailureType from trezor.messages import FailureType
from trezor.messages.Failure import Failure
from trezor.wire import codec_v1 from trezor.wire import codec_v1
from trezor.wire.errors import Error from trezor.wire.errors import Error
# import all errors into namespace, so that `wire.Error` is available elsewhere # Import all errors into namespace, so that `wire.Error` is available from
# other packages.
from trezor.wire.errors import * # isort:skip # noqa: F401,F403 from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if False: if False:
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
Dict,
Callable, Callable,
Dict,
Iterable, Iterable,
List, List,
Optional, Optional,
@ -20,61 +59,58 @@ if False:
Type, Type,
) )
from trezorio import WireInterface from trezorio import WireInterface
from protobuf import LoadedMessageType, MessageType
Handler = Callable[..., loop.Task] Handler = Callable[..., loop.Task]
workflow_handlers = {} # type: Dict[int, Tuple[Handler, Iterable]] # Maps a wire type directly to a handler.
workflow_handlers = {} # type: Dict[int, Handler]
# Maps a wire type to a tuple of package and module. This allows handlers
# to be dynamically imported when such message arrives.
workflow_packages = {} # type: Dict[int, Tuple[str, str]]
# Maps a wire type to a "keychain namespace". Such workflows are created
# with an instance of `seed.Keychain` with correctly derived keys.
workflow_namespaces = {} # type: Dict[int, List]
def add(mtype: int, pkgname: str, modname: str, namespace: List = None) -> None: def add(wire_type: int, pkgname: str, modname: str, namespace: List = None) -> None:
"""Shortcut for registering a dynamically-imported Protobuf workflow.""" """Shortcut for registering a dynamically-imported Protobuf workflow."""
if namespace is not None: if namespace is not None:
register( workflow_namespaces[wire_type] = namespace
mtype, workflow_packages[wire_type] = (pkgname, modname)
protobuf_workflow,
keychain_workflow,
namespace,
import_workflow,
pkgname,
modname,
)
else:
register(mtype, protobuf_workflow, import_workflow, pkgname, modname)
def register(mtype: int, handler: Handler, *args: Any) -> None: def register(wire_type: int, handler: Handler) -> None:
"""Register `handler` to get scheduled after `mtype` message is received.""" """Register `handler` to get scheduled after `wire_type` message is received."""
if isinstance(mtype, type) and issubclass(mtype, protobuf.MessageType): workflow_handlers[wire_type] = handler
mtype = mtype.MESSAGE_WIRE_TYPE
if mtype in workflow_handlers:
raise KeyError
workflow_handlers[mtype] = (handler, args)
def setup(iface: WireInterface) -> None: def setup(iface: WireInterface) -> None:
"""Initialize the wire stack on passed USB interface.""" """Initialize the wire stack on passed USB interface."""
loop.schedule(session_handler(iface, codec_v1.SESSION_ID)) loop.schedule(handle_session(iface, codec_v1.SESSION_ID))
def clear() -> None: def clear() -> None:
"""Remove all registered handlers.""" """Remove all registered handlers."""
workflow_handlers.clear() workflow_handlers.clear()
workflow_packages.clear()
workflow_namespaces.clear()
class DummyContext: class DummyContext:
async def call(*argv): async def call(*argv: Any) -> None:
pass pass
async def read(*argv): async def read(*argv: Any) -> None:
pass pass
async def write(*argv): async def write(*argv: Any) -> None:
pass pass
async def wait(self, *tasks: Awaitable) -> Any: async def wait(self, *tasks: Awaitable) -> Any:
return await loop.spawn(*tasks) return await loop.race(*tasks)
class Context: class Context:
@ -83,43 +119,22 @@ class Context:
self.sid = sid self.sid = sid
async def call( async def call(
self, msg: MessageType, exptype: Type[LoadedMessageType] self, msg: protobuf.MessageType, expected_type: Type[protobuf.LoadedMessageType]
) -> LoadedMessageType: ) -> protobuf.LoadedMessageType:
await self.write(msg) await self.write(msg)
del msg del msg
return await self.read(exptype) return await self.read(expected_type)
async def call_any(self, msg: MessageType, *allowed_types: int) -> MessageType: async def call_any(
self, msg: protobuf.MessageType, *expected_wire_types: int
) -> protobuf.MessageType:
await self.write(msg) await self.write(msg)
del msg del msg
return await self.read_any(allowed_types) return await self.read_any(expected_wire_types)
async def read( async def read(
self, exptype: Optional[Type[LoadedMessageType]] self, expected_type: Type[protobuf.LoadedMessageType]
) -> LoadedMessageType: ) -> protobuf.LoadedMessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__, "%s:%x expect: %s", self.iface.iface_num(), self.sid, exptype
)
await reader.aopen() # wait for the message header
# if we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it
if exptype is None or reader.type != exptype.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader)
if __debug__:
log.debug(
__name__, "%s:%x read: %s", self.iface.iface_num(), self.sid, exptype
)
# parse the message and return it
return await protobuf.load_message(reader, exptype)
async def read_any(self, allowed_types: Iterable[int]) -> MessageType:
reader = self.make_reader() reader = self.make_reader()
if __debug__: if __debug__:
@ -128,14 +143,51 @@ class Context:
"%s:%x expect: %s", "%s:%x expect: %s",
self.iface.iface_num(), self.iface.iface_num(),
self.sid, self.sid,
allowed_types, expected_type,
) )
await reader.aopen() # wait for the message header # Wait for the message header, contained in the first report. After
# we receive it, we have a message type to match on.
await reader.aopen()
# if we got a message with unexpected type, raise the reader via # If we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it # `UnexpectedMessageError` and let the session handler deal with it.
if reader.type not in allowed_types: if reader.type != expected_type.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(reader)
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
expected_type,
)
# parse the message and return it
return await protobuf.load_message(reader, expected_type)
async def read_any(
self, expected_wire_types: Iterable[int]
) -> protobuf.MessageType:
reader = self.make_reader()
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
expected_wire_types,
)
# Wait for the message header, contained in the first report. After
# we receive it, we have a message type to match on.
await reader.aopen()
# If we got a message with unexpected type, raise the reader via
# `UnexpectedMessageError` and let the session handler deal with it.
if reader.type not in expected_wire_types:
raise UnexpectedMessageError(reader) raise UnexpectedMessageError(reader)
# find the protobuf type # find the protobuf type
@ -172,7 +224,7 @@ class Context:
while servicing the wire context. If a message comes until one of the while servicing the wire context. If a message comes until one of the
tasks ends, `UnexpectedMessageError` is raised. tasks ends, `UnexpectedMessageError` is raised.
""" """
return loop.spawn(self.read(None), *tasks) return loop.race(self.read_any(()), *tasks)
def make_reader(self) -> codec_v1.Reader: def make_reader(self) -> codec_v1.Reader:
return codec_v1.Reader(self.iface) return codec_v1.Reader(self.iface)
@ -183,120 +235,198 @@ class Context:
class UnexpectedMessageError(Exception): class UnexpectedMessageError(Exception):
def __init__(self, reader: codec_v1.Reader) -> None: def __init__(self, reader: codec_v1.Reader) -> None:
super().__init__()
self.reader = reader self.reader = reader
async def session_handler(iface: WireInterface, sid: int) -> None: async def handle_session(iface: WireInterface, session_id: int) -> None:
reader = None ctx = Context(iface, session_id)
ctx = Context(iface, sid) next_reader = None # type: Optional[codec_v1.Reader]
while True: while True:
try: try:
# wait for new message, if needed, and find handler if next_reader is None:
if not reader: # We are not currently reading a message, so let's wait for one.
reader = ctx.make_reader() # If the decoding fails, exception is raised and we try again
await reader.aopen() # (with the same `Reader` instance, it's OK). Even in case of
try: # de-synchronized wire communication, report with a message
handler, args = workflow_handlers[reader.type] # header is eventually received, after a couple of tries.
except KeyError: req_reader = ctx.make_reader()
handler, args = unexpected_msg, () await req_reader.aopen()
else:
# We have a reader left over from earlier. We should process
# this message instead of waiting for new one.
req_reader = next_reader
next_reader = None
m = utils.unimport_begin() # Now we are in a middle of reading a message and we need to decide
w = handler(ctx, reader, *args) # what to do with it, based on its type from the message header.
# From this point on, we should take care to read it in full and
# send a response.
# Take a mark of modules that are imported at this point, so we can
# roll back and un-import any others. Should not raise.
modules = utils.unimport_begin()
# We need to find a handler for this message type. Should not
# raise.
handler = get_workflow_handler(req_reader)
if handler is None:
# If no handler is found, we can skip decoding and directly
# respond with failure, but first, we should read the rest of
# the message reports. Should not raise.
await read_and_throw_away(req_reader)
res_msg = unexpected_message()
else:
# We found a valid handler for this message type.
# Workflow task, declared for the `workflow.on_close` call later.
wf_task = None # type: Optional[loop.Task]
# Here we make sure we always respond with a Failure response
# in case of any errors.
try: try:
workflow.onstart(w) # Find a protobuf.MessageType subclass that describes this
await w # message. Raises if the type is not found.
finally: req_type = messages.get_type(req_reader.type)
workflow.onclose(w)
utils.unimport_end(m) # Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = await protobuf.load_message(req_reader, req_type)
# At this point, message reports are all processed and
# correctly parsed into `req_msg`.
# Create the workflow task.
wf_task = handler(ctx, req_msg)
# Register the task into the workflow management system.
workflow.on_start(wf_task)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
# response message, or raise an exception (a rather common
# thing to do). Exceptions are handled in the code below.
res_msg = await wf_task
except UnexpectedMessageError as exc: except UnexpectedMessageError as exc:
# retry with opened reader from the exception # Workflow was trying to read a message from the wire, and
reader = exc.reader # something unexpected came in. See Context.read() for
continue # example, which expects some particular message and raises
except Error as exc: # UnexpectedMessageError if another one comes in.
# we log wire.Error as warning, not as exception # In order not to lose the message, we pass on the reader
if __debug__: # to get picked up by the workflow logic in the beginning of
log.warning(__name__, "failure: %s", exc.message) # the cycle, which processes it in the usual manner.
# TODO:
# We might handle only the few common cases here, like
# Initialize and Cancel.
next_reader = exc.reader
res_msg = None
except Exception as exc: except Exception as exc:
# sessions are never closed by raised exceptions # Either:
# - the first workflow message had a type that has a
# registered handler, but does not have a protobuf class
# - the first workflow message was not a valid protobuf
# - workflow raised some kind of an exception while running
if __debug__:
log.exception(__name__, exc)
res_msg = failure(exc)
finally:
# De-register the task from the workflow system, if we
# registered it before.
if wf_task is not None:
workflow.on_close(wf_task)
if res_msg is not None:
# Either the workflow returned a response, or we created one.
# Write it on the wire. Possibly, the incoming message haven't
# been read in full. We ignore this case here and let the rest
# of the reports get processed while waiting for the message
# header.
# TODO: if the write fails, we do not unimport the loaded modules
await ctx.write(res_msg)
# Cleanup, so garbage collection triggered after un-importing can
# pick up the trash.
req_reader = None
req_type = None
req_msg = None
res_msg = None
handler = None
wf_task = None
# Unload modules imported by the workflow. Should not raise.
utils.unimport_end(modules)
except BaseException as exc:
# The session handling should never exit, just log and continue.
if __debug__: if __debug__:
log.exception(__name__, exc) log.exception(__name__, exc)
# read new message in next iteration
reader = None def get_workflow_handler(reader: codec_v1.Reader) -> Optional[Handler]:
msg_type = reader.type
if msg_type in workflow_handlers:
# Message has a handler available, return it directly.
handler = workflow_handlers[msg_type]
elif msg_type in workflow_packages:
# Message needs a dynamically imported handler, import it.
pkgname, modname = workflow_packages[msg_type]
handler = import_workflow(pkgname, modname)
else:
# Message does not have any registered handler.
return None
if msg_type in workflow_namespaces:
# Workflow needs a keychain, wrap it with a keychain provider.
namespace = workflow_namespaces[msg_type]
handler = wrap_keychain_workflow(handler, namespace)
return handler
async def protobuf_workflow( def import_workflow(pkgname: str, modname: str) -> Handler:
ctx: Context, reader: codec_v1.Reader, handler: Handler, *args: Any
) -> None:
from trezor.messages.Failure import Failure
req = await protobuf.load_message(reader, messages.get_type(reader.type))
if __debug__:
log.debug(__name__, "%s:%x request: %s", ctx.iface.iface_num(), ctx.sid, req)
try:
res = await handler(ctx, req, *args)
except UnexpectedMessageError:
# session handler takes care of this one
raise
except Error as exc:
# respond with specific code and message
await ctx.write(Failure(code=exc.code, message=exc.message))
raise
except Exception as e:
# respond with a generic code and message
message = "Firmware error"
if __debug__:
message = "{}: {}".format(type(e), e)
await ctx.write(Failure(code=FailureType.FirmwareError, message=message))
raise
if res:
# respond with a specific response
await ctx.write(res)
async def keychain_workflow(
ctx: Context,
req: protobuf.MessageType,
namespace: List,
handler: Handler,
*args: Any
) -> Any:
from apps.common import seed
keychain = await seed.get_keychain(ctx, namespace)
args += (keychain,)
try:
return await handler(ctx, req, *args)
finally:
keychain.__del__()
def import_workflow(
ctx: Context, req: protobuf.MessageType, pkgname: str, modname: str, *args: Any
) -> Any:
modpath = "%s.%s" % (pkgname, modname) modpath = "%s.%s" % (pkgname, modname)
module = __import__(modpath, None, None, (modname,), 0) # type: ignore module = __import__(modpath, None, None, (modname,), 0) # type: ignore
handler = getattr(module, modname) handler = getattr(module, modname)
return handler(ctx, req, *args) return handler
async def unexpected_msg(ctx: Context, reader: codec_v1.Reader) -> None: def wrap_keychain_workflow(handler: Handler, namespace: List) -> Handler:
from trezor.messages.Failure import Failure async def keychain_workflow(ctx: Context, req: protobuf.MessageType) -> Any:
from apps.common import seed
# receive the message and throw it away # Workflow that is hiding behind `handler` expects a keychain
await read_full_msg(reader) # instance, in addition to the request message. Acquire it from
# the seed module. More on-the-wire communication, and also UI
# interaction, might happen here.
keychain = await seed.get_keychain(ctx, namespace)
try:
return await handler(ctx, req, keychain)
finally:
# Be hygienic and wipe the keys from memory.
keychain.__del__()
# respond with an unknown message error return keychain_workflow
await ctx.write(
Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
)
async def read_full_msg(reader: codec_v1.Reader) -> None: def failure(exc: BaseException) -> Failure:
if isinstance(exc, Error):
return Failure(code=exc.code, message=exc.message)
else:
return Failure(code=FailureType.FirmwareError, message="Firmware error")
def unexpected_message() -> Failure:
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
async def read_and_throw_away(reader: codec_v1.Reader) -> None:
while reader.size > 0: while reader.size > 0:
buf = bytearray(reader.size) buf = bytearray(reader.size)
await reader.areadinto(buf) await reader.areadinto(buf)

View File

@ -20,7 +20,7 @@ INVALID_TYPE = const(-1)
class Reader: class Reader:
""" """
Decoder for legacy codec over the HID layer. Provides readable Decoder for a wire codec over the HID (or UDP) layer. Provides readable
async-file-like interface. async-file-like interface.
""" """
@ -33,9 +33,9 @@ class Reader:
async def aopen(self) -> None: async def aopen(self) -> None:
""" """
Begin the message transmission by waiting for initial V2 message report Start reading a message by waiting for initial message report. Because
on this session. `self.type` and `self.size` are initialized and the first report contains the message header, `self.type` and
available after `aopen()` returns. `self.size` are initialized and available after `aopen()` returns.
""" """
read = loop.wait(self.iface.iface_num() | io.POLL_READ) read = loop.wait(self.iface.iface_num() | io.POLL_READ)
while True: while True:
@ -88,7 +88,7 @@ class Reader:
class Writer: class Writer:
""" """
Encoder for legacy codec over the HID layer. Provides writable Encoder for a wire codec over the HID (or UDP) layer. Provides writable
async-file-like interface. async-file-like interface.
""" """

View File

@ -1,64 +1,66 @@
from trezor import loop from trezor import loop
if False: if False:
from trezor import ui from typing import Callable, Optional, Set
from typing import List, Callable, Optional
workflows = [] # type: List[loop.Task]
layouts = [] # type: List[ui.Layout]
layout_signal = loop.chan()
default = None # type: Optional[loop.Task]
default_layout = None # type: Optional[Callable[[], loop.Task]]
def onstart(w: loop.Task) -> None:
workflows.append(w)
def onclose(w: loop.Task) -> None:
workflows.remove(w)
if not layouts and default_layout:
startdefault(default_layout)
if __debug__: if __debug__:
# Used in `on_close` bellow for memory statistics.
import micropython import micropython
from trezor import utils from trezor import utils
# Set of workflow tasks. Multiple workflows can be running at the same time.
tasks = set() # type: Set[loop.Task]
# Default workflow task, if a default workflow is running. Default workflow
# is not contained in the `tasks` set above.
default_task = None # type: Optional[loop.Task]
# Constructor for the default workflow. Returns a workflow task.
default_constructor = None # type: Optional[Callable[[], loop.Task]]
def on_start(workflow: loop.Task) -> None:
"""
Call after creating a workflow task, but before running it. You should
make sure to always call `on_close` when the task is finished.
"""
# Take note that this workflow task is running.
tasks.add(workflow)
def on_close(workflow: loop.Task) -> None:
"""Call when a workflow task has finished running."""
# Remove task from the running set.
tasks.remove(workflow)
if not tasks and default_constructor:
# If no workflows are running, we should create a new default workflow
# and run it.
start_default(default_constructor)
if __debug__:
# In debug builds, we dump a memory info right after a workflow is
# finished.
if utils.LOG_MEMORY: if utils.LOG_MEMORY:
micropython.mem_info() micropython.mem_info()
def closedefault() -> None: def start_default(constructor: Callable[[], loop.Task]) -> None:
global default """Start a default workflow, created from `constructor`."""
global default_task
global default_constructor
if default: if not default_task:
loop.close(default) default_constructor = constructor
default = None default_task = constructor()
loop.schedule(default_task)
def startdefault(layout: Callable[[], loop.Task]) -> None: def close_default() -> None:
global default """Explicitly close the default workflow task."""
global default_layout global default_task
if not default: if default_task:
default_layout = layout loop.close(default_task)
default = layout() default_task = None
loop.schedule(default)
def restartdefault() -> None:
global default_layout
closedefault()
if default_layout:
startdefault(default_layout)
def onlayoutstart(l: ui.Layout) -> None:
closedefault()
layouts.append(l)
def onlayoutclose(l: ui.Layout) -> None:
if l in layouts:
layouts.remove(l)