You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
465 lines
16 KiB
465 lines
16 KiB
"""
|
|
# Wire
|
|
|
|
Handles on-the-wire communication with a host computer. The communication is:
|
|
|
|
- Request / response.
|
|
- Protobuf-encoded, see `protobuf.py`.
|
|
- Wrapped in a simple envelope format, see `trezor/wire/codec_v1.py`.
|
|
- Transferred over USB interface, or UDP in case of Unix emulation.
|
|
|
|
This module:
|
|
|
|
1. Provides API for registering messages. In other words binds what functions are invoked
|
|
when some particular message is received. See the `add` function.
|
|
2. Runs workflows, also called `handlers`, to process the message.
|
|
3. Creates and passes the `Context` object to the handlers. This provides an interface to
|
|
wait, read, write etc. on the wire.
|
|
|
|
## `add` function
|
|
|
|
The `add` function registers what function is invoked when some particular `message_type`
|
|
is received. The following example binds the `apps.wallet.get_address` function with
|
|
the `GetAddress` message:
|
|
|
|
```python
|
|
wire.add(MessageType.GetAddress, "apps.wallet", "get_address")
|
|
```
|
|
|
|
## Session handler
|
|
|
|
When the `wire.setup` is called the `handle_session` coroutine is scheduled. The
|
|
`handle_session` waits for some messages to be received on some particular interface and
|
|
reads the message's header. When the message type is known the first handler is called. This way the
|
|
`handle_session` goes through all the workflows.
|
|
|
|
"""
|
|
|
|
from storage.cache import InvalidSessionError
|
|
from trezor import log, loop, protobuf, utils, workflow
|
|
from trezor.enums import FailureType
|
|
from trezor.messages import Failure
|
|
from trezor.wire import codec_v1
|
|
from trezor.wire.errors import ActionCancelled, DataError, Error
|
|
|
|
# Import all errors into namespace, so that `wire.Error` is available from
|
|
# other packages.
|
|
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
|
|
|
if False:
|
|
from typing import (
|
|
Any,
|
|
Awaitable,
|
|
Callable,
|
|
Container,
|
|
Coroutine,
|
|
Iterable,
|
|
TypeVar,
|
|
)
|
|
from trezorio import WireInterface
|
|
|
|
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
|
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
|
Handler = Callable[["Context", Msg], HandlerTask]
|
|
|
|
|
|
# If set to False protobuf messages marked with "unstable" option are rejected.
|
|
experimental_enabled: bool = False
|
|
|
|
|
|
def setup(iface: WireInterface, is_debug_session: bool = False) -> None:
|
|
"""Initialize the wire stack on passed USB interface."""
|
|
loop.schedule(handle_session(iface, codec_v1.SESSION_ID, is_debug_session))
|
|
|
|
|
|
if False:
|
|
from typing import Protocol, TypeVar
|
|
|
|
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
|
|
|
class GenericContext(Protocol):
|
|
async def call(
|
|
self,
|
|
msg: protobuf.MessageType,
|
|
expected_type: type[protobuf.MessageType],
|
|
) -> Any:
|
|
...
|
|
|
|
async def read(self, expected_type: type[protobuf.MessageType]) -> Any:
|
|
...
|
|
|
|
async def write(self, msg: protobuf.MessageType) -> None:
|
|
...
|
|
|
|
# XXX modify type signature so that the return value must be of the same type?
|
|
async def wait(self, *tasks: Awaitable) -> Any:
|
|
...
|
|
|
|
|
|
def _wrap_protobuf_load(
|
|
buffer: bytes,
|
|
expected_type: type[LoadedMessageType],
|
|
) -> LoadedMessageType:
|
|
try:
|
|
msg = protobuf.decode(buffer, expected_type, experimental_enabled)
|
|
if __debug__ and utils.EMULATOR:
|
|
log.debug(
|
|
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
|
)
|
|
return msg
|
|
except Exception as e:
|
|
if __debug__:
|
|
log.exception(__name__, e)
|
|
if e.args:
|
|
raise DataError("Failed to decode message: " + " ".join(e.args))
|
|
else:
|
|
raise DataError("Failed to decode message")
|
|
|
|
|
|
class DummyContext:
|
|
async def call(self, *argv: Any) -> None:
|
|
pass
|
|
|
|
async def read(self, *argv: Any) -> None:
|
|
pass
|
|
|
|
async def write(self, *argv: Any) -> None:
|
|
pass
|
|
|
|
async def wait(self, *tasks: Awaitable) -> Any:
|
|
return await loop.race(*tasks)
|
|
|
|
|
|
DUMMY_CONTEXT = DummyContext()
|
|
|
|
PROTOBUF_BUFFER_SIZE = 8192
|
|
|
|
WIRE_BUFFER = bytearray(PROTOBUF_BUFFER_SIZE)
|
|
|
|
|
|
class Context:
|
|
def __init__(self, iface: WireInterface, sid: int, buffer: bytearray) -> None:
|
|
self.iface = iface
|
|
self.sid = sid
|
|
self.buffer = buffer
|
|
|
|
async def call(
|
|
self,
|
|
msg: protobuf.MessageType,
|
|
expected_type: type[LoadedMessageType],
|
|
) -> LoadedMessageType:
|
|
await self.write(msg)
|
|
del msg
|
|
return await self.read(expected_type)
|
|
|
|
async def call_any(
|
|
self, msg: protobuf.MessageType, *expected_wire_types: int
|
|
) -> protobuf.MessageType:
|
|
await self.write(msg)
|
|
del msg
|
|
return await self.read_any(expected_wire_types)
|
|
|
|
def read_from_wire(self) -> Awaitable[codec_v1.Message]:
|
|
return codec_v1.read_message(self.iface, self.buffer)
|
|
|
|
async def read(self, expected_type: type[LoadedMessageType]) -> LoadedMessageType:
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x expect: %s",
|
|
self.iface.iface_num(),
|
|
self.sid,
|
|
expected_type.MESSAGE_NAME,
|
|
)
|
|
|
|
# Load the full message into a buffer, parse out type and data payload
|
|
msg = await self.read_from_wire()
|
|
|
|
# If we got a message with unexpected type, raise the message via
|
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
|
if msg.type != expected_type.MESSAGE_WIRE_TYPE:
|
|
raise UnexpectedMessageError(msg)
|
|
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x read: %s",
|
|
self.iface.iface_num(),
|
|
self.sid,
|
|
expected_type.MESSAGE_NAME,
|
|
)
|
|
|
|
workflow.idle_timer.touch()
|
|
|
|
# look up the protobuf class and parse the message
|
|
return _wrap_protobuf_load(msg.data, expected_type)
|
|
|
|
async def read_any(
|
|
self, expected_wire_types: Iterable[int]
|
|
) -> protobuf.MessageType:
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x expect: %s",
|
|
self.iface.iface_num(),
|
|
self.sid,
|
|
expected_wire_types,
|
|
)
|
|
|
|
# Load the full message into a buffer, parse out type and data payload
|
|
msg = await self.read_from_wire()
|
|
|
|
# If we got a message with unexpected type, raise the message via
|
|
# `UnexpectedMessageError` and let the session handler deal with it.
|
|
if msg.type not in expected_wire_types:
|
|
raise UnexpectedMessageError(msg)
|
|
|
|
# find the protobuf type
|
|
exptype = protobuf.type_for_wire(msg.type)
|
|
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x read: %s",
|
|
self.iface.iface_num(),
|
|
self.sid,
|
|
exptype.MESSAGE_NAME,
|
|
)
|
|
|
|
workflow.idle_timer.touch()
|
|
|
|
# parse the message and return it
|
|
return _wrap_protobuf_load(msg.data, exptype)
|
|
|
|
async def write(self, msg: protobuf.MessageType) -> None:
|
|
if __debug__:
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x write: %s",
|
|
self.iface.iface_num(),
|
|
self.sid,
|
|
msg.MESSAGE_NAME,
|
|
)
|
|
|
|
# cannot write message without wire type
|
|
assert msg.MESSAGE_WIRE_TYPE is not None
|
|
|
|
msg_size = protobuf.encoded_length(msg)
|
|
|
|
if msg_size <= len(self.buffer):
|
|
# reuse preallocated
|
|
buffer = self.buffer
|
|
else:
|
|
# message is too big, we need to allocate a new buffer
|
|
buffer = bytearray(msg_size)
|
|
|
|
msg_size = protobuf.encode(buffer, msg)
|
|
|
|
await codec_v1.write_message(
|
|
self.iface,
|
|
msg.MESSAGE_WIRE_TYPE,
|
|
memoryview(buffer)[:msg_size],
|
|
)
|
|
|
|
def wait(self, *tasks: Awaitable) -> Any:
|
|
"""
|
|
Wait until one of the passed tasks finishes, and return the result,
|
|
while servicing the wire context. If a message comes until one of the
|
|
tasks ends, `UnexpectedMessageError` is raised.
|
|
"""
|
|
return loop.race(self.read_any(()), *tasks)
|
|
|
|
|
|
class UnexpectedMessageError(Exception):
|
|
def __init__(self, msg: codec_v1.Message) -> None:
|
|
super().__init__()
|
|
self.msg = msg
|
|
|
|
|
|
async def _handle_single_message(
|
|
ctx: Context, msg: codec_v1.Message, use_workflow: bool
|
|
) -> codec_v1.Message | None:
|
|
"""Handle a message that was loaded from USB by the caller.
|
|
|
|
Find the appropriate handler, run it and write its result on the wire. In case
|
|
a problem is encountered at any point, write the appropriate error on the wire.
|
|
|
|
If the workflow finished normally or with an error, the return value is None.
|
|
|
|
If an unexpected message had arrived on the wire while the workflow was processing,
|
|
the workflow is shut down with an `UnexpectedMessageError`. This is not considered
|
|
an "error condition" to return over the wire -- instead the message is processed
|
|
as if starting a new workflow.
|
|
In such case, the `UnexpectedMessageError` is caught and the message is returned
|
|
to the caller. It will then be processed in the next iteration of the message loop.
|
|
"""
|
|
if __debug__:
|
|
try:
|
|
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
|
except Exception:
|
|
msg_type = f"{msg.type} - unknown message type"
|
|
log.debug(
|
|
__name__,
|
|
"%s:%x receive: <%s>",
|
|
ctx.iface.iface_num(),
|
|
ctx.sid,
|
|
msg_type,
|
|
)
|
|
|
|
res_msg: protobuf.MessageType | None = None
|
|
|
|
# We need to find a handler for this message type. Should not raise.
|
|
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
|
|
|
if handler is None:
|
|
# If no handler is found, we can skip decoding and directly
|
|
# respond with failure.
|
|
await ctx.write(unexpected_message())
|
|
return None
|
|
|
|
# Here we make sure we always respond with a Failure response
|
|
# in case of any errors.
|
|
try:
|
|
# Find a protobuf.MessageType subclass that describes this
|
|
# message. Raises if the type is not found.
|
|
req_type = protobuf.type_for_wire(msg.type)
|
|
|
|
# Try to decode the message according to schema from
|
|
# `req_type`. Raises if the message is malformed.
|
|
req_msg = _wrap_protobuf_load(msg.data, req_type)
|
|
|
|
# Create the handler task.
|
|
task = handler(ctx, req_msg)
|
|
|
|
# Run the workflow task. Workflow can do more on-the-wire
|
|
# communication inside, but it should eventually return a
|
|
# response message, or raise an exception (a rather common
|
|
# thing to do). Exceptions are handled in the code below.
|
|
if use_workflow:
|
|
# Spawn a workflow around the task. This ensures that concurrent
|
|
# workflows are shut down.
|
|
res_msg = await workflow.spawn(task)
|
|
else:
|
|
# For debug messages, ignore workflow processing and just await
|
|
# results of the handler.
|
|
res_msg = await task
|
|
|
|
except UnexpectedMessageError as exc:
|
|
# Workflow was trying to read a message from the wire, and
|
|
# something unexpected came in. See Context.read() for
|
|
# example, which expects some particular message and raises
|
|
# UnexpectedMessageError if another one comes in.
|
|
# In order not to lose the message, we return it to the caller.
|
|
# TODO:
|
|
# We might handle only the few common cases here, like
|
|
# Initialize and Cancel.
|
|
return exc.msg
|
|
|
|
except BaseException as exc:
|
|
# Either:
|
|
# - the message had a type that has a registered handler, but does not have
|
|
# a protobuf class
|
|
# - the message was not valid protobuf
|
|
# - workflow raised some kind of an exception while running
|
|
# - something canceled the workflow from the outside
|
|
if __debug__:
|
|
if isinstance(exc, ActionCancelled):
|
|
log.debug(__name__, "cancelled: %s", exc.message)
|
|
elif isinstance(exc, loop.TaskClosed):
|
|
log.debug(__name__, "cancelled: loop task was closed")
|
|
else:
|
|
log.exception(__name__, exc)
|
|
res_msg = failure(exc)
|
|
|
|
if res_msg is not None:
|
|
# perform the write outside the big try-except block, so that usb write
|
|
# problem bubbles up
|
|
await ctx.write(res_msg)
|
|
return None
|
|
|
|
|
|
async def handle_session(
|
|
iface: WireInterface, session_id: int, is_debug_session: bool = False
|
|
) -> None:
|
|
ctx = Context(iface, session_id, WIRE_BUFFER)
|
|
next_msg: codec_v1.Message | None = None
|
|
|
|
if __debug__ and is_debug_session:
|
|
import apps.debug
|
|
|
|
apps.debug.DEBUG_CONTEXT = ctx
|
|
|
|
# Take a mark of modules that are imported at this point, so we can
|
|
# roll back and un-import any others.
|
|
modules = utils.unimport_begin()
|
|
while True:
|
|
try:
|
|
if next_msg is None:
|
|
# If the previous run did not keep an unprocessed message for us,
|
|
# wait for a new one coming from the wire.
|
|
try:
|
|
msg = await ctx.read_from_wire()
|
|
except codec_v1.CodecError as exc:
|
|
if __debug__:
|
|
log.exception(__name__, exc)
|
|
await ctx.write(failure(exc))
|
|
continue
|
|
|
|
else:
|
|
# Process the message from previous run.
|
|
msg = next_msg
|
|
next_msg = None
|
|
|
|
try:
|
|
next_msg = await _handle_single_message(
|
|
ctx, msg, use_workflow=not is_debug_session
|
|
)
|
|
except Exception as exc:
|
|
# Log and ignore. The session handler can only exit explicitly in the
|
|
# following finally block.
|
|
if __debug__:
|
|
log.exception(__name__, exc)
|
|
finally:
|
|
if not __debug__ or not is_debug_session:
|
|
# Unload modules imported by the workflow. Should not raise.
|
|
# This is not done for the debug session because the snapshot taken
|
|
# in a debug session would clear modules which are in use by the
|
|
# workflow running on wire.
|
|
utils.unimport_end(modules)
|
|
|
|
if next_msg is None and msg.type not in AVOID_RESTARTING_FOR:
|
|
# Shut down the loop if there is no next message waiting.
|
|
# Let the session be restarted from `main`.
|
|
loop.clear()
|
|
return # pylint: disable=lost-exception
|
|
|
|
except Exception as exc:
|
|
# Log and try again. The session handler can only exit explicitly via
|
|
# loop.clear() above.
|
|
if __debug__:
|
|
log.exception(__name__, exc)
|
|
|
|
|
|
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
|
"""Placeholder handler lookup before a proper one is registered."""
|
|
return None
|
|
|
|
|
|
find_handler = _find_handler_placeholder
|
|
AVOID_RESTARTING_FOR: Container[int] = ()
|
|
|
|
|
|
def failure(exc: BaseException) -> Failure:
|
|
if isinstance(exc, Error):
|
|
return Failure(code=exc.code, message=exc.message)
|
|
elif isinstance(exc, loop.TaskClosed):
|
|
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
|
elif isinstance(exc, InvalidSessionError):
|
|
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
|
else:
|
|
return Failure(code=FailureType.FirmwareError, message="Firmware error")
|
|
|
|
|
|
def unexpected_message() -> Failure:
|
|
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|