You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
281 lines
8.6 KiB
281 lines
8.6 KiB
"""Context pseudo-global.
|
|
|
|
Each workflow handler runs in a "context" which is tied to a particular communication
|
|
session. When the handler needs to communicate with the host, it needs access to that
|
|
context.
|
|
|
|
To avoid the need to pass a context object around, the context is stored in a
|
|
pseudo-global manner: any workflow handler can request access to the context via this
|
|
module, and the appropriate context object will be used for it.
|
|
|
|
Some workflows don't need a context to exist. This is supported by the `maybe_call`
|
|
function, which will silently ignore the call if no context is available. Useful mainly
|
|
for ButtonRequests. Of course, `context.wait()` transparently works in such situations.
|
|
"""
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
from trezor import log, loop, protobuf
|
|
|
|
from .protocol import WireProtocol
|
|
from .protocol_common import Message
|
|
|
|
if TYPE_CHECKING:
|
|
from trezorio import WireInterface
|
|
from typing import (
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
Container,
|
|
Coroutine,
|
|
Generator,
|
|
TypeVar,
|
|
overload,
|
|
)
|
|
|
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
|
Handler = Callable[["Context", Msg], HandlerTask]
|
|
|
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
class UnexpectedMessage(Exception):
|
|
"""A message was received that is not part of the current workflow.
|
|
|
|
Utility exception to inform the session handler that the current workflow
|
|
should be aborted and a new one started as if `msg` was the first message.
|
|
"""
|
|
|
|
def __init__(self, msg: Message) -> None:
|
|
super().__init__()
|
|
self.msg = msg
|
|
|
|
|
|
class Context:
|
|
"""Wire context.
|
|
|
|
Represents USB communication inside a particular session on a particular interface
|
|
(i.e., wire, debug, single BT connection, etc.)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
iface: WireInterface,
|
|
buffer: bytearray,
|
|
session_id: bytearray | None = None,
|
|
) -> None:
|
|
self.iface = iface
|
|
self.buffer = buffer
|
|
self.session_id: session_id
|
|
|
|
def read_from_wire(self) -> Awaitable[Message]:
|
|
"""Read a whole message from the wire without parsing it."""
|
|
return WireProtocol.read_message(self.iface, self.buffer)
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
@overload
|
|
async def read(
|
|
self, expected_types: Container[int]
|
|
) -> protobuf.MessageType: ...
|
|
|
|
@overload
|
|
async def read(
|
|
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
|
) -> LoadedMessageType: ...
|
|
|
|
async def read(
|
|
self,
|
|
expected_types: Container[int],
|
|
expected_type: type[protobuf.MessageType] | None = None,
|
|
) -> protobuf.MessageType:
|
|
"""Read a message from the wire.
|
|
|
|
The read message must be of one of the types specified in `expected_types`.
|
|
If only a single type is expected, it can be passed as `expected_type`,
|
|
to save on having to decode the type code into a protobuf class.
|
|
"""
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x expect: %s",
|
|
self.iface.iface_num(),
|
|
self.session_id,
|
|
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
|
)
|
|
|
|
# Load the full message into a buffer, parse out type and data payload
|
|
msg = await self.read_from_wire()
|
|
|
|
# If we got a message with unexpected type, raise the message via
|
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
|
if msg.type not in expected_types:
|
|
raise UnexpectedMessage(msg)
|
|
|
|
# TODO check that the message has the expected session_id. If not, raise UnexpectedMessageError
|
|
# (and maybe update ctx.session_id - depends on expected behaviour)
|
|
|
|
if expected_type is None:
|
|
expected_type = protobuf.type_for_wire(msg.type)
|
|
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x read: %s",
|
|
self.iface.iface_num(),
|
|
self.session_id,
|
|
expected_type.MESSAGE_NAME,
|
|
)
|
|
|
|
# look up the protobuf class and parse the message
|
|
from . import wrap_protobuf_load
|
|
|
|
return wrap_protobuf_load(msg.data, expected_type)
|
|
|
|
async def write(self, msg: protobuf.MessageType) -> None:
|
|
"""Write a message to the wire."""
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x write: %s",
|
|
self.iface.iface_num(),
|
|
self.session_id,
|
|
msg.MESSAGE_NAME,
|
|
)
|
|
|
|
# cannot write message without wire type
|
|
assert msg.MESSAGE_WIRE_TYPE is not None
|
|
|
|
msg_size = protobuf.encoded_length(msg)
|
|
|
|
if msg_size <= len(self.buffer):
|
|
# reuse preallocated
|
|
buffer = self.buffer
|
|
else:
|
|
# message is too big, we need to allocate a new buffer
|
|
buffer = bytearray(msg_size)
|
|
|
|
msg_size = protobuf.encode(buffer, msg)
|
|
|
|
await WireProtocol.write_message(
|
|
self.iface,
|
|
Message(
|
|
message_type=msg.MESSAGE_WIRE_TYPE,
|
|
message_data=memoryview(buffer)[:msg_size],
|
|
session_id=self.session_id,
|
|
),
|
|
)
|
|
|
|
|
|
CURRENT_CONTEXT: Context | None = None
|
|
|
|
|
|
def wait(task: Awaitable[T]) -> Awaitable[T]:
|
|
"""
|
|
Wait until the passed in task finishes, and return the result, while servicing the
|
|
wire context.
|
|
|
|
Used to make sure the device is responsive on USB while waiting for user
|
|
interaction. If a message is received before the task finishes, it raises an
|
|
`UnexpectedMessage` exception, returning control to the session handler.
|
|
"""
|
|
if CURRENT_CONTEXT is None:
|
|
return task
|
|
else:
|
|
return loop.race(CURRENT_CONTEXT.read(()), task)
|
|
|
|
|
|
async def call(
|
|
msg: protobuf.MessageType,
|
|
expected_type: type[LoadedMessageType],
|
|
) -> LoadedMessageType:
|
|
"""Send a message to the host and wait for a response of a particular type.
|
|
|
|
Raises if there is no context for this workflow."""
|
|
if CURRENT_CONTEXT is None:
|
|
raise RuntimeError("No wire context")
|
|
|
|
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
|
|
|
await CURRENT_CONTEXT.write(msg)
|
|
del msg
|
|
return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
|
|
|
|
|
async def call_any(
|
|
msg: protobuf.MessageType, *expected_wire_types: int
|
|
) -> protobuf.MessageType:
|
|
"""Send a message to the host and wait for a response.
|
|
|
|
The response can be of any of the types specified in `expected_wire_types`.
|
|
|
|
Raises if there is no context for this workflow."""
|
|
if CURRENT_CONTEXT is None:
|
|
raise RuntimeError("No wire context")
|
|
|
|
await CURRENT_CONTEXT.write(msg)
|
|
del msg
|
|
return await CURRENT_CONTEXT.read(expected_wire_types)
|
|
|
|
|
|
async def maybe_call(
|
|
msg: protobuf.MessageType, expected_type: type[LoadedMessageType]
|
|
) -> None:
|
|
"""Send a message to the host and read but ignore the response.
|
|
|
|
If there is a context, the function still checks that the response is of the
|
|
requested type. If there is no context, the call is ignored.
|
|
"""
|
|
if CURRENT_CONTEXT is None:
|
|
return
|
|
|
|
await call(msg, expected_type)
|
|
|
|
|
|
def get_context() -> Context:
|
|
"""Get the current session context.
|
|
|
|
Can be needed in case the caller needs raw read and raw write capabilities, which
|
|
are not provided by the module functions.
|
|
|
|
Result of this function should not be stored -- the context is technically allowed
|
|
to change inbetween any `await` statements.
|
|
"""
|
|
if CURRENT_CONTEXT is None:
|
|
raise RuntimeError("No wire context")
|
|
return CURRENT_CONTEXT
|
|
|
|
|
|
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
|
"""Run a workflow in a particular context.
|
|
|
|
Stores the context in a closure and installs it into the global variable every time
|
|
the closure is resumed, thus making sure that all calls to `wire.context.*` will
|
|
work as expected.
|
|
"""
|
|
global CURRENT_CONTEXT
|
|
send_val = None
|
|
send_exc = None
|
|
|
|
while True:
|
|
CURRENT_CONTEXT = ctx
|
|
try:
|
|
if send_exc is not None:
|
|
res = workflow.throw(send_exc)
|
|
else:
|
|
res = workflow.send(send_val)
|
|
except StopIteration as st:
|
|
return st.value
|
|
finally:
|
|
CURRENT_CONTEXT = None
|
|
|
|
try:
|
|
send_val = yield res
|
|
except BaseException as e:
|
|
send_exc = e
|
|
else:
|
|
send_exc = None
|