mirror of
https://github.com/trezor/trezor-firmware.git
synced 2025-01-02 03:20:59 +00:00
feat(core): keep track of current context for workflow
This commit is contained in:
parent
78a8b48f1e
commit
fe80793b47
@ -193,6 +193,8 @@ trezor.wire
|
||||
import trezor.wire
|
||||
trezor.wire.codec_v1
|
||||
import trezor.wire.codec_v1
|
||||
trezor.wire.context
|
||||
import trezor.wire.context
|
||||
trezor.wire.errors
|
||||
import trezor.wire.errors
|
||||
trezor.workflow
|
||||
|
@ -42,7 +42,7 @@ from storage.cache import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import codec_v1
|
||||
from trezor.wire import codec_v1, context
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
@ -53,43 +53,22 @@ from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
Iterable,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
)
|
||||
from trezorio import WireInterface
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
Handler = Callable[["Context", Msg], HandlerTask]
|
||||
Handler = Callable[[Msg], HandlerTask]
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
class GenericContext(Protocol):
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[protobuf.MessageType],
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
|
||||
...
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
...
|
||||
|
||||
# XXX modify type signature so that the return value must be of the same type?
|
||||
async def wait(self, *tasks: Awaitable) -> Any:
|
||||
...
|
||||
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
experimental_enabled = False
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
@ -97,12 +76,12 @@ def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
||||
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
||||
|
||||
|
||||
def _wrap_protobuf_load(
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
msg = protobuf.decode(buffer, expected_type, experimental_enabled)
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
@ -117,22 +96,6 @@ def _wrap_protobuf_load(
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
class DummyContext:
|
||||
async def call(self, *argv: Any) -> None:
|
||||
pass
|
||||
|
||||
async def read(self, *argv: Any) -> None:
|
||||
pass
|
||||
|
||||
async def write(self, *argv: Any) -> None:
|
||||
pass
|
||||
|
||||
async def wait(self, *tasks: Awaitable) -> Any:
|
||||
return await loop.race(*tasks)
|
||||
|
||||
|
||||
DUMMY_CONTEXT = DummyContext()
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
@ -142,143 +105,8 @@ if __debug__:
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
|
||||
class Context:
|
||||
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer = buffer
|
||||
|
||||
async def call(
|
||||
self,
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read(expected_type)
|
||||
|
||||
async def call_any(
|
||||
self, msg: protobuf.MessageType, *expected_wire_types: int
|
||||
) -> protobuf.MessageType:
|
||||
await self.write(msg)
|
||||
del msg
|
||||
return await self.read_any(expected_wire_types)
|
||||
|
||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x expect: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
|
||||
raise UnexpectedMessageError(msg)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x read: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
return _wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def read_any(
|
||||
self, expected_wire_types: Iterable[int]
|
||||
) -> protobuf.MessageType:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x expect: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_wire_types,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_wire_types:
|
||||
raise UnexpectedMessageError(msg)
|
||||
|
||||
# find the protobuf type
|
||||
exptype = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x read: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
exptype.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# parse the message and return it
|
||||
return _wrap_protobuf_load(msg.data, exptype)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x write: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
def wait(self, *tasks: Awaitable) -> Any:
|
||||
"""
|
||||
Wait until one of the passed tasks finishes, and return the result,
|
||||
while servicing the wire context. If a message comes until one of the
|
||||
tasks ends, `UnexpectedMessageError` is raised.
|
||||
"""
|
||||
return loop.race(self.read_any(()), *tasks)
|
||||
|
||||
|
||||
class UnexpectedMessageError(Exception):
|
||||
def __init__(self, msg: codec_v1.Message) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
|
||||
async def _handle_single_message(
|
||||
ctx: Context, msg: codec_v1.Message, use_workflow: bool
|
||||
ctx: context.Context, msg: codec_v1.Message, use_workflow: bool
|
||||
) -> codec_v1.Message | None:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
@ -288,10 +116,10 @@ async def _handle_single_message(
|
||||
If the workflow finished normally or with an error, the return value is None.
|
||||
|
||||
If an unexpected message had arrived on the wire while the workflow was processing,
|
||||
the workflow is shut down with an `UnexpectedMessageError`. This is not considered
|
||||
an "error condition" to return over the wire -- instead the message is processed
|
||||
as if starting a new workflow.
|
||||
In such case, the `UnexpectedMessageError` is caught and the message is returned
|
||||
the workflow is shut down with an `UnexpectedMessage` exception. This is not
|
||||
considered an "error condition" to return over the wire -- instead the message
|
||||
is processed as if starting a new workflow.
|
||||
In such case, the `UnexpectedMessage` is caught and the message is returned
|
||||
to the caller. It will then be processed in the next iteration of the message loop.
|
||||
"""
|
||||
if __debug__:
|
||||
@ -330,10 +158,10 @@ async def _handle_single_message(
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = _wrap_protobuf_load(msg.data, req_type)
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(ctx, req_msg)
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
@ -342,17 +170,17 @@ async def _handle_single_message(
|
||||
if use_workflow:
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
res_msg = await workflow.spawn(task)
|
||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||
else:
|
||||
# For debug messages, ignore workflow processing and just await
|
||||
# results of the handler.
|
||||
res_msg = await task
|
||||
|
||||
except UnexpectedMessageError as exc:
|
||||
except context.UnexpectedMessage as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessageError if another one comes in.
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
@ -390,7 +218,7 @@ async def handle_session(
|
||||
else:
|
||||
ctx_buffer = WIRE_BUFFER
|
||||
|
||||
ctx = Context(iface, session_id, ctx_buffer)
|
||||
ctx = context.Context(iface, session_id, ctx_buffer)
|
||||
next_msg: codec_v1.Message | None = None
|
||||
|
||||
if __debug__ and is_debug_session:
|
||||
|
266
core/src/trezor/wire/context.py
Normal file
266
core/src/trezor/wire/context.py
Normal file
@ -0,0 +1,266 @@
|
||||
"""Context pseudo-global.
|
||||
|
||||
Each workflow handler runs in a "context" which is tied to a particular communication
|
||||
session. When the handler needs to communicate with the host, it needs access to that
|
||||
context.
|
||||
|
||||
To avoid the need to pass a context object around, the context is stored in a
|
||||
pseudo-global manner: any workflow handler can request access to the context via this
|
||||
module, and the appropriate context object will be used for it.
|
||||
|
||||
Some workflows don't need a context to exist. This is supported by the `maybe_call`
|
||||
function, which will silently ignore the call if no context is available. Useful mainly
|
||||
for ButtonRequests. Of course, `context.wait()` transparently works in such situations.
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from trezor import log, loop, protobuf
|
||||
|
||||
from . import codec_v1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Container,
|
||||
Coroutine,
|
||||
Generator,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
from trezorio import WireInterface
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
Handler = Callable[["Context", Msg], HandlerTask]
|
||||
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
class UnexpectedMessage(Exception):
|
||||
"""A message was received that is not part of the current workflow.
|
||||
|
||||
Utility exception to inform the session handler that the current workflow
|
||||
should be aborted and a new one started as if `msg` was the first message.
|
||||
"""
|
||||
|
||||
def __init__(self, msg: codec_v1.Message) -> None:
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
|
||||
|
||||
class Context:
|
||||
"""Wire context.
|
||||
|
||||
Represents USB communication inside a particular session on a particular interface
|
||||
(i.e., wire, debug, single BT connection, etc.)
|
||||
"""
|
||||
|
||||
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
||||
self.iface = iface
|
||||
self.sid = sid
|
||||
self.buffer = buffer
|
||||
|
||||
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
||||
"""Read a whole message from the wire without parsing it."""
|
||||
return codec_v1.read_message(self.iface, self.buffer)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
async def read(self, expected_types: Container[int]) -> protobuf.MessageType:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def read(
|
||||
self, expected_types: Container[int], expected_type: type[LoadedMessageType]
|
||||
) -> LoadedMessageType:
|
||||
...
|
||||
|
||||
async def read(
|
||||
self,
|
||||
expected_types: Container[int],
|
||||
expected_type: type[protobuf.MessageType] | None = None,
|
||||
) -> protobuf.MessageType:
|
||||
"""Read a message from the wire.
|
||||
|
||||
The read message must be of one of the types specified in `expected_types`.
|
||||
If only a single type is expected, it can be passed as `expected_type`,
|
||||
to save on having to decode the type code into a protobuf class.
|
||||
"""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x expect: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_type.MESSAGE_NAME if expected_type else expected_types,
|
||||
)
|
||||
|
||||
# Load the full message into a buffer, parse out type and data payload
|
||||
msg = await self.read_from_wire()
|
||||
|
||||
# If we got a message with unexpected type, raise the message via
|
||||
# `UnexpectedMessageError` and let the session handler deal with it.
|
||||
if msg.type not in expected_types:
|
||||
raise UnexpectedMessage(msg)
|
||||
|
||||
if expected_type is None:
|
||||
expected_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x read: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
expected_type.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# look up the protobuf class and parse the message
|
||||
from . import wrap_protobuf_load
|
||||
|
||||
return wrap_protobuf_load(msg.data, expected_type)
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
"""Write a message to the wire."""
|
||||
if __debug__:
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x write: %s",
|
||||
self.iface.iface_num(),
|
||||
self.sid,
|
||||
msg.MESSAGE_NAME,
|
||||
)
|
||||
|
||||
# cannot write message without wire type
|
||||
assert msg.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
msg_size = protobuf.encoded_length(msg)
|
||||
|
||||
if msg_size <= len(self.buffer):
|
||||
# reuse preallocated
|
||||
buffer = self.buffer
|
||||
else:
|
||||
# message is too big, we need to allocate a new buffer
|
||||
buffer = bytearray(msg_size)
|
||||
|
||||
msg_size = protobuf.encode(buffer, msg)
|
||||
|
||||
await codec_v1.write_message(
|
||||
self.iface,
|
||||
msg.MESSAGE_WIRE_TYPE,
|
||||
memoryview(buffer)[:msg_size],
|
||||
)
|
||||
|
||||
|
||||
CURRENT_CONTEXT: Context | None = None
|
||||
|
||||
|
||||
def wait(*tasks: Awaitable) -> Any:
|
||||
"""
|
||||
Wait until one of the passed tasks finishes, and return the result, while servicing
|
||||
the wire context.
|
||||
|
||||
Used to make sure the device is responsive on USB while waiting for user
|
||||
interaction. If a message is received before any of the passed in tasks finish, it
|
||||
raises an `UnexpectedMessage` exception, returning control to the session handler.
|
||||
"""
|
||||
if CURRENT_CONTEXT is None:
|
||||
return loop.race(*tasks)
|
||||
else:
|
||||
return loop.race(CURRENT_CONTEXT.read(()), *tasks)
|
||||
|
||||
|
||||
async def call(
|
||||
msg: protobuf.MessageType,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
"""Send a message to the host and wait for a response of a particular type.
|
||||
|
||||
Raises if there is no context for this workflow."""
|
||||
if CURRENT_CONTEXT is None:
|
||||
raise RuntimeError("No wire context")
|
||||
|
||||
assert expected_type.MESSAGE_WIRE_TYPE is not None
|
||||
|
||||
await CURRENT_CONTEXT.write(msg)
|
||||
del msg
|
||||
return await CURRENT_CONTEXT.read((expected_type.MESSAGE_WIRE_TYPE,), expected_type)
|
||||
|
||||
|
||||
async def call_any(
|
||||
msg: protobuf.MessageType, *expected_wire_types: int
|
||||
) -> protobuf.MessageType:
|
||||
"""Send a message to the host and wait for a response.
|
||||
|
||||
The response can be of any of the types specified in `expected_wire_types`.
|
||||
|
||||
Raises if there is no context for this workflow."""
|
||||
if CURRENT_CONTEXT is None:
|
||||
raise RuntimeError("No wire context")
|
||||
|
||||
await CURRENT_CONTEXT.write(msg)
|
||||
del msg
|
||||
return await CURRENT_CONTEXT.read(expected_wire_types)
|
||||
|
||||
|
||||
async def maybe_call(
|
||||
msg: protobuf.MessageType, expected_type: type[LoadedMessageType]
|
||||
) -> None:
|
||||
"""Send a message to the host and read but ignore the response.
|
||||
|
||||
If there is a context, the function still checks that the response is of the
|
||||
requested type. If there is no context, the call is ignored.
|
||||
"""
|
||||
if CURRENT_CONTEXT is None:
|
||||
return
|
||||
|
||||
await call(msg, expected_type)
|
||||
|
||||
|
||||
def get_context() -> Context:
|
||||
"""Get the current session context.
|
||||
|
||||
Can be needed in case the caller needs raw read and raw write capabilities, which
|
||||
are not provided by the module functions.
|
||||
|
||||
Result of this function should not be stored -- the context is technically allowed
|
||||
to change inbetween any `await` statements.
|
||||
"""
|
||||
if CURRENT_CONTEXT is None:
|
||||
raise RuntimeError("No wire context")
|
||||
return CURRENT_CONTEXT
|
||||
|
||||
|
||||
def with_context(ctx: Context, workflow: loop.Task) -> Generator:
|
||||
"""Run a workflow in a particular context.
|
||||
|
||||
Stores the context in a closure and installs it into the global variable every time
|
||||
the closure is resumed, thus making sure that all calls to `wire.context.*` will
|
||||
work as expected.
|
||||
"""
|
||||
global CURRENT_CONTEXT
|
||||
send_val = None
|
||||
send_exc = None
|
||||
|
||||
while True:
|
||||
CURRENT_CONTEXT = ctx
|
||||
try:
|
||||
if send_exc is not None:
|
||||
res = workflow.throw(send_exc)
|
||||
else:
|
||||
res = workflow.send(send_val)
|
||||
except StopIteration as st:
|
||||
return st.value
|
||||
finally:
|
||||
CURRENT_CONTEXT = None
|
||||
|
||||
try:
|
||||
send_val = yield res
|
||||
except BaseException as e:
|
||||
send_exc = e
|
||||
else:
|
||||
send_exc = None
|
Loading…
Reference in New Issue
Block a user