parent
fe9167ffa2
commit
0927bbeb68
@ -0,0 +1,195 @@
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_common import InvalidSessionError
|
||||
from trezor import log, loop, protobuf, utils, workflow
|
||||
from trezor.enums import FailureType
|
||||
from trezor.messages import Failure
|
||||
from trezor.wire import context, protocol_common
|
||||
from trezor.wire.errors import ActionCancelled, DataError, Error
|
||||
|
||||
# Import all errors into namespace, so that `wire.Error` is available from
|
||||
# other packages.
|
||||
from trezor.wire.errors import * # isort:skip # noqa: F401,F403
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
from typing import Any, Callable, Container, Coroutine, TypeVar
|
||||
|
||||
Msg = TypeVar("Msg", bound=protobuf.MessageType)
|
||||
HandlerTask = Coroutine[Any, Any, protobuf.MessageType]
|
||||
Handler = Callable[[Msg], HandlerTask]
|
||||
LoadedMessageType = TypeVar("LoadedMessageType", bound=protobuf.MessageType)
|
||||
|
||||
|
||||
# If set to False protobuf messages marked with "experimental_message" option are rejected.
|
||||
EXPERIMENTAL_ENABLED = False
|
||||
|
||||
|
||||
def wrap_protobuf_load(
|
||||
buffer: bytes,
|
||||
expected_type: type[LoadedMessageType],
|
||||
) -> LoadedMessageType:
|
||||
try:
|
||||
msg = protobuf.decode(buffer, expected_type, EXPERIMENTAL_ENABLED)
|
||||
if __debug__ and utils.EMULATOR:
|
||||
log.debug(
|
||||
__name__, "received message contents:\n%s", utils.dump_protobuf(msg)
|
||||
)
|
||||
return msg
|
||||
except Exception as e:
|
||||
if __debug__:
|
||||
log.exception(__name__, e)
|
||||
if e.args:
|
||||
raise DataError("Failed to decode message: " + " ".join(e.args))
|
||||
else:
|
||||
raise DataError("Failed to decode message")
|
||||
|
||||
|
||||
_PROTOBUF_BUFFER_SIZE = const(8192)
|
||||
|
||||
WIRE_BUFFER = bytearray(_PROTOBUF_BUFFER_SIZE)
|
||||
|
||||
if __debug__:
|
||||
PROTOBUF_BUFFER_SIZE_DEBUG = 1024
|
||||
WIRE_BUFFER_DEBUG = bytearray(PROTOBUF_BUFFER_SIZE_DEBUG)
|
||||
|
||||
|
||||
async def handle_single_message(
|
||||
ctx: context.CodecContext, msg: protocol_common.MessageWithId, use_workflow: bool
|
||||
) -> protocol_common.MessageWithId | None:
|
||||
"""Handle a message that was loaded from USB by the caller.
|
||||
|
||||
Find the appropriate handler, run it and write its result on the wire. In case
|
||||
a problem is encountered at any point, write the appropriate error on the wire.
|
||||
|
||||
If the workflow finished normally or with an error, the return value is None.
|
||||
|
||||
If an unexpected message had arrived on the wire while the workflow was processing,
|
||||
the workflow is shut down with an `UnexpectedMessage` exception. This is not
|
||||
considered an "error condition" to return over the wire -- instead the message
|
||||
is processed as if starting a new workflow.
|
||||
In such case, the `UnexpectedMessage` is caught and the message is returned
|
||||
to the caller. It will then be processed in the next iteration of the message loop.
|
||||
"""
|
||||
if __debug__:
|
||||
try:
|
||||
msg_type = protobuf.type_for_wire(msg.type).MESSAGE_NAME
|
||||
except Exception:
|
||||
msg_type = f"{msg.type} - unknown message type"
|
||||
if ctx.channel_id is not None:
|
||||
sid = int.from_bytes(ctx.channel_id, "big")
|
||||
else:
|
||||
sid = -1
|
||||
log.debug(
|
||||
__name__,
|
||||
"%s:%x receive: <%s>",
|
||||
ctx.iface.iface_num(),
|
||||
sid,
|
||||
msg_type,
|
||||
)
|
||||
|
||||
res_msg: protobuf.MessageType | None = None
|
||||
|
||||
# We need to find a handler for this message type. Should not raise.
|
||||
handler = find_handler(ctx.iface, msg.type) # pylint: disable=assignment-from-none
|
||||
|
||||
if handler is None:
|
||||
# If no handler is found, we can skip decoding and directly
|
||||
# respond with failure.
|
||||
await ctx.write(unexpected_message())
|
||||
return None
|
||||
|
||||
if msg.type in workflow.ALLOW_WHILE_LOCKED:
|
||||
workflow.autolock_interrupts_workflow = False
|
||||
|
||||
# Here we make sure we always respond with a Failure response
|
||||
# in case of any errors.
|
||||
try:
|
||||
# Find a protobuf.MessageType subclass that describes this
|
||||
# message. Raises if the type is not found.
|
||||
req_type = protobuf.type_for_wire(msg.type)
|
||||
|
||||
# Try to decode the message according to schema from
|
||||
# `req_type`. Raises if the message is malformed.
|
||||
req_msg = wrap_protobuf_load(msg.data, req_type)
|
||||
|
||||
# Create the handler task.
|
||||
task = handler(req_msg)
|
||||
|
||||
# Run the workflow task. Workflow can do more on-the-wire
|
||||
# communication inside, but it should eventually return a
|
||||
# response message, or raise an exception (a rather common
|
||||
# thing to do). Exceptions are handled in the code below.
|
||||
if use_workflow:
|
||||
# Spawn a workflow around the task. This ensures that concurrent
|
||||
# workflows are shut down.
|
||||
res_msg = await workflow.spawn(context.with_context(ctx, task))
|
||||
else:
|
||||
# For debug messages, ignore workflow processing and just await
|
||||
# results of the handler.
|
||||
res_msg = await task
|
||||
|
||||
except context.UnexpectedMessage as exc:
|
||||
# Workflow was trying to read a message from the wire, and
|
||||
# something unexpected came in. See Context.read() for
|
||||
# example, which expects some particular message and raises
|
||||
# UnexpectedMessage if another one comes in.
|
||||
# In order not to lose the message, we return it to the caller.
|
||||
# TODO:
|
||||
# We might handle only the few common cases here, like
|
||||
# Initialize and Cancel.
|
||||
return exc.msg
|
||||
|
||||
except BaseException as exc:
|
||||
# Either:
|
||||
# - the message had a type that has a registered handler, but does not have
|
||||
# a protobuf class
|
||||
# - the message was not valid protobuf
|
||||
# - workflow raised some kind of an exception while running
|
||||
# - something canceled the workflow from the outside
|
||||
if __debug__:
|
||||
if isinstance(exc, ActionCancelled):
|
||||
log.debug(__name__, "cancelled: %s", exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
log.debug(__name__, "cancelled: loop task was closed")
|
||||
else:
|
||||
log.exception(__name__, exc)
|
||||
res_msg = failure(exc)
|
||||
|
||||
if res_msg is not None:
|
||||
# perform the write outside the big try-except block, so that usb write
|
||||
# problem bubbles up
|
||||
await ctx.write(res_msg)
|
||||
return None
|
||||
|
||||
|
||||
def _find_handler_placeholder(iface: WireInterface, msg_type: int) -> Handler | None:
|
||||
"""Placeholder handler lookup before a proper one is registered."""
|
||||
return None
|
||||
|
||||
|
||||
find_handler = _find_handler_placeholder
|
||||
AVOID_RESTARTING_FOR: Container[int] = ()
|
||||
|
||||
|
||||
def failure(exc: BaseException) -> Failure:
|
||||
if isinstance(exc, Error):
|
||||
return Failure(code=exc.code, message=exc.message)
|
||||
elif isinstance(exc, loop.TaskClosed):
|
||||
return Failure(code=FailureType.ActionCancelled, message="Cancelled")
|
||||
elif isinstance(exc, InvalidSessionError):
|
||||
return Failure(code=FailureType.InvalidSession, message="Invalid session")
|
||||
else:
|
||||
# NOTE: when receiving generic `FirmwareError` on non-debug build,
|
||||
# change the `if __debug__` to `if True` to get the full error message.
|
||||
if __debug__:
|
||||
message = str(exc)
|
||||
else:
|
||||
message = "Firmware error"
|
||||
return Failure(code=FailureType.FirmwareError, message=message)
|
||||
|
||||
|
||||
def unexpected_message() -> Failure:
|
||||
return Failure(code=FailureType.UnexpectedMessage, message="Unexpected message")
|
@ -0,0 +1,106 @@
|
||||
import ustruct
|
||||
from micropython import const
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from storage.cache_thp import SessionThpCache
|
||||
from trezor import loop, protobuf, utils
|
||||
|
||||
from ..protocol_common import Context
|
||||
from . import thp_session
|
||||
from .thp_messages import CONTINUATION_PACKET, ENCRYPTED_TRANSPORT
|
||||
|
||||
# from .thp_session import SessionState, ThpError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from trezorio import WireInterface
|
||||
|
||||
_INIT_DATA_OFFSET = const(5)
|
||||
_CONT_DATA_OFFSET = const(3)
|
||||
|
||||
|
||||
class ChannelContext(Context):
|
||||
def __init__(
|
||||
self, iface: WireInterface, channel_id: int, session_data: SessionThpCache
|
||||
) -> None:
|
||||
super().__init__(iface, channel_id)
|
||||
self.session_data = session_data
|
||||
self.buffer: utils.BufferType
|
||||
self.waiting_for_ack_timeout: loop.Task | None
|
||||
self.is_cont_packet_expected: bool = False
|
||||
self.expected_payload_length: int = 0
|
||||
self.bytes_read = 0
|
||||
|
||||
# ACCESS TO SESSION_DATA
|
||||
|
||||
def get_management_session_state(self):
|
||||
return thp_session.get_state(self.session_data)
|
||||
|
||||
# CALLED BY THP_MAIN_LOOP
|
||||
|
||||
async def receive_packet(self, packet: utils.BufferType):
|
||||
ctrl_byte = packet[0]
|
||||
if _is_ctrl_byte_continuation(ctrl_byte):
|
||||
await self._handle_cont_packet(packet)
|
||||
else:
|
||||
await self._handle_init_packet(packet)
|
||||
|
||||
async def _handle_init_packet(self, packet):
|
||||
ctrl_byte, _, payload_length = ustruct.unpack(">BHH", packet)
|
||||
packet_payload = packet[5:]
|
||||
|
||||
if _is_ctrl_byte_encrypted_transport(ctrl_byte):
|
||||
packet_payload = self._decode(packet_payload)
|
||||
|
||||
# session_id = packet_payload[0] # TODO handle handshake differently
|
||||
self.expected_payload_length = payload_length
|
||||
self.bytes_read = 0
|
||||
|
||||
await self._buffer_packet_data(self.buffer, packet, _INIT_DATA_OFFSET)
|
||||
# TODO Set/Provide different buffer for management session
|
||||
|
||||
if self.expected_payload_length == self.bytes_read:
|
||||
self._finish_message()
|
||||
else:
|
||||
self.is_cont_packet_expected = True
|
||||
|
||||
async def _handle_cont_packet(self, packet):
|
||||
if not self.is_cont_packet_expected:
|
||||
return # Continuation packet is not expected, ignoring
|
||||
await self._buffer_packet_data(self.buffer, packet, _CONT_DATA_OFFSET)
|
||||
|
||||
def _decode(self, payload) -> bytes:
|
||||
return payload # TODO add decryption process
|
||||
|
||||
async def _buffer_packet_data(
|
||||
self, payload_buffer, packet: utils.BufferType, offset
|
||||
):
|
||||
self.bytes_read += utils.memcpy(payload_buffer, self.bytes_read, packet, offset)
|
||||
|
||||
def _finish_message(self):
|
||||
# TODO Provide loaded message to SessionContext or handle it with this ChannelContext
|
||||
self.bytes_read = 0
|
||||
self.expected_payload_length = 0
|
||||
self.is_cont_packet_expected = False
|
||||
|
||||
# CALLED BY WORKFLOW / SESSION CONTEXT
|
||||
|
||||
async def write(self, msg: protobuf.MessageType, session_id: int = 0) -> None:
|
||||
pass
|
||||
# TODO protocol.write(self.iface, self.channel_id, session_id, msg)
|
||||
|
||||
def create_new_session(
|
||||
self,
|
||||
passphrase="",
|
||||
) -> None: # TODO change it to output session data
|
||||
pass
|
||||
# TODO check, wheter a session with this passphrase already exists
|
||||
# if not, create a new session with this passphrase
|
||||
# if yes, what TODO TODO ???
|
||||
|
||||
|
||||
def _is_ctrl_byte_continuation(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0x80 == CONTINUATION_PACKET
|
||||
|
||||
|
||||
def _is_ctrl_byte_encrypted_transport(ctrl_byte: int) -> bool:
|
||||
return ctrl_byte & 0xEF == ENCRYPTED_TRANSPORT
|
@ -0,0 +1,14 @@
|
||||
from trezor import protobuf
|
||||
|
||||
from ..context import Context
|
||||
from .channel_context import ChannelContext
|
||||
|
||||
|
||||
class SessionContext(Context):
|
||||
def __init__(self, channel_context: ChannelContext, session_id: int) -> None:
|
||||
super().__init__(channel_context.iface, channel_context.channel_id)
|
||||
self.channel_context = channel_context
|
||||
self.session_id = session_id
|
||||
|
||||
async def write(self, msg: protobuf.MessageType) -> None:
|
||||
return await self.channel_context.write(msg, self.session_id)
|
Loading…
Reference in new issue