1
0
mirror of https://github.com/trezor/trezor-firmware.git synced 2025-01-02 03:20:59 +00:00

feat(core): keep track of current context for workflow

This commit is contained in:
matejcik 2023-06-26 11:36:41 +02:00 committed by matejcik
parent 78a8b48f1e
commit fe80793b47
3 changed files with 284 additions and 188 deletions

View File

@ -193,6 +193,8 @@ trezor.wire
import trezor.wire
trezor.wire.codec_v1
import trezor.wire.codec_v1
trezor.wire.context
import trezor.wire.context
trezor.wire.errors
import trezor.wire.errors
trezor.workflow

View File

@ -42,7 +42,7 @@ from storage.cache import InvalidSessionError
from trezor import log, loop, protobuf, utils, workflow
from trezor.enums import FailureType
from trezor.messages import Failure
from trezor.wire import codec_v1
from trezor.wire import codec_v1, context
from trezor.wire.errors import ActionCancelled, DataError, Error
# Import all errors into namespace, so that `wire.Error` is available from
@ -53,43 +53,22 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
if TYPE_CHECKING:
from typing import (
Any,
Awaitable,
Callable,
Container,
Coroutine,
Iterable,
Protocol,
TypeVar,
)
from trezorio import WireInterface
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[["Context", Msg], HandlerTask]
Handler = Callable[[Msg], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class GenericContext(Protocol):
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[protobuf.MessageType],
) -> Any:
...
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
...
async def write(self, msg: protobuf.MessageType) -> None:
...
# XXX modify type signature so that the return value must be of the same type?
async def wait(self, *tasks: Awaitable) -> Any:
...
# If set to False protobuf messages marked with "experimental_message" option are rejected.
experimental_enabled = False
EXPERIMENTAL_ENABLED = False
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
@ -97,12 +76,12 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
def _wrap_protobuf_load(
def wrap_protobuf_load(
buffer: bytes,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
try:
msg = protobuf.decode(buffer, expected_type, experimental_enabled)
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
if __debug__ and utils.EMULATOR:
log.debug(
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
@ -117,22 +96,6 @@ def _wrap_protobuf_load(
raise DataError("Failed to decode message")
class DummyContext:
async def call(self, *argv: Any) -> None:
pass
async def read(self, *argv: Any) -> None:
pass
async def write(self, *argv: Any) -> None:
pass
async def wait(self, *tasks: Awaitable) -> Any:
return await loop.race(*tasks)
DUMMY_CONTEXT = DummyContext()
_PROTOBUF_BUFFER_SIZE = const(8192)
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
@ -142,143 +105,8 @@ if __debug__:
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
class Context:
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
self.iface = iface
self.sid = sid
self.buffer = buffer
async def call(
self,
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
await self.write(msg)
del msg
return await self.read(expected_type)
async def call_any(
self, msg: protobuf.MessageType, *expected_wire_types: int
) -> protobuf.MessageType:
await self.write(msg)
del msg
return await self.read_any(expected_wire_types)
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
return codec_v1.read_message(self.iface, self.buffer)
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
raise UnexpectedMessageError(msg)
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
return _wrap_protobuf_load(msg.data, expected_type)
async def read_any(
self, expected_wire_types: Iterable[int]
) -> protobuf.MessageType:
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
expected_wire_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_wire_types:
raise UnexpectedMessageError(msg)
# find the protobuf type
exptype = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
exptype.MESSAGE_NAME,
)
# parse the message and return it
return _wrap_protobuf_load(msg.data, exptype)
async def write(self, msg: protobuf.MessageType) -> None:
if __debug__:
log.debug(
__name__,
"%s:%x write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
def wait(self, *tasks: Awaitable) -> Any:
"""
Wait until one of the passed tasks finishes, and return the result,
while servicing the wire context. If a message comes until one of the
tasks ends, `UnexpectedMessageError` is raised.
"""
return loop.race(self.read_any(()), *tasks)
class UnexpectedMessageError(Exception):
def __init__(self, msg: codec_v1.Message) -> None:
super().__init__()
self.msg = msg
async def _handle_single_message(
ctx: Context, msg: codec_v1.Message, use_workflow: bool
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
) -> codec_v1.Message | None:
"""Handle a message that was loaded from USB by the caller.
@ -288,10 +116,10 @@ async def _handle_single_message(
If the workflow finished normally or with an error, the return value is None.
If an unexpected message had arrived on the wire while the workflow was processing,
the workflow is shut down with an `UnexpectedMessageError`. This is not considered
an "error condition" to return over the wire -- instead the message is processed
as if starting a new workflow.
In such case, the `UnexpectedMessageError` is caught and the message is returned
the workflow is shut down with an `UnexpectedMessage` exception. This is not
considered an "error condition" to return over the wire -- instead the message
is processed as if starting a new workflow.
In such case, the `UnexpectedMessage` is caught and the message is returned
to the caller. It will then be processed in the next iteration of the message loop.
"""
if __debug__:
@ -330,10 +158,10 @@ async def _handle_single_message(
# Try to decode the message according to schema from
# `req_type`. Raises if the message is malformed.
req_msg = _wrap_protobuf_load(msg.data, req_type)
req_msg = wrap_protobuf_load(msg.data, req_type)
# Create the handler task.
task = handler(ctx, req_msg)
task = handler(req_msg)
# Run the workflow task. Workflow can do more on-the-wire
# communication inside, but it should eventually return a
@ -342,17 +170,17 @@ async def _handle_single_message(
if use_workflow:
# Spawn a workflow around the task. This ensures that concurrent
# workflows are shut down.
res_msg = await workflow.spawn(task)
res_msg = await workflow.spawn(context.with_context(ctx, task))
else:
# For debug messages, ignore workflow processing and just await
# results of the handler.
res_msg = await task
except UnexpectedMessageError as exc:
except context.UnexpectedMessage as exc:
# Workflow was trying to read a message from the wire, and
# something unexpected came in. See Context.read() for
# example, which expects some particular message and raises
# UnexpectedMessageError if another one comes in.
# UnexpectedMessage if another one comes in.
# In order not to lose the message, we return it to the caller.
# TODO:
# We might handle only the few common cases here, like
@ -390,7 +218,7 @@ async def handle_session(
else:
ctx_buffer = WIRE_BUFFER
ctx = Context(iface, session_id, ctx_buffer)
ctx = context.Context(iface, session_id, ctx_buffer)
next_msg: codec_v1.Message | None = None
if __debug__ and is_debug_session:

View File

@ -0,0 +1,266 @@
"""Context pseudo-global.
Each workflow handler runs in a "context" which is tied to a particular communication
session. When the handler needs to communicate with the host, it needs access to that
context.
To avoid the need to pass a context object around, the context is stored in a
pseudo-global manner: any workflow handler can request access to the context via this
module, and the appropriate context object will be used for it.
Some workflows don't need a context to exist. This is supported by the `maybe_call`
function, which will silently ignore the call if no context is available. Useful mainly
for ButtonRequests. Of course, `context.wait()` transparently works in such situations.
"""
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf
from . import codec_v1
if TYPE_CHECKING:
from typing import (
Any,
Awaitable,
Callable,
Container,
Coroutine,
Generator,
TypeVar,
overload,
)
from trezorio import WireInterface
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[["Context", Msg], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class UnexpectedMessage(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
super().__init__()
self.msg = msg
class Context:
"""Wire context.
Represents USB communication inside a particular session on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
self.iface = iface
self.sid = sid
self.buffer = buffer
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@overload
async def read(self, expected_types: Container[int]) -> protobuf.MessageType:
...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType:
...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
"""Read a message from the wire.
The read message must be of one of the types specified in `expected_types`.
If only a single type is expected, it can be passed as `expected_type`,
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from . import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
log.debug(
__name__,
"%s:%x write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
CURRENT_CONTEXT: Context | None = None
def wait(*tasks: Awaitable) -> Any:
"""
Wait until one of the passed tasks finishes, and return the result, while servicing
the wire context.
Used to make sure the device is responsive on USB while waiting for user
interaction. If a message is received before any of the passed in tasks finish, it
raises an `UnexpectedMessage` exception, returning control to the session handler.
"""
if CURRENT_CONTEXT is None:
return loop.race(*tasks)
else:
return loop.race(CURRENT_CONTEXT.read(()), *tasks)
async def call(
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
"""Send a message to the host and wait for a response of a particular type.
Raises if there is no context for this workflow."""
if CURRENT_CONTEXT is None:
raise RuntimeError("No wire context")
assert expected_type.MESSAGE_WIRE_TYPE is not None
await CURRENT_CONTEXT.write(msg)
del msg
return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
async def call_any(
msg: protobuf.MessageType, *expected_wire_types: int
) -> protobuf.MessageType:
"""Send a message to the host and wait for a response.
The response can be of any of the types specified in `expected_wire_types`.
Raises if there is no context for this workflow."""
if CURRENT_CONTEXT is None:
raise RuntimeError("No wire context")
await CURRENT_CONTEXT.write(msg)
del msg
return await CURRENT_CONTEXT.read(expected_wire_types)
async def maybe_call(
msg: protobuf.MessageType, expected_type: type[LoadedMessageType]
) -> None:
"""Send a message to the host and read but ignore the response.
If there is a context, the function still checks that the response is of the
requested type. If there is no context, the call is ignored.
"""
if CURRENT_CONTEXT is None:
return
await call(msg, expected_type)
def get_context() -> Context:
"""Get the current session context.
Can be needed in case the caller needs raw read and raw write capabilities, which
are not provided by the module functions.
Result of this function should not be stored -- the context is technically allowed
to change inbetween any `await` statements.
"""
if CURRENT_CONTEXT is None:
raise RuntimeError("No wire context")
return CURRENT_CONTEXT
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time
the closure is resumed, thus making sure that all calls to `wire.context.*` will
work as expected.
"""
global CURRENT_CONTEXT
send_val = None
send_exc = None
while True:
CURRENT_CONTEXT = ctx
try:
if send_exc is not None:
res = workflow.throw(send_exc)
else:
res = workflow.send(send_val)
except StopIteration as st:
return st.value
finally:
CURRENT_CONTEXT = None
try:
send_val = yield res
except BaseException as e:
send_exc = e
else:
send_exc = None