You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
trezor-firmware/core/src/trezor/wire/context.py

267 lines
8.1 KiB

"""Context pseudo-global.
Each workflow handler runs in a "context" which is tied to a particular communication
session. When the handler needs to communicate with the host, it needs access to that
context.
To avoid the need to pass a context object around, the context is stored in a
pseudo-global manner: any workflow handler can request access to the context via this
module, and the appropriate context object will be used for it.
Some workflows don't need a context to exist. This is supported by the `maybe_call`
function, which will silently ignore the call if no context is available. Useful mainly
for ButtonRequests. Of course, `context.wait()` transparently works in such situations.
"""
from typing import TYPE_CHECKING
from trezor import log, loop, protobuf
from . import codec_v1
if TYPE_CHECKING:
from trezorio import WireInterface
from typing import (
Any,
Awaitable,
Callable,
Container,
Coroutine,
Generator,
TypeVar,
overload,
)
Msg = TypeVar("Msg", bound=protobuf.MessageType)
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
Handler = Callable[["Context", Msg], HandlerTask]
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
class UnexpectedMessage(Exception):
"""A message was received that is not part of the current workflow.
Utility exception to inform the session handler that the current workflow
should be aborted and a new one started as if `msg` was the first message.
"""
def __init__(self, msg: codec_v1.Message) -> None:
super().__init__()
self.msg = msg
class Context:
"""Wire context.
Represents USB communication inside a particular session on a particular interface
(i.e., wire, debug, single BT connection, etc.)
"""
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
self.iface = iface
self.sid = sid
self.buffer = buffer
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
"""Read a whole message from the wire without parsing it."""
return codec_v1.read_message(self.iface, self.buffer)
if TYPE_CHECKING:
@overload
async def read(
self, expected_types: Container[int]
) -> protobuf.MessageType: ...
@overload
async def read(
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
) -> LoadedMessageType: ...
async def read(
self,
expected_types: Container[int],
expected_type: type[protobuf.MessageType] | None = None,
) -> protobuf.MessageType:
"""Read a message from the wire.
The read message must be of one of the types specified in `expected_types`.
If only a single type is expected, it can be passed as `expected_type`,
to save on having to decode the type code into a protobuf class.
"""
if __debug__:
log.debug(
__name__,
"%s:%x expect: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME if expected_type else expected_types,
)
# Load the full message into a buffer, parse out type and data payload
msg = await self.read_from_wire()
# If we got a message with unexpected type, raise the message via
# `UnexpectedMessageError` and let the session handler deal with it.
if msg.type not in expected_types:
raise UnexpectedMessage(msg)
if expected_type is None:
expected_type = protobuf.type_for_wire(msg.type)
if __debug__:
log.debug(
__name__,
"%s:%x read: %s",
self.iface.iface_num(),
self.sid,
expected_type.MESSAGE_NAME,
)
# look up the protobuf class and parse the message
from . import wrap_protobuf_load
return wrap_protobuf_load(msg.data, expected_type)
async def write(self, msg: protobuf.MessageType) -> None:
"""Write a message to the wire."""
if __debug__:
log.debug(
__name__,
"%s:%x write: %s",
self.iface.iface_num(),
self.sid,
msg.MESSAGE_NAME,
)
# cannot write message without wire type
assert msg.MESSAGE_WIRE_TYPE is not None
msg_size = protobuf.encoded_length(msg)
if msg_size <= len(self.buffer):
# reuse preallocated
buffer = self.buffer
else:
# message is too big, we need to allocate a new buffer
buffer = bytearray(msg_size)
msg_size = protobuf.encode(buffer, msg)
await codec_v1.write_message(
self.iface,
msg.MESSAGE_WIRE_TYPE,
memoryview(buffer)[:msg_size],
)
CURRENT_CONTEXT: Context | None = None
def wait(*tasks: Awaitable) -> Any:
"""
Wait until one of the passed tasks finishes, and return the result, while servicing
the wire context.
Used to make sure the device is responsive on USB while waiting for user
interaction. If a message is received before any of the passed in tasks finish, it
raises an `UnexpectedMessage` exception, returning control to the session handler.
"""
if CURRENT_CONTEXT is None:
return loop.race(*tasks)
else:
return loop.race(CURRENT_CONTEXT.read(()), *tasks)
async def call(
msg: protobuf.MessageType,
expected_type: type[LoadedMessageType],
) -> LoadedMessageType:
"""Send a message to the host and wait for a response of a particular type.
Raises if there is no context for this workflow."""
if CURRENT_CONTEXT is None:
raise RuntimeError("No wire context")
assert expected_type.MESSAGE_WIRE_TYPE is not None
await CURRENT_CONTEXT.write(msg)
del msg
return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
async def call_any(
msg: protobuf.MessageType, *expected_wire_types: int
) -> protobuf.MessageType:
"""Send a message to the host and wait for a response.
The response can be of any of the types specified in `expected_wire_types`.
Raises if there is no context for this workflow."""
if CURRENT_CONTEXT is None:
raise RuntimeError("No wire context")
await CURRENT_CONTEXT.write(msg)
del msg
return await CURRENT_CONTEXT.read(expected_wire_types)
async def maybe_call(
msg: protobuf.MessageType, expected_type: type[LoadedMessageType]
) -> None:
"""Send a message to the host and read but ignore the response.
If there is a context, the function still checks that the response is of the
requested type. If there is no context, the call is ignored.
"""
if CURRENT_CONTEXT is None:
return
await call(msg, expected_type)
def get_context() -> Context:
"""Get the current session context.
Can be needed in case the caller needs raw read and raw write capabilities, which
are not provided by the module functions.
Result of this function should not be stored -- the context is technically allowed
to change inbetween any `await` statements.
"""
if CURRENT_CONTEXT is None:
raise RuntimeError("No wire context")
return CURRENT_CONTEXT
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
"""Run a workflow in a particular context.
Stores the context in a closure and installs it into the global variable every time
the closure is resumed, thus making sure that all calls to `wire.context.*` will
work as expected.
"""
global CURRENT_CONTEXT
send_val = None
send_exc = None
while True:
CURRENT_CONTEXT = ctx
try:
if send_exc is not None:
res = workflow.throw(send_exc)
else:
res = workflow.send(send_val)
except StopIteration as st:
return st.value
finally:
CURRENT_CONTEXT = None
try:
send_val = yield res
except BaseException as e:
send_exc = e
else:
send_exc = None