From fe80793b475efc0ca34698261efa1cd2143e0ff9 Mon Sep 17 00:00:00 2001 From: matejcik Date: Mon, 26 Jun 2023 11:36:41 +0200 Subject: [PATCH] feat(core): keep track of current context for workflow --- core/src/all_modules.py | 2 + core/src/trezor/wire/__init__.py | 204 ++---------------------- core/src/trezor/wire/context.py | 266 +++++++++++++++++++++++++++++++ 3 files changed, 284 insertions(+), 188 deletions(-) create mode 100644 core/src/trezor/wire/context.py diff --git a/core/src/all_modules.py b/core/src/all_modules.py index 8e68d76754..60ac66b011 100644 --- a/core/src/all_modules.py +++ b/core/src/all_modules.py @@ -193,6 +193,8 @@ trezor.wire import trezor.wire trezor.wire.codec_v1 import trezor.wire.codec_v1 +trezor.wire.context +import trezor.wire.context trezor.wire.errors import trezor.wire.errors trezor.workflow diff --git a/core/src/trezor/wire/__init__.py b/core/src/trezor/wire/__init__.py index 1b7047ca25..cabba16de0 100644 --- a/core/src/trezor/wire/__init__.py +++ b/core/src/trezor/wire/__init__.py @@ -42,7 +42,7 @@ from storage.cache import InvalidSessionError from trezor import log, loop, protobuf, utils, workflow from trezor.enums import FailureType from trezor.messages import Failure -from trezor.wire import codec_v1 +from trezor.wire import codec_v1, context from trezor.wire.errors import ActionCancelled, DataError, Error # Import all errors into namespace, so that `wire.Error` is available from @@ -53,43 +53,22 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403 if TYPE_CHECKING: from typing import ( Any, - Awaitable, Callable, Container, Coroutine, - Iterable, - Protocol, TypeVar, ) from trezorio import WireInterface Msg = TypeVar("Msg", bound=protobuf.MessageType) HandlerTask = Coroutine[Any, Any, protobuf.MessageType] - Handler = Callable[["Context", Msg], HandlerTask] + Handler = Callable[[Msg], HandlerTask] LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) - class GenericContext(Protocol): - async def call( - self, - msg: protobuf.MessageType, - expected_type: type[protobuf.MessageType], - ) -> Any: - ... - - async def read(self, expected_type: type[protobuf.MessageType]) -> Any: - ... - - async def write(self, msg: protobuf.MessageType) -> None: - ... - - # XXX modify type signature so that the return value must be of the same type? - async def wait(self, *tasks: Awaitable) -> Any: - ... - # If set to False protobuf messages marked with "experimental_message" option are rejected. -experimental_enabled = False +EXPERIMENTAL_ENABLED = False def setup(iface: WireInterface, is_debug_session: bool = False) -> None: @@ -97,12 +76,12 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None: loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session)) -def _wrap_protobuf_load( +def wrap_protobuf_load( buffer: bytes, expected_type: type[LoadedMessageType], ) -> LoadedMessageType: try: - msg = protobuf.decode(buffer, expected_type, experimental_enabled) + msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED) if __debug__ and utils.EMULATOR: log.debug( __name__, "received message contents:\n%s", utils.dump_protobuf(msg) @@ -117,22 +96,6 @@ def _wrap_protobuf_load( raise DataError("Failed to decode message") -class DummyContext: - async def call(self, *argv: Any) -> None: - pass - - async def read(self, *argv: Any) -> None: - pass - - async def write(self, *argv: Any) -> None: - pass - - async def wait(self, *tasks: Awaitable) -> Any: - return await loop.race(*tasks) - - -DUMMY_CONTEXT = DummyContext() - _PROTOBUF_BUFFER_SIZE = const(8192) WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE) @@ -142,143 +105,8 @@ if __debug__: WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG) -class Context: - def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: - self.iface = iface - self.sid = sid - self.buffer = buffer - - async def call( - self, - msg: protobuf.MessageType, - expected_type: type[LoadedMessageType], - ) -> LoadedMessageType: - await self.write(msg) - del msg - return await self.read(expected_type) - - async def call_any( - self, msg: protobuf.MessageType, *expected_wire_types: int - ) -> protobuf.MessageType: - await self.write(msg) - del msg - return await self.read_any(expected_wire_types) - - def read_from_wire(self) -> Awaitable[codec_v1.Message]: - return codec_v1.read_message(self.iface, self.buffer) - - async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType: - if __debug__: - log.debug( - __name__, - "%s:%x expect: %s", - self.iface.iface_num(), - self.sid, - expected_type.MESSAGE_NAME, - ) - - # Load the full message into a buffer, parse out type and data payload - msg = await self.read_from_wire() - - # If we got a message with unexpected type, raise the message via - # `UnexpectedMessageError` and let the session handler deal with it. - if msg.type != expected_type.MESSAGE_WIRE_TYPE: - raise UnexpectedMessageError(msg) - - if __debug__: - log.debug( - __name__, - "%s:%x read: %s", - self.iface.iface_num(), - self.sid, - expected_type.MESSAGE_NAME, - ) - - # look up the protobuf class and parse the message - return _wrap_protobuf_load(msg.data, expected_type) - - async def read_any( - self, expected_wire_types: Iterable[int] - ) -> protobuf.MessageType: - if __debug__: - log.debug( - __name__, - "%s:%x expect: %s", - self.iface.iface_num(), - self.sid, - expected_wire_types, - ) - - # Load the full message into a buffer, parse out type and data payload - msg = await self.read_from_wire() - - # If we got a message with unexpected type, raise the message via - # `UnexpectedMessageError` and let the session handler deal with it. - if msg.type not in expected_wire_types: - raise UnexpectedMessageError(msg) - - # find the protobuf type - exptype = protobuf.type_for_wire(msg.type) - - if __debug__: - log.debug( - __name__, - "%s:%x read: %s", - self.iface.iface_num(), - self.sid, - exptype.MESSAGE_NAME, - ) - - # parse the message and return it - return _wrap_protobuf_load(msg.data, exptype) - - async def write(self, msg: protobuf.MessageType) -> None: - if __debug__: - log.debug( - __name__, - "%s:%x write: %s", - self.iface.iface_num(), - self.sid, - msg.MESSAGE_NAME, - ) - - # cannot write message without wire type - assert msg.MESSAGE_WIRE_TYPE is not None - - msg_size = protobuf.encoded_length(msg) - - if msg_size <= len(self.buffer): - # reuse preallocated - buffer = self.buffer - else: - # message is too big, we need to allocate a new buffer - buffer = bytearray(msg_size) - - msg_size = protobuf.encode(buffer, msg) - - await codec_v1.write_message( - self.iface, - msg.MESSAGE_WIRE_TYPE, - memoryview(buffer)[:msg_size], - ) - - def wait(self, *tasks: Awaitable) -> Any: - """ - Wait until one of the passed tasks finishes, and return the result, - while servicing the wire context. If a message comes until one of the - tasks ends, `UnexpectedMessageError` is raised. - """ - return loop.race(self.read_any(()), *tasks) - - -class UnexpectedMessageError(Exception): - def __init__(self, msg: codec_v1.Message) -> None: - super().__init__() - self.msg = msg - - async def _handle_single_message( - ctx: Context, msg: codec_v1.Message, use_workflow: bool + ctx: context.Context, msg: codec_v1.Message, use_workflow: bool ) -> codec_v1.Message | None: """Handle a message that was loaded from USB by the caller. @@ -288,10 +116,10 @@ async def _handle_single_message( If the workflow finished normally or with an error, the return value is None. If an unexpected message had arrived on the wire while the workflow was processing, - the workflow is shut down with an `UnexpectedMessageError`. This is not considered - an "error condition" to return over the wire -- instead the message is processed - as if starting a new workflow. - In such case, the `UnexpectedMessageError` is caught and the message is returned + the workflow is shut down with an `UnexpectedMessage` exception. This is not + considered an "error condition" to return over the wire -- instead the message + is processed as if starting a new workflow. + In such case, the `UnexpectedMessage` is caught and the message is returned to the caller. It will then be processed in the next iteration of the message loop. """ if __debug__: @@ -330,10 +158,10 @@ async def _handle_single_message( # Try to decode the message according to schema from # `req_type`. Raises if the message is malformed. - req_msg = _wrap_protobuf_load(msg.data, req_type) + req_msg = wrap_protobuf_load(msg.data, req_type) # Create the handler task. - task = handler(ctx, req_msg) + task = handler(req_msg) # Run the workflow task. Workflow can do more on-the-wire # communication inside, but it should eventually return a @@ -342,17 +170,17 @@ async def _handle_single_message( if use_workflow: # Spawn a workflow around the task. This ensures that concurrent # workflows are shut down. - res_msg = await workflow.spawn(task) + res_msg = await workflow.spawn(context.with_context(ctx, task)) else: # For debug messages, ignore workflow processing and just await # results of the handler. res_msg = await task - except UnexpectedMessageError as exc: + except context.UnexpectedMessage as exc: # Workflow was trying to read a message from the wire, and # something unexpected came in. See Context.read() for # example, which expects some particular message and raises - # UnexpectedMessageError if another one comes in. + # UnexpectedMessage if another one comes in. # In order not to lose the message, we return it to the caller. # TODO: # We might handle only the few common cases here, like @@ -390,7 +218,7 @@ async def handle_session( else: ctx_buffer = WIRE_BUFFER - ctx = Context(iface, session_id, ctx_buffer) + ctx = context.Context(iface, session_id, ctx_buffer) next_msg: codec_v1.Message | None = None if __debug__ and is_debug_session: diff --git a/core/src/trezor/wire/context.py b/core/src/trezor/wire/context.py new file mode 100644 index 0000000000..8dda31629d --- /dev/null +++ b/core/src/trezor/wire/context.py @@ -0,0 +1,266 @@ +"""Context pseudo-global. + +Each workflow handler runs in a "context" which is tied to a particular communication +session. When the handler needs to communicate with the host, it needs access to that +context. + +To avoid the need to pass a context object around, the context is stored in a +pseudo-global manner: any workflow handler can request access to the context via this +module, and the appropriate context object will be used for it. + +Some workflows don't need a context to exist. This is supported by the `maybe_call` +function, which will silently ignore the call if no context is available. Useful mainly +for ButtonRequests. Of course, `context.wait()` transparently works in such situations. +""" + +from typing import TYPE_CHECKING + +from trezor import log, loop, protobuf + +from . import codec_v1 + +if TYPE_CHECKING: + from typing import ( + Any, + Awaitable, + Callable, + Container, + Coroutine, + Generator, + TypeVar, + overload, + ) + from trezorio import WireInterface + + Msg = TypeVar("Msg", bound=protobuf.MessageType) + HandlerTask = Coroutine[Any, Any, protobuf.MessageType] + Handler = Callable[["Context", Msg], HandlerTask] + + LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType) + + +class UnexpectedMessage(Exception): + """A message was received that is not part of the current workflow. + + Utility exception to inform the session handler that the current workflow + should be aborted and a new one started as if `msg` was the first message. + """ + + def __init__(self, msg: codec_v1.Message) -> None: + super().__init__() + self.msg = msg + + +class Context: + """Wire context. + + Represents USB communication inside a particular session on a particular interface + (i.e., wire, debug, single BT connection, etc.) + """ + + def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None: + self.iface = iface + self.sid = sid + self.buffer = buffer + + def read_from_wire(self) -> Awaitable[codec_v1.Message]: + """Read a whole message from the wire without parsing it.""" + return codec_v1.read_message(self.iface, self.buffer) + + if TYPE_CHECKING: + + @overload + async def read(self, expected_types: Container[int]) -> protobuf.MessageType: + ... + + @overload + async def read( + self, expected_types: Container[int], expected_type: type[LoadedMessageType] + ) -> LoadedMessageType: + ... + + async def read( + self, + expected_types: Container[int], + expected_type: type[protobuf.MessageType] | None = None, + ) -> protobuf.MessageType: + """Read a message from the wire. + + The read message must be of one of the types specified in `expected_types`. + If only a single type is expected, it can be passed as `expected_type`, + to save on having to decode the type code into a protobuf class. + """ + if __debug__: + log.debug( + __name__, + "%s:%x expect: %s", + self.iface.iface_num(), + self.sid, + expected_type.MESSAGE_NAME if expected_type else expected_types, + ) + + # Load the full message into a buffer, parse out type and data payload + msg = await self.read_from_wire() + + # If we got a message with unexpected type, raise the message via + # `UnexpectedMessageError` and let the session handler deal with it. + if msg.type not in expected_types: + raise UnexpectedMessage(msg) + + if expected_type is None: + expected_type = protobuf.type_for_wire(msg.type) + + if __debug__: + log.debug( + __name__, + "%s:%x read: %s", + self.iface.iface_num(), + self.sid, + expected_type.MESSAGE_NAME, + ) + + # look up the protobuf class and parse the message + from . import wrap_protobuf_load + + return wrap_protobuf_load(msg.data, expected_type) + + async def write(self, msg: protobuf.MessageType) -> None: + """Write a message to the wire.""" + if __debug__: + log.debug( + __name__, + "%s:%x write: %s", + self.iface.iface_num(), + self.sid, + msg.MESSAGE_NAME, + ) + + # cannot write message without wire type + assert msg.MESSAGE_WIRE_TYPE is not None + + msg_size = protobuf.encoded_length(msg) + + if msg_size <= len(self.buffer): + # reuse preallocated + buffer = self.buffer + else: + # message is too big, we need to allocate a new buffer + buffer = bytearray(msg_size) + + msg_size = protobuf.encode(buffer, msg) + + await codec_v1.write_message( + self.iface, + msg.MESSAGE_WIRE_TYPE, + memoryview(buffer)[:msg_size], + ) + + +CURRENT_CONTEXT: Context | None = None + + +def wait(*tasks: Awaitable) -> Any: + """ + Wait until one of the passed tasks finishes, and return the result, while servicing + the wire context. + + Used to make sure the device is responsive on USB while waiting for user + interaction. If a message is received before any of the passed in tasks finish, it + raises an `UnexpectedMessage` exception, returning control to the session handler. + """ + if CURRENT_CONTEXT is None: + return loop.race(*tasks) + else: + return loop.race(CURRENT_CONTEXT.read(()), *tasks) + + +async def call( + msg: protobuf.MessageType, + expected_type: type[LoadedMessageType], +) -> LoadedMessageType: + """Send a message to the host and wait for a response of a particular type. + + Raises if there is no context for this workflow.""" + if CURRENT_CONTEXT is None: + raise RuntimeError("No wire context") + + assert expected_type.MESSAGE_WIRE_TYPE is not None + + await CURRENT_CONTEXT.write(msg) + del msg + return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type) + + +async def call_any( + msg: protobuf.MessageType, *expected_wire_types: int +) -> protobuf.MessageType: + """Send a message to the host and wait for a response. + + The response can be of any of the types specified in `expected_wire_types`. + + Raises if there is no context for this workflow.""" + if CURRENT_CONTEXT is None: + raise RuntimeError("No wire context") + + await CURRENT_CONTEXT.write(msg) + del msg + return await CURRENT_CONTEXT.read(expected_wire_types) + + +async def maybe_call( + msg: protobuf.MessageType, expected_type: type[LoadedMessageType] +) -> None: + """Send a message to the host and read but ignore the response. + + If there is a context, the function still checks that the response is of the + requested type. If there is no context, the call is ignored. + """ + if CURRENT_CONTEXT is None: + return + + await call(msg, expected_type) + + +def get_context() -> Context: + """Get the current session context. + + Can be needed in case the caller needs raw read and raw write capabilities, which + are not provided by the module functions. + + Result of this function should not be stored -- the context is technically allowed + to change inbetween any `await` statements. + """ + if CURRENT_CONTEXT is None: + raise RuntimeError("No wire context") + return CURRENT_CONTEXT + + +def with_context(ctx: Context, workflow: loop.Task) -> Generator: + """Run a workflow in a particular context. + + Stores the context in a closure and installs it into the global variable every time + the closure is resumed, thus making sure that all calls to `wire.context.*` will + work as expected. + """ + global CURRENT_CONTEXT + send_val = None + send_exc = None + + while True: + CURRENT_CONTEXT = ctx + try: + if send_exc is not None: + res = workflow.throw(send_exc) + else: + res = workflow.send(send_val) + except StopIteration as st: + return st.value + finally: + CURRENT_CONTEXT = None + + try: + send_val = yield res + except BaseException as e: + send_exc = e + else: + send_exc = None